11# Copyright (c) OpenMMLab. All rights reserved.
22from numbers import Number
3- from typing import List , Optional , Sequence , Tuple
3+ from typing import Any , Dict , List , Optional , Sequence
44
55import torch
66from mmengine .model import BaseDataPreprocessor
7- from torch import Tensor
87
98from mmseg .registry import MODELS
10- from mmseg .utils import OptSampleList , stack_batch
9+ from mmseg .utils import stack_batch
1110
1211
1312@MODELS .register_module ()
@@ -87,22 +86,20 @@ def __init__(self,
8786 # TODO: support batch augmentations.
8887 self .batch_augments = batch_augments
8988
90- def forward (self ,
91- data : Sequence [dict ],
92- training : bool = False ) -> Tuple [Tensor , OptSampleList ]:
89+ def forward (self , data : dict , training : bool = False ) -> Dict [str , Any ]:
9390 """Perform normalization、padding and bgr2rgb conversion based on
9491 ``BaseDataPreprocessor``.
9592
9693 Args:
97- data (Sequence[ dict] ): data sampled from dataloader.
94+ data (dict): data sampled from dataloader.
9895 training (bool): Whether to enable training time augmentation.
9996
10097 Returns:
101- Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
102- model input.
98+ Dict: Data in the same format as the model input.
10399 """
104- inputs , batch_data_samples = self .collate_data (data )
105-
100+ data = self .cast_data (data ) # type: ignore
101+ inputs = data ['inputs' ]
102+ data_samples = data .get ('data_samples' , None )
106103 # TODO: whether normalize should be after stack_batch
107104 if self .channel_conversion and inputs [0 ].size (0 ) == 3 :
108105 inputs = [_input [[2 , 1 , 0 ], ...] for _input in inputs ]
@@ -113,20 +110,23 @@ def forward(self,
113110 inputs = [_input .float () for _input in inputs ]
114111
115112 if training :
116- batch_inputs , batch_data_samples = stack_batch (
113+ assert data_samples is not None , ('During training, ' ,
114+ '`data_samples` must be define.' )
115+ inputs , data_samples = stack_batch (
117116 inputs = inputs ,
118- batch_data_samples = batch_data_samples ,
117+ data_samples = data_samples ,
119118 size = self .size ,
120119 size_divisor = self .size_divisor ,
121120 pad_val = self .pad_val ,
122121 seg_pad_val = self .seg_pad_val )
123122
124123 if self .batch_augments is not None :
125- inputs , batch_data_samples = self .batch_augments (
126- inputs , batch_data_samples )
127- return batch_inputs , batch_data_samples
124+ inputs , data_samples = self .batch_augments (
125+ inputs , data_samples )
126+ return dict ( inputs = inputs , data_samples = data_samples )
128127 else :
129128 assert len (inputs ) == 1 , (
130129 'Batch inference is not support currently, '
131130 'as the image size might be different in a batch' )
132- return torch .stack (inputs , dim = 0 ), batch_data_samples
131+ return dict (
132+ inputs = torch .stack (inputs , dim = 0 ), data_samples = data_samples )
0 commit comments