program.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import yaml
  20. import time
  21. import shutil
  22. import paddle
  23. import paddle.distributed as dist
  24. from tqdm import tqdm
  25. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  26. from ppocr.utils.stats import TrainingStats
  27. from ppocr.utils.save_load import save_model
  28. from ppocr.utils.utility import print_dict
  29. from ppocr.utils.logging import get_logger
  30. from ppocr.data import build_dataloader
  31. import numpy as np
  32. class ArgsParser(ArgumentParser):
  33. def __init__(self):
  34. super(ArgsParser, self).__init__(
  35. formatter_class=RawDescriptionHelpFormatter)
  36. self.add_argument("-c", "--config", help="configuration file to use")
  37. self.add_argument(
  38. "-o", "--opt", nargs='+', help="set configuration options")
  39. def parse_args(self, argv=None):
  40. args = super(ArgsParser, self).parse_args(argv)
  41. assert args.config is not None, \
  42. "Please specify --config=configure_file_path."
  43. args.opt = self._parse_opt(args.opt)
  44. return args
  45. def _parse_opt(self, opts):
  46. config = {}
  47. if not opts:
  48. return config
  49. for s in opts:
  50. s = s.strip()
  51. k, v = s.split('=')
  52. config[k] = yaml.load(v, Loader=yaml.Loader)
  53. return config
  54. class AttrDict(dict):
  55. """Single level attribute dict, NOT recursive"""
  56. def __init__(self, **kwargs):
  57. super(AttrDict, self).__init__()
  58. super(AttrDict, self).update(kwargs)
  59. def __getattr__(self, key):
  60. if key in self:
  61. return self[key]
  62. raise AttributeError("object has no attribute '{}'".format(key))
  63. global_config = AttrDict()
  64. default_config = {'Global': {'debug': False, }}
  65. def load_config(file_path):
  66. """
  67. Load config from yml/yaml file.
  68. Args:
  69. file_path (str): Path of the config file to be loaded.
  70. Returns: global config
  71. """
  72. merge_config(default_config)
  73. _, ext = os.path.splitext(file_path)
  74. assert ext in ['.yml', '.yaml'], "only support yaml files for now"
  75. merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
  76. return global_config
  77. def merge_config(config):
  78. """
  79. Merge config into global config.
  80. Args:
  81. config (dict): Config to be merged.
  82. Returns: global config
  83. """
  84. for key, value in config.items():
  85. if "." not in key:
  86. if isinstance(value, dict) and key in global_config:
  87. global_config[key].update(value)
  88. else:
  89. global_config[key] = value
  90. else:
  91. sub_keys = key.split('.')
  92. assert (
  93. sub_keys[0] in global_config
  94. ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
  95. global_config.keys(), sub_keys[0])
  96. cur = global_config[sub_keys[0]]
  97. for idx, sub_key in enumerate(sub_keys[1:]):
  98. if idx == len(sub_keys) - 2:
  99. cur[sub_key] = value
  100. else:
  101. cur = cur[sub_key]
  102. def check_gpu(use_gpu):
  103. """
  104. Log error and exit when set use_gpu=true in paddlepaddle
  105. cpu version.
  106. """
  107. err = "Config use_gpu cannot be set as true while you are " \
  108. "using paddlepaddle cpu version ! \nPlease try: \n" \
  109. "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
  110. "\t2. Set use_gpu as false in config file to run " \
  111. "model on CPU"
  112. try:
  113. if use_gpu and not paddle.is_compiled_with_cuda():
  114. print(err)
  115. sys.exit(1)
  116. except Exception as e:
  117. pass
  118. def train(config,
  119. train_dataloader,
  120. valid_dataloader,
  121. device,
  122. model,
  123. loss_class,
  124. optimizer,
  125. lr_scheduler,
  126. post_process_class,
  127. eval_class,
  128. pre_best_model_dict,
  129. logger,
  130. vdl_writer=None):
  131. cal_metric_during_train = config['Global'].get('cal_metric_during_train',
  132. False)
  133. log_smooth_window = config['Global']['log_smooth_window']
  134. epoch_num = config['Global']['epoch_num']
  135. print_batch_step = config['Global']['print_batch_step']
  136. eval_batch_step = config['Global']['eval_batch_step']
  137. global_step = 0
  138. start_eval_step = 0
  139. if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
  140. start_eval_step = eval_batch_step[0]
  141. eval_batch_step = eval_batch_step[1]
  142. if len(valid_dataloader) == 0:
  143. logger.info(
  144. 'No Images in eval dataset, evaluation during training will be disabled'
  145. )
  146. start_eval_step = 1e111
  147. logger.info(
  148. "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
  149. format(start_eval_step, eval_batch_step))
  150. save_epoch_step = config['Global']['save_epoch_step']
  151. save_model_dir = config['Global']['save_model_dir']
  152. if not os.path.exists(save_model_dir):
  153. os.makedirs(save_model_dir)
  154. main_indicator = eval_class.main_indicator
  155. best_model_dict = {main_indicator: 0}
  156. best_model_dict.update(pre_best_model_dict)
  157. train_stats = TrainingStats(log_smooth_window, ['lr'])
  158. model_average = False
  159. model.train()
  160. use_srn = config['Architecture']['algorithm'] == "SRN"
  161. if 'start_epoch' in best_model_dict:
  162. start_epoch = best_model_dict['start_epoch']
  163. else:
  164. start_epoch = 1
  165. for epoch in range(start_epoch, epoch_num + 1):
  166. train_dataloader = build_dataloader(
  167. config, 'Train', device, logger, seed=epoch)
  168. train_batch_cost = 0.0
  169. train_reader_cost = 0.0
  170. batch_sum = 0
  171. batch_start = time.time()
  172. for idx, batch in enumerate(train_dataloader):
  173. train_reader_cost += time.time() - batch_start
  174. if idx >= len(train_dataloader):
  175. break
  176. lr = optimizer.get_lr()
  177. images = batch[0]
  178. if use_srn:
  179. others = batch[-4:]
  180. preds = model(images, others)
  181. model_average = True
  182. else:
  183. preds = model(images)
  184. loss = loss_class(preds, batch)
  185. avg_loss = loss['loss']
  186. avg_loss.backward()
  187. optimizer.step()
  188. optimizer.clear_grad()
  189. train_batch_cost += time.time() - batch_start
  190. batch_sum += len(images)
  191. if not isinstance(lr_scheduler, float):
  192. lr_scheduler.step()
  193. # logger and visualdl
  194. stats = {k: v.numpy().mean() for k, v in loss.items()}
  195. stats['lr'] = lr
  196. train_stats.update(stats)
  197. if cal_metric_during_train: # only rec and cls need
  198. batch = [item.numpy() for item in batch]
  199. post_result = post_process_class(preds, batch[1])
  200. eval_class(post_result, batch)
  201. metric = eval_class.get_metric()
  202. train_stats.update(metric)
  203. if vdl_writer is not None and dist.get_rank() == 0:
  204. for k, v in train_stats.get().items():
  205. vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
  206. vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
  207. if dist.get_rank(
  208. ) == 0 and global_step > 0 and global_step % print_batch_step == 0:
  209. logs = train_stats.log()
  210. strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
  211. epoch, epoch_num, global_step, logs, train_reader_cost /
  212. print_batch_step, train_batch_cost / print_batch_step,
  213. batch_sum, batch_sum / train_batch_cost)
  214. logger.info(strs)
  215. train_batch_cost = 0.0
  216. train_reader_cost = 0.0
  217. batch_sum = 0
  218. # eval
  219. if global_step > start_eval_step and \
  220. (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
  221. if model_average:
  222. Model_Average = paddle.incubate.optimizer.ModelAverage(
  223. 0.15,
  224. parameters=model.parameters(),
  225. min_average_window=10000,
  226. max_average_window=15625)
  227. Model_Average.apply()
  228. cur_metric = eval(
  229. model,
  230. valid_dataloader,
  231. post_process_class,
  232. eval_class,
  233. use_srn=use_srn)
  234. cur_metric_str = 'cur metric, {}'.format(', '.join(
  235. ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
  236. logger.info(cur_metric_str)
  237. # logger metric
  238. if vdl_writer is not None:
  239. for k, v in cur_metric.items():
  240. if isinstance(v, (float, int)):
  241. vdl_writer.add_scalar('EVAL/{}'.format(k),
  242. cur_metric[k], global_step)
  243. if cur_metric[main_indicator] >= best_model_dict[
  244. main_indicator]:
  245. best_model_dict.update(cur_metric)
  246. best_model_dict['best_epoch'] = epoch
  247. save_model(
  248. model,
  249. optimizer,
  250. save_model_dir,
  251. logger,
  252. is_best=True,
  253. prefix='best_accuracy',
  254. best_model_dict=best_model_dict,
  255. epoch=epoch)
  256. best_str = 'best metric, {}'.format(', '.join([
  257. '{}: {}'.format(k, v) for k, v in best_model_dict.items()
  258. ]))
  259. logger.info(best_str)
  260. # logger best metric
  261. if vdl_writer is not None:
  262. vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
  263. best_model_dict[main_indicator],
  264. global_step)
  265. global_step += 1
  266. optimizer.clear_grad()
  267. batch_start = time.time()
  268. if dist.get_rank() == 0:
  269. save_model(
  270. model,
  271. optimizer,
  272. save_model_dir,
  273. logger,
  274. is_best=False,
  275. prefix='latest',
  276. best_model_dict=best_model_dict,
  277. epoch=epoch)
  278. if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
  279. save_model(
  280. model,
  281. optimizer,
  282. save_model_dir,
  283. logger,
  284. is_best=False,
  285. prefix='iter_epoch_{}'.format(epoch),
  286. best_model_dict=best_model_dict,
  287. epoch=epoch)
  288. best_str = 'best metric, {}'.format(', '.join(
  289. ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
  290. logger.info(best_str)
  291. if dist.get_rank() == 0 and vdl_writer is not None:
  292. vdl_writer.close()
  293. return
  294. def eval(model, valid_dataloader, post_process_class, eval_class,
  295. use_srn=False):
  296. model.eval()
  297. with paddle.no_grad():
  298. total_frame = 0.0
  299. total_time = 0.0
  300. pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
  301. for idx, batch in enumerate(valid_dataloader):
  302. if idx >= len(valid_dataloader):
  303. break
  304. images = batch[0]
  305. start = time.time()
  306. if use_srn:
  307. others = batch[-4:]
  308. preds = model(images, others)
  309. else:
  310. preds = model(images)
  311. batch = [item.numpy() for item in batch]
  312. # Obtain usable results from post-processing methods
  313. post_result = post_process_class(preds, batch[1])
  314. total_time += time.time() - start
  315. # Evaluate the results of the current batch
  316. eval_class(post_result, batch)
  317. pbar.update(1)
  318. total_frame += len(images)
  319. # Get final metric,eg. acc or hmean
  320. metric = eval_class.get_metric()
  321. pbar.close()
  322. model.train()
  323. metric['fps'] = total_frame / total_time
  324. return metric
  325. def preprocess(is_train=False):
  326. FLAGS = ArgsParser().parse_args()
  327. config = load_config(FLAGS.config)
  328. merge_config(FLAGS.opt)
  329. # check if set use_gpu=True in paddlepaddle cpu version
  330. use_gpu = config['Global']['use_gpu']
  331. check_gpu(use_gpu)
  332. alg = config['Architecture']['algorithm']
  333. assert alg in [
  334. 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
  335. ]
  336. device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
  337. device = paddle.set_device(device)
  338. config['Global']['distributed'] = dist.get_world_size() != 1
  339. if is_train:
  340. # save_config
  341. save_model_dir = config['Global']['save_model_dir']
  342. os.makedirs(save_model_dir, exist_ok=True)
  343. with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
  344. yaml.dump(
  345. dict(config), f, default_flow_style=False, sort_keys=False)
  346. log_file = '{}/train.log'.format(save_model_dir)
  347. else:
  348. log_file = None
  349. logger = get_logger(name='root', log_file=log_file)
  350. if config['Global']['use_visualdl']:
  351. from visualdl import LogWriter
  352. save_model_dir = config['Global']['save_model_dir']
  353. vdl_writer_path = '{}/vdl/'.format(save_model_dir)
  354. os.makedirs(vdl_writer_path, exist_ok=True)
  355. vdl_writer = LogWriter(logdir=vdl_writer_path)
  356. else:
  357. vdl_writer = None
  358. print_dict(config, logger)
  359. logger.info('train with paddle {} and device {}'.format(paddle.__version__,
  360. device))
  361. return config, device, logger, vdl_writer