11import os .path as osp
22
3+ import torch .distributed as dist
34from mmcv .runner import DistEvalHook as _DistEvalHook
45from mmcv .runner import EvalHook as _EvalHook
6+ from torch .nn .modules .batchnorm import _BatchNorm
57
68
79class EvalHook (_EvalHook ):
@@ -23,33 +25,17 @@ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
2325 super ().__init__ (* args , by_epoch = by_epoch , ** kwargs )
2426 self .efficient_test = efficient_test
2527
26- def after_train_iter (self , runner ):
27- """After train epoch hook.
28-
29- Override default ``single_gpu_test``.
30- """
31- if self .by_epoch or not self .every_n_iters (runner , self .interval ):
28+ def _do_evaluate (self , runner ):
29+ """perform evaluation and save ckpt."""
30+ if not self ._should_evaluate (runner ):
3231 return
33- from mmseg .apis import single_gpu_test
34- runner .log_buffer .clear ()
35- results = single_gpu_test (
36- runner .model ,
37- self .dataloader ,
38- show = False ,
39- efficient_test = self .efficient_test )
40- self .evaluate (runner , results )
4132
42- def after_train_epoch (self , runner ):
43- """After train epoch hook.
44-
45- Override default ``single_gpu_test``.
46- """
47- if not self .by_epoch or not self .every_n_epochs (runner , self .interval ):
48- return
4933 from mmseg .apis import single_gpu_test
50- runner .log_buffer .clear ()
5134 results = single_gpu_test (runner .model , self .dataloader , show = False )
52- self .evaluate (runner , results )
35+ runner .log_buffer .output ['eval_iter_num' ] = len (self .dataloader )
36+ key_score = self .evaluate (runner , results )
37+ if self .save_best :
38+ self ._save_ckpt (runner , key_score )
5339
5440
5541class DistEvalHook (_DistEvalHook ):
@@ -71,39 +57,38 @@ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
7157 super ().__init__ (* args , by_epoch = by_epoch , ** kwargs )
7258 self .efficient_test = efficient_test
7359
74- def after_train_iter (self , runner ):
75- """After train epoch hook.
76-
77- Override default ``multi_gpu_test``.
78- """
79- if self .by_epoch or not self .every_n_iters (runner , self .interval ):
60+ def _do_evaluate (self , runner ):
61+ """perform evaluation and save ckpt."""
62+ # Synchronization of BatchNorm's buffer (running_mean
63+ # and running_var) is not supported in the DDP of pytorch,
64+ # which may cause the inconsistent performance of models in
65+ # different ranks, so we broadcast BatchNorm's buffers
66+ # of rank 0 to other ranks to avoid this.
67+ if self .broadcast_bn_buffer :
68+ model = runner .model
69+ for name , module in model .named_modules ():
70+ if isinstance (module ,
71+ _BatchNorm ) and module .track_running_stats :
72+ dist .broadcast (module .running_var , 0 )
73+ dist .broadcast (module .running_mean , 0 )
74+
75+ if not self ._should_evaluate (runner ):
8076 return
81- from mmseg .apis import multi_gpu_test
82- runner .log_buffer .clear ()
83- results = multi_gpu_test (
84- runner .model ,
85- self .dataloader ,
86- tmpdir = osp .join (runner .work_dir , '.eval_hook' ),
87- gpu_collect = self .gpu_collect ,
88- efficient_test = self .efficient_test )
89- if runner .rank == 0 :
90- print ('\n ' )
91- self .evaluate (runner , results )
9277
93- def after_train_epoch (self , runner ):
94- """After train epoch hook.
78+ tmpdir = self .tmpdir
79+ if tmpdir is None :
80+ tmpdir = osp .join (runner .work_dir , '.eval_hook' )
9581
96- Override default ``multi_gpu_test``.
97- """
98- if not self .by_epoch or not self .every_n_epochs (runner , self .interval ):
99- return
10082 from mmseg .apis import multi_gpu_test
101- runner .log_buffer .clear ()
10283 results = multi_gpu_test (
10384 runner .model ,
10485 self .dataloader ,
105- tmpdir = osp . join ( runner . work_dir , '.eval_hook' ) ,
86+ tmpdir = tmpdir ,
10687 gpu_collect = self .gpu_collect )
10788 if runner .rank == 0 :
10889 print ('\n ' )
109- self .evaluate (runner , results )
90+ runner .log_buffer .output ['eval_iter_num' ] = len (self .dataloader )
91+ key_score = self .evaluate (runner , results )
92+
93+ if self .save_best :
94+ self ._save_ckpt (runner , key_score )
0 commit comments