@@ -146,8 +146,49 @@ def __init__(self,
146146 # The ret[0] of build_norm_layer is norm name.
147147 self .norm = build_norm_layer (norm_cfg , embed_dims )[1 ]
148148
149+ # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
150+ from mmseg import digit_version , mmcv_version
151+ if mmcv_version < digit_version ('1.3.17' ):
152+ warnings .warn ('The legacy version of forward function in'
153+ 'EfficientMultiheadAttention is deprecated in'
154+ 'mmcv>=1.3.17 and will no longer support in the'
155+ 'future. Please upgrade your mmcv.' )
156+ self .forward = self .legacy_forward
157+
149158 def forward (self , x , hw_shape , identity = None ):
150159
160+ x_q = x
161+ if self .sr_ratio > 1 :
162+ x_kv = nlc_to_nchw (x , hw_shape )
163+ x_kv = self .sr (x_kv )
164+ x_kv = nchw_to_nlc (x_kv )
165+ x_kv = self .norm (x_kv )
166+ else :
167+ x_kv = x
168+
169+ if identity is None :
170+ identity = x_q
171+
172+ # Because the dataflow('key', 'query', 'value') of
173+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
174+ # embed_dims), We should adjust the shape of dataflow from
175+ # batch_first (batch, num_query, embed_dims) to num_query_first
176+ # (num_query ,batch, embed_dims), and recover ``attn_output``
177+ # from num_query_first to batch_first.
178+ if self .batch_first :
179+ x_q = x_q .transpose (0 , 1 )
180+ x_kv = x_kv .transpose (0 , 1 )
181+
182+ out = self .attn (query = x_q , key = x_kv , value = x_kv )[0 ]
183+
184+ if self .batch_first :
185+ out = out .transpose (0 , 1 )
186+
187+ return identity + self .dropout_layer (self .proj_drop (out ))
188+
189+ def legacy_forward (self , x , hw_shape , identity = None ):
190+ """multi head attention forward in mmcv version < 1.3.17."""
191+
151192 x_q = x
152193 if self .sr_ratio > 1 :
153194 x_kv = nlc_to_nchw (x , hw_shape )
0 commit comments