Skip to content

Commit 69b28e0

Browse files
authored
[Refactor] Add build_pos_embed and build_layers for BEiT (open-mmlab#1517)
* [Refactor] Add build_pos_embed and build_layers for BEiT * Update mmseg/models/backbones/beit.py
1 parent f16bb06 commit 69b28e0

File tree

1 file changed

+54
-31
lines changed

1 file changed

+54
-31
lines changed

mmseg/models/backbones/beit.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -306,25 +306,31 @@ def __init__(self,
306306
elif pretrained is not None:
307307
raise TypeError('pretrained must be a str or None')
308308

309+
self.in_channels = in_channels
309310
self.img_size = img_size
310311
self.patch_size = patch_size
311312
self.norm_eval = norm_eval
312313
self.pretrained = pretrained
313-
314-
self.patch_embed = PatchEmbed(
315-
in_channels=in_channels,
316-
embed_dims=embed_dims,
317-
conv_type='Conv2d',
318-
kernel_size=patch_size,
319-
stride=patch_size,
320-
padding=0,
321-
norm_cfg=norm_cfg if patch_norm else None,
322-
init_cfg=None)
323-
324-
window_size = (img_size[0] // patch_size, img_size[1] // patch_size)
325-
self.patch_shape = window_size
314+
self.num_layers = num_layers
315+
self.embed_dims = embed_dims
316+
self.num_heads = num_heads
317+
self.mlp_ratio = mlp_ratio
318+
self.attn_drop_rate = attn_drop_rate
319+
self.drop_path_rate = drop_path_rate
320+
self.num_fcs = num_fcs
321+
self.qv_bias = qv_bias
322+
self.act_cfg = act_cfg
323+
self.norm_cfg = norm_cfg
324+
self.patch_norm = patch_norm
325+
self.init_values = init_values
326+
self.window_size = (img_size[0] // patch_size,
327+
img_size[1] // patch_size)
328+
self.patch_shape = self.window_size
326329
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
327330

331+
self._build_patch_embedding()
332+
self._build_layers()
333+
328334
if isinstance(out_indices, int):
329335
if out_indices == -1:
330336
out_indices = num_layers - 1
@@ -334,29 +340,47 @@ def __init__(self,
334340
else:
335341
raise TypeError('out_indices must be type of int, list or tuple')
336342

337-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
338-
self.layers = ModuleList()
339-
for i in range(num_layers):
340-
self.layers.append(
341-
BEiTTransformerEncoderLayer(
342-
embed_dims=embed_dims,
343-
num_heads=num_heads,
344-
feedforward_channels=mlp_ratio * embed_dims,
345-
attn_drop_rate=attn_drop_rate,
346-
drop_path_rate=dpr[i],
347-
num_fcs=num_fcs,
348-
bias='qv_bias' if qv_bias else False,
349-
act_cfg=act_cfg,
350-
norm_cfg=norm_cfg,
351-
window_size=window_size,
352-
init_values=init_values))
353-
354343
self.final_norm = final_norm
355344
if final_norm:
356345
self.norm1_name, norm1 = build_norm_layer(
357346
norm_cfg, embed_dims, postfix=1)
358347
self.add_module(self.norm1_name, norm1)
359348

349+
def _build_patch_embedding(self):
350+
"""Build patch embedding layer."""
351+
self.patch_embed = PatchEmbed(
352+
in_channels=self.in_channels,
353+
embed_dims=self.embed_dims,
354+
conv_type='Conv2d',
355+
kernel_size=self.patch_size,
356+
stride=self.patch_size,
357+
padding=0,
358+
norm_cfg=self.norm_cfg if self.patch_norm else None,
359+
init_cfg=None)
360+
361+
def _build_layers(self):
362+
"""Build transformer encoding layers."""
363+
364+
dpr = [
365+
x.item()
366+
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
367+
]
368+
self.layers = ModuleList()
369+
for i in range(self.num_layers):
370+
self.layers.append(
371+
BEiTTransformerEncoderLayer(
372+
embed_dims=self.embed_dims,
373+
num_heads=self.num_heads,
374+
feedforward_channels=self.mlp_ratio * self.embed_dims,
375+
attn_drop_rate=self.attn_drop_rate,
376+
drop_path_rate=dpr[i],
377+
num_fcs=self.num_fcs,
378+
bias='qv_bias' if self.qv_bias else False,
379+
act_cfg=self.act_cfg,
380+
norm_cfg=self.norm_cfg,
381+
window_size=self.window_size,
382+
init_values=self.init_values))
383+
360384
@property
361385
def norm1(self):
362386
return getattr(self, self.norm1_name)
@@ -419,7 +443,6 @@ def resize_rel_pos_embed(self, checkpoint):
419443
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
420444
Copyright (c) Microsoft Corporation
421445
Licensed under the MIT License
422-
423446
Args:
424447
checkpoint (dict): Key and value of the pretrain model.
425448
Returns:

0 commit comments

Comments
 (0)