@@ -306,25 +306,31 @@ def __init__(self,
306306 elif pretrained is not None :
307307 raise TypeError ('pretrained must be a str or None' )
308308
309+ self .in_channels = in_channels
309310 self .img_size = img_size
310311 self .patch_size = patch_size
311312 self .norm_eval = norm_eval
312313 self .pretrained = pretrained
313-
314- self .patch_embed = PatchEmbed (
315- in_channels = in_channels ,
316- embed_dims = embed_dims ,
317- conv_type = 'Conv2d' ,
318- kernel_size = patch_size ,
319- stride = patch_size ,
320- padding = 0 ,
321- norm_cfg = norm_cfg if patch_norm else None ,
322- init_cfg = None )
323-
324- window_size = (img_size [0 ] // patch_size , img_size [1 ] // patch_size )
325- self .patch_shape = window_size
314+ self .num_layers = num_layers
315+ self .embed_dims = embed_dims
316+ self .num_heads = num_heads
317+ self .mlp_ratio = mlp_ratio
318+ self .attn_drop_rate = attn_drop_rate
319+ self .drop_path_rate = drop_path_rate
320+ self .num_fcs = num_fcs
321+ self .qv_bias = qv_bias
322+ self .act_cfg = act_cfg
323+ self .norm_cfg = norm_cfg
324+ self .patch_norm = patch_norm
325+ self .init_values = init_values
326+ self .window_size = (img_size [0 ] // patch_size ,
327+ img_size [1 ] // patch_size )
328+ self .patch_shape = self .window_size
326329 self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dims ))
327330
331+ self ._build_patch_embedding ()
332+ self ._build_layers ()
333+
328334 if isinstance (out_indices , int ):
329335 if out_indices == - 1 :
330336 out_indices = num_layers - 1
@@ -334,29 +340,47 @@ def __init__(self,
334340 else :
335341 raise TypeError ('out_indices must be type of int, list or tuple' )
336342
337- dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , num_layers )]
338- self .layers = ModuleList ()
339- for i in range (num_layers ):
340- self .layers .append (
341- BEiTTransformerEncoderLayer (
342- embed_dims = embed_dims ,
343- num_heads = num_heads ,
344- feedforward_channels = mlp_ratio * embed_dims ,
345- attn_drop_rate = attn_drop_rate ,
346- drop_path_rate = dpr [i ],
347- num_fcs = num_fcs ,
348- bias = 'qv_bias' if qv_bias else False ,
349- act_cfg = act_cfg ,
350- norm_cfg = norm_cfg ,
351- window_size = window_size ,
352- init_values = init_values ))
353-
354343 self .final_norm = final_norm
355344 if final_norm :
356345 self .norm1_name , norm1 = build_norm_layer (
357346 norm_cfg , embed_dims , postfix = 1 )
358347 self .add_module (self .norm1_name , norm1 )
359348
349+ def _build_patch_embedding (self ):
350+ """Build patch embedding layer."""
351+ self .patch_embed = PatchEmbed (
352+ in_channels = self .in_channels ,
353+ embed_dims = self .embed_dims ,
354+ conv_type = 'Conv2d' ,
355+ kernel_size = self .patch_size ,
356+ stride = self .patch_size ,
357+ padding = 0 ,
358+ norm_cfg = self .norm_cfg if self .patch_norm else None ,
359+ init_cfg = None )
360+
361+ def _build_layers (self ):
362+ """Build transformer encoding layers."""
363+
364+ dpr = [
365+ x .item ()
366+ for x in torch .linspace (0 , self .drop_path_rate , self .num_layers )
367+ ]
368+ self .layers = ModuleList ()
369+ for i in range (self .num_layers ):
370+ self .layers .append (
371+ BEiTTransformerEncoderLayer (
372+ embed_dims = self .embed_dims ,
373+ num_heads = self .num_heads ,
374+ feedforward_channels = self .mlp_ratio * self .embed_dims ,
375+ attn_drop_rate = self .attn_drop_rate ,
376+ drop_path_rate = dpr [i ],
377+ num_fcs = self .num_fcs ,
378+ bias = 'qv_bias' if self .qv_bias else False ,
379+ act_cfg = self .act_cfg ,
380+ norm_cfg = self .norm_cfg ,
381+ window_size = self .window_size ,
382+ init_values = self .init_values ))
383+
360384 @property
361385 def norm1 (self ):
362386 return getattr (self , self .norm1_name )
@@ -419,7 +443,6 @@ def resize_rel_pos_embed(self, checkpoint):
419443 https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
420444 Copyright (c) Microsoft Corporation
421445 Licensed under the MIT License
422-
423446 Args:
424447 checkpoint (dict): Key and value of the pretrain model.
425448 Returns:
0 commit comments