@@ -146,8 +146,49 @@ def __init__(self,
146
146
# The ret[0] of build_norm_layer is norm name.
147
147
self .norm = build_norm_layer (norm_cfg , embed_dims )[1 ]
148
148
149
+ # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
150
+ from mmseg import digit_version , mmcv_version
151
+ if mmcv_version < digit_version ('1.3.17' ):
152
+ warnings .warn ('The legacy version of forward function in'
153
+ 'EfficientMultiheadAttention is deprecated in'
154
+ 'mmcv>=1.3.17 and will no longer support in the'
155
+ 'future. Please upgrade your mmcv.' )
156
+ self .forward = self .legacy_forward
157
+
149
158
def forward (self , x , hw_shape , identity = None ):
150
159
160
+ x_q = x
161
+ if self .sr_ratio > 1 :
162
+ x_kv = nlc_to_nchw (x , hw_shape )
163
+ x_kv = self .sr (x_kv )
164
+ x_kv = nchw_to_nlc (x_kv )
165
+ x_kv = self .norm (x_kv )
166
+ else :
167
+ x_kv = x
168
+
169
+ if identity is None :
170
+ identity = x_q
171
+
172
+ # Because the dataflow('key', 'query', 'value') of
173
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
174
+ # embed_dims), We should adjust the shape of dataflow from
175
+ # batch_first (batch, num_query, embed_dims) to num_query_first
176
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
177
+ # from num_query_first to batch_first.
178
+ if self .batch_first :
179
+ x_q = x_q .transpose (0 , 1 )
180
+ x_kv = x_kv .transpose (0 , 1 )
181
+
182
+ out = self .attn (query = x_q , key = x_kv , value = x_kv )[0 ]
183
+
184
+ if self .batch_first :
185
+ out = out .transpose (0 , 1 )
186
+
187
+ return identity + self .dropout_layer (self .proj_drop (out ))
188
+
189
+ def legacy_forward (self , x , hw_shape , identity = None ):
190
+ """multi head attention forward in mmcv version < 1.3.17."""
191
+
151
192
x_q = x
152
193
if self .sr_ratio > 1 :
153
194
x_kv = nlc_to_nchw (x , hw_shape )
0 commit comments