@@ -48,9 +48,6 @@ class SegDataPreProcessor(BaseDataPreprocessor):
4848 rgb_to_bgr (bool): whether to convert image from RGB to RGB.
4949 Defaults to False.
5050 batch_augments (list[dict], optional): Batch-level augmentations
51- train_cfg (dict, optional): The padding size config in training, if
52- not specify, will use `size` and `size_divisor` params as default.
53- Defaults to None, only supports keys `size` or `size_divisor`.
5451 test_cfg (dict, optional): The padding size config in testing, if not
5552 specify, will use `size` and `size_divisor` params as default.
5653 Defaults to None, only supports keys `size` or `size_divisor`.
@@ -67,7 +64,6 @@ def __init__(
6764 bgr_to_rgb : bool = False ,
6865 rgb_to_bgr : bool = False ,
6966 batch_augments : Optional [List [dict ]] = None ,
70- train_cfg : dict = None ,
7167 test_cfg : dict = None ,
7268 ):
7369 super ().__init__ ()
@@ -96,10 +92,8 @@ def __init__(
9692 # TODO: support batch augmentations.
9793 self .batch_augments = batch_augments
9894
99- # Support different padding methods in training and testing
100- default_size_cfg = dict (size = size , size_divisor = size_divisor )
101- self .train_cfg = train_cfg if train_cfg else default_size_cfg
102- self .test_cfg = test_cfg if test_cfg else default_size_cfg
95+ # Support different padding methods in testing
96+ self .test_cfg = test_cfg
10397
10498 def forward (self , data : dict , training : bool = False ) -> Dict [str , Any ]:
10599 """Perform normalization、padding and bgr2rgb conversion based on
@@ -126,24 +120,31 @@ def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
126120 if training :
127121 assert data_samples is not None , ('During training, ' ,
128122 '`data_samples` must be define.' )
123+ inputs , data_samples = stack_batch (
124+ inputs = inputs ,
125+ data_samples = data_samples ,
126+ size = self .size ,
127+ size_divisor = self .size_divisor ,
128+ pad_val = self .pad_val ,
129+ seg_pad_val = self .seg_pad_val )
130+
131+ if self .batch_augments is not None :
132+ inputs , data_samples = self .batch_augments (
133+ inputs , data_samples )
129134 else :
130135 assert len (inputs ) == 1 , (
131136 'Batch inference is not support currently, '
132137 'as the image size might be different in a batch' )
133-
134- size_cfg = self .train_cfg if training else self .test_cfg
135- size = size_cfg .get ('size' , None )
136- size_divisor = size_cfg .get ('size_divisor' , None )
137-
138- inputs , data_samples = stack_batch (
139- inputs = inputs ,
140- data_samples = data_samples ,
141- size = size ,
142- size_divisor = size_divisor ,
143- pad_val = self .pad_val ,
144- seg_pad_val = self .seg_pad_val )
145-
146- if self .batch_augments is not None :
147- inputs , data_samples = self .batch_augments (inputs , data_samples )
138+ # pad images when testing
139+ if self .test_cfg :
140+ inputs , data_samples = stack_batch (
141+ inputs = inputs ,
142+ data_samples = data_samples ,
143+ size = self .test_cfg .get ('size' , None ),
144+ size_divisor = self .test_cfg .get ('size_divisor' , None ),
145+ pad_val = self .pad_val ,
146+ seg_pad_val = self .seg_pad_val )
147+ else :
148+ inputs = torch .stack (inputs , dim = 0 )
148149
149150 return dict (inputs = inputs , data_samples = data_samples )
0 commit comments