33
44class Model (nn .Module ):
55 """
6- Model that performs a 3D transposed convolution, average pooling, clamping, softmax, and multiplication.
6+ Model that performs average pooling, 3D transposed convolution, clamping,
7+ spatial softmax, and multiplication by a learnable scale.
78 """
89 def __init__ (self , in_channels , out_channels , kernel_size , stride , padding , output_padding , pool_kernel_size , clamp_min , clamp_max ):
910 super (Model , self ).__init__ ()
10- self .conv_transpose = nn .ConvTranspose3d (in_channels , out_channels , kernel_size , stride = stride , padding = padding , output_padding = output_padding )
1111 self .avg_pool = nn .AvgPool3d (pool_kernel_size )
12+ self .conv_transpose = nn .ConvTranspose3d (in_channels , out_channels , kernel_size , stride = stride , padding = padding , output_padding = output_padding )
1213 self .clamp_min = clamp_min
1314 self .clamp_max = clamp_max
15+ self .scale = nn .Parameter (torch .ones (1 , out_channels , 1 , 1 , 1 ))
1416
1517 def forward (self , x ):
1618 """
@@ -20,11 +22,14 @@ def forward(self, x):
2022 Returns:
2123 torch.Tensor: Output tensor of shape (batch_size, out_channels, depth, height, width).
2224 """
23- x = self .conv_transpose (x )
2425 x = self .avg_pool (x )
26+ x = self .conv_transpose (x )
2527 x = torch .clamp (x , self .clamp_min , self .clamp_max )
26- x = torch .softmax (x , dim = 1 )
27- x = x * 2
28+ b , c , d , h , w = x .shape
29+ x = x .view (b , c , - 1 ) # flatten spatial dims
30+ x = torch .softmax (x , dim = 2 )
31+ x = x .view (b , c , d , h , w )
32+ x = x * self .scale
2833 return x
2934
3035batch_size = 16
@@ -43,4 +48,4 @@ def get_inputs():
4348 return [torch .randn (batch_size , in_channels , depth , height , width )]
4449
4550def get_init_inputs ():
46- return [in_channels , out_channels , kernel_size , stride , padding , output_padding , pool_kernel_size , clamp_min , clamp_max ]
51+ return [in_channels , out_channels , kernel_size , stride , padding , output_padding , pool_kernel_size , clamp_min , clamp_max ]
0 commit comments