Skip to content

Commit af9e8ec

Browse files
Restore 89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Max.py
1 parent d6149e7 commit af9e8ec

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class Model(nn.Module):
5+
"""
6+
A model that performs a sequence of operations:
7+
- ConvTranspose3d
8+
- MaxPool3d
9+
- Softmax
10+
- Subtract
11+
- Swish
12+
- Max
13+
"""
14+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding):
15+
super(Model, self).__init__()
16+
self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
17+
self.max_pool = nn.MaxPool3d(kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding)
18+
self.subtract = nn.Parameter(torch.randn(out_channels)) # Assuming subtraction is element-wise across channels
19+
20+
def forward(self, x):
21+
x = self.conv_transpose(x)
22+
x = self.max_pool(x)
23+
x = torch.softmax(x, dim=1) # Apply softmax across channels (dim=1)
24+
x = x - self.subtract.view(1, -1, 1, 1, 1) # Subtract across channels
25+
x = torch.sigmoid(x) * x # Swish activation
26+
x = torch.max(x, dim=1)[0] # Max pooling across channels
27+
return x
28+
29+
batch_size = 128
30+
in_channels = 3
31+
out_channels = 16
32+
depth, height, width = 16, 32, 32
33+
kernel_size = 3
34+
stride = 2
35+
padding = 1
36+
output_padding = 1
37+
pool_kernel_size = 2
38+
pool_stride = 2
39+
pool_padding = 0
40+
41+
def get_inputs():
42+
return [torch.randn(batch_size, in_channels, depth, height, width)]
43+
44+
def get_init_inputs():
45+
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, pool_kernel_size, pool_stride, pool_padding]

0 commit comments

Comments
 (0)