Skip to content

Commit d782b55

Browse files
Restore 58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp_Max.py
1 parent f1d877c commit d782b55

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class Model(nn.Module):
5+
"""
6+
Model that performs a 3D transposed convolution, LogSumExp, HardSwish, subtraction, clamp, and maximum operations.
7+
"""
8+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias_shape):
9+
super(Model, self).__init__()
10+
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
11+
self.bias = nn.Parameter(torch.randn(bias_shape))
12+
13+
def forward(self, x):
14+
x = self.conv_transpose(x)
15+
x = torch.logsumexp(x, dim=1, keepdim=True)
16+
x = x * torch.sigmoid(x + 3) / 6
17+
x = x - self.bias
18+
x = torch.clamp(x, min=-1, max=1)
19+
x = torch.max(x, dim=1, keepdim=True)[0]
20+
return x
21+
22+
batch_size = 128
23+
in_channels = 3
24+
out_channels = 16
25+
depth, height, width = 16, 32, 32
26+
kernel_size = 3
27+
stride = 2
28+
padding = 1
29+
bias_shape = (out_channels, 1, 1, 1)
30+
31+
def get_inputs():
32+
return [torch.randn(batch_size, in_channels, depth, height, width)]
33+
34+
def get_init_inputs():
35+
return [in_channels, out_channels, kernel_size, stride, padding, bias_shape]

0 commit comments

Comments
 (0)