从 tools/train.py
开始:
mmdet.apis.train_detector
,传入刚才 build 好的 model,datasets,配置参数等。进入 mmdet.apis.train_detector
:
MMDistributedDataParallel
封装 model,单 GPU 训练则 MMDataParallel
;runner.run()
,训练从此处开始。runner 是 EpochBasedRunner
类的实例,进入 EpochBasedRunner
类的定义,可以看到最主要的是 run 方法:
def run(self, data_loaders, workflow, max_epochs, **kwargs):
#...
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f‘runner has no method named "{mode}" to run an ‘
‘epoch‘)
epoch_runner = getattr(self, mode)
else:
raise TypeError(
‘mode in workflow must be a str, but got {}‘.format(
type(mode)))
for _ in range(epochs):
if mode == ‘train‘ and self.epoch >= max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
workflow
变量的注释:
workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [(‘train‘, 2), (‘val‘, 1)] means
running 2 epochs for training and 1 epoch for validation,
最后 4 行是重点,根据每个 workflow 的 mode 和 epochs 调用 epochs 次相应的函数,比如:
for _ in range(epochs):
if mode == ‘train‘ and self.epoch >= max_epochs:
break
#when mode == ‘train‘
# `epoch_runner(data_loaders[i], **kwargs)` ==
self.train(data_loaders[i], **kwargs)
一个 epoch 相当于遍历一遍数据集的所有数据。
接下来看看 EpochBasedRunner.train() :
before_train_epoch
, before_train_iter
, after_train_iter
, after_train_epoch
。执行反向传播是在 after_train_iter
处。 (先不纠结 hook 是个啥)解析一下配置文件中出现的 TensorboardLoggerHook
EpochBasedRunner.register_hook()
EpochBasedRunner.call_hook(fn_name)
fn_name
,调用所有 hook 里的相应函数,将 runner 作为参数传进去。TensorboardLoggerHook
作用:每次 iter 或 epoch 完记录训练结果到 tensorboard (即写到 summary 文件里)
TensorboardLoggerHook.after_train_iter(runner)
判断是否达到 interval,比如在配置文件中指定了每 50 个 iter 才 log 训练结果,如果达到 50 个 iter,则对 50 个 iter 的结果求平均,再调用自己的 log 函数。 50 个 iter 的结果存放在 runner.log_buffer 里。
TensorboardLoggerHook.log(runner)
将 runner.log_buffer 里的结果值,添加到 summary_writer 里。
除了 Logger 这种形式的 hook 之外,还有其他一些功能也以 hook 的形式实现,比如 optimizer 对应的 OptimizerHook
,或者 training 过程中的 eval 也是通过 EvaluationHook
调用。
原文:https://www.cnblogs.com/notesbyY/p/13475892.html