1
1
import os .path as osp
2
2
3
+ import torch .distributed as dist
3
4
from mmcv .runner import DistEvalHook as _DistEvalHook
4
5
from mmcv .runner import EvalHook as _EvalHook
6
+ from torch .nn .modules .batchnorm import _BatchNorm
5
7
6
8
7
9
class EvalHook (_EvalHook ):
@@ -23,33 +25,17 @@ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
23
25
super ().__init__ (* args , by_epoch = by_epoch , ** kwargs )
24
26
self .efficient_test = efficient_test
25
27
26
- def after_train_iter (self , runner ):
27
- """After train epoch hook.
28
-
29
- Override default ``single_gpu_test``.
30
- """
31
- if self .by_epoch or not self .every_n_iters (runner , self .interval ):
28
+ def _do_evaluate (self , runner ):
29
+ """perform evaluation and save ckpt."""
30
+ if not self ._should_evaluate (runner ):
32
31
return
33
- from mmseg .apis import single_gpu_test
34
- runner .log_buffer .clear ()
35
- results = single_gpu_test (
36
- runner .model ,
37
- self .dataloader ,
38
- show = False ,
39
- efficient_test = self .efficient_test )
40
- self .evaluate (runner , results )
41
32
42
- def after_train_epoch (self , runner ):
43
- """After train epoch hook.
44
-
45
- Override default ``single_gpu_test``.
46
- """
47
- if not self .by_epoch or not self .every_n_epochs (runner , self .interval ):
48
- return
49
33
from mmseg .apis import single_gpu_test
50
- runner .log_buffer .clear ()
51
34
results = single_gpu_test (runner .model , self .dataloader , show = False )
52
- self .evaluate (runner , results )
35
+ runner .log_buffer .output ['eval_iter_num' ] = len (self .dataloader )
36
+ key_score = self .evaluate (runner , results )
37
+ if self .save_best :
38
+ self ._save_ckpt (runner , key_score )
53
39
54
40
55
41
class DistEvalHook (_DistEvalHook ):
@@ -71,39 +57,38 @@ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
71
57
super ().__init__ (* args , by_epoch = by_epoch , ** kwargs )
72
58
self .efficient_test = efficient_test
73
59
74
- def after_train_iter (self , runner ):
75
- """After train epoch hook.
76
-
77
- Override default ``multi_gpu_test``.
78
- """
79
- if self .by_epoch or not self .every_n_iters (runner , self .interval ):
60
+ def _do_evaluate (self , runner ):
61
+ """perform evaluation and save ckpt."""
62
+ # Synchronization of BatchNorm's buffer (running_mean
63
+ # and running_var) is not supported in the DDP of pytorch,
64
+ # which may cause the inconsistent performance of models in
65
+ # different ranks, so we broadcast BatchNorm's buffers
66
+ # of rank 0 to other ranks to avoid this.
67
+ if self .broadcast_bn_buffer :
68
+ model = runner .model
69
+ for name , module in model .named_modules ():
70
+ if isinstance (module ,
71
+ _BatchNorm ) and module .track_running_stats :
72
+ dist .broadcast (module .running_var , 0 )
73
+ dist .broadcast (module .running_mean , 0 )
74
+
75
+ if not self ._should_evaluate (runner ):
80
76
return
81
- from mmseg .apis import multi_gpu_test
82
- runner .log_buffer .clear ()
83
- results = multi_gpu_test (
84
- runner .model ,
85
- self .dataloader ,
86
- tmpdir = osp .join (runner .work_dir , '.eval_hook' ),
87
- gpu_collect = self .gpu_collect ,
88
- efficient_test = self .efficient_test )
89
- if runner .rank == 0 :
90
- print ('\n ' )
91
- self .evaluate (runner , results )
92
77
93
- def after_train_epoch (self , runner ):
94
- """After train epoch hook.
78
+ tmpdir = self .tmpdir
79
+ if tmpdir is None :
80
+ tmpdir = osp .join (runner .work_dir , '.eval_hook' )
95
81
96
- Override default ``multi_gpu_test``.
97
- """
98
- if not self .by_epoch or not self .every_n_epochs (runner , self .interval ):
99
- return
100
82
from mmseg .apis import multi_gpu_test
101
- runner .log_buffer .clear ()
102
83
results = multi_gpu_test (
103
84
runner .model ,
104
85
self .dataloader ,
105
- tmpdir = osp . join ( runner . work_dir , '.eval_hook' ) ,
86
+ tmpdir = tmpdir ,
106
87
gpu_collect = self .gpu_collect )
107
88
if runner .rank == 0 :
108
89
print ('\n ' )
109
- self .evaluate (runner , results )
90
+ runner .log_buffer .output ['eval_iter_num' ] = len (self .dataloader )
91
+ key_score = self .evaluate (runner , results )
92
+
93
+ if self .save_best :
94
+ self ._save_ckpt (runner , key_score )
0 commit comments