File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -479,7 +479,7 @@ class SwinTransformer(BaseModule):
479479 embed_dims (int): The feature dimension. Default: 96.
480480 patch_size (int | tuple[int]): Patch size. Default: 4.
481481 window_size (int): Window size. Default: 7.
482- mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
482+ mlp_ratio (int | float ): Ratio of mlp hidden dim to embedding dim.
483483 Default: 4.
484484 depths (tuple[int]): Depths of each Swin Transformer stage.
485485 Default: (2, 2, 6, 2).
@@ -610,7 +610,7 @@ def __init__(self,
610610 stage = SwinBlockSequence (
611611 embed_dims = in_channels ,
612612 num_heads = num_heads [i ],
613- feedforward_channels = mlp_ratio * in_channels ,
613+ feedforward_channels = int ( mlp_ratio * in_channels ) ,
614614 depth = depths [i ],
615615 window_size = window_size ,
616616 qkv_bias = qkv_bias ,
You can’t perform that action at this time.
0 commit comments