Skip to content

Commit d665f6b

Browse files
authored
fix EfficientMultiheadAttention in SegFormer (open-mmlab#1037)
1 parent 7a1c9a5 commit d665f6b

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

mmseg/models/backbones/mit.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,49 @@ def __init__(self,
146146
# The ret[0] of build_norm_layer is norm name.
147147
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
148148

149+
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
150+
from mmseg import digit_version, mmcv_version
151+
if mmcv_version < digit_version('1.3.17'):
152+
warnings.warn('The legacy version of forward function in'
153+
'EfficientMultiheadAttention is deprecated in'
154+
'mmcv>=1.3.17 and will no longer support in the'
155+
'future. Please upgrade your mmcv.')
156+
self.forward = self.legacy_forward
157+
149158
def forward(self, x, hw_shape, identity=None):
150159

160+
x_q = x
161+
if self.sr_ratio > 1:
162+
x_kv = nlc_to_nchw(x, hw_shape)
163+
x_kv = self.sr(x_kv)
164+
x_kv = nchw_to_nlc(x_kv)
165+
x_kv = self.norm(x_kv)
166+
else:
167+
x_kv = x
168+
169+
if identity is None:
170+
identity = x_q
171+
172+
# Because the dataflow('key', 'query', 'value') of
173+
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
174+
# embed_dims), We should adjust the shape of dataflow from
175+
# batch_first (batch, num_query, embed_dims) to num_query_first
176+
# (num_query ,batch, embed_dims), and recover ``attn_output``
177+
# from num_query_first to batch_first.
178+
if self.batch_first:
179+
x_q = x_q.transpose(0, 1)
180+
x_kv = x_kv.transpose(0, 1)
181+
182+
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
183+
184+
if self.batch_first:
185+
out = out.transpose(0, 1)
186+
187+
return identity + self.dropout_layer(self.proj_drop(out))
188+
189+
def legacy_forward(self, x, hw_shape, identity=None):
190+
"""multi head attention forward in mmcv version < 1.3.17."""
191+
151192
x_q = x
152193
if self.sr_ratio > 1:
153194
x_kv = nlc_to_nchw(x, hw_shape)

0 commit comments

Comments
 (0)