Skip to content

Commit d18a901

Browse files
Fix 38_ConvTranspose3d_AvgPool_Clamp_Softmax_Multiply.py
1 parent e000f01 commit d18a901

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

KernelBench/level2/38_ConvTranspose3d_AvgPool_Clamp_Softmax_Multiply.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
class Model(nn.Module):
55
"""
6-
Model that performs a 3D transposed convolution, average pooling, clamping, softmax, and multiplication.
6+
Model that performs average pooling, 3D transposed convolution, clamping,
7+
spatial softmax, and multiplication by a learnable scale.
78
"""
89
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, clamp_min, clamp_max):
910
super(Model, self).__init__()
10-
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
1111
self.avg_pool = nn.AvgPool3d(pool_kernel_size)
12+
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
1213
self.clamp_min = clamp_min
1314
self.clamp_max = clamp_max
15+
self.scale = nn.Parameter(torch.ones(1, out_channels, 1, 1, 1))
1416

1517
def forward(self, x):
1618
"""
@@ -20,11 +22,14 @@ def forward(self, x):
2022
Returns:
2123
torch.Tensor: Output tensor of shape (batch_size, out_channels, depth, height, width).
2224
"""
23-
x = self.conv_transpose(x)
2425
x = self.avg_pool(x)
26+
x = self.conv_transpose(x)
2527
x = torch.clamp(x, self.clamp_min, self.clamp_max)
26-
x = torch.softmax(x, dim=1)
27-
x = x * 2
28+
b, c, d, h, w = x.shape
29+
x = x.view(b, c, -1) # flatten spatial dims
30+
x = torch.softmax(x, dim=2)
31+
x = x.view(b, c, d, h, w)
32+
x = x * self.scale
2833
return x
2934

3035
batch_size = 16
@@ -43,4 +48,4 @@ def get_inputs():
4348
return [torch.randn(batch_size, in_channels, depth, height, width)]
4449

4550
def get_init_inputs():
46-
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, clamp_min, clamp_max]
51+
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, clamp_min, clamp_max]

0 commit comments

Comments
 (0)