1+ import torch
2+ import torch .nn as nn
3+
4+ class Model (nn .Module ):
5+ """
6+ A model that performs a sequence of operations:
7+ - ConvTranspose3d
8+ - MaxPool3d
9+ - Softmax
10+ - Subtract
11+ - Swish
12+ - Max
13+ """
14+ def __init__ (self , in_channels , out_channels , kernel_size , stride , padding , output_padding , pool_kernel_size , pool_stride , pool_padding ):
15+ super (Model , self ).__init__ ()
16+ self .conv_transpose = nn .ConvTranspose3d (in_channels , out_channels , kernel_size , stride = stride , padding = padding , output_padding = output_padding )
17+ self .max_pool = nn .MaxPool3d (kernel_size = pool_kernel_size , stride = pool_stride , padding = pool_padding )
18+ self .subtract = nn .Parameter (torch .randn (out_channels )) # Assuming subtraction is element-wise across channels
19+
20+ def forward (self , x ):
21+ x = self .conv_transpose (x )
22+ x = self .max_pool (x )
23+ x = torch .softmax (x , dim = 1 ) # Apply softmax across channels (dim=1)
24+ x = x - self .subtract .view (1 , - 1 , 1 , 1 , 1 ) # Subtract across channels
25+ x = torch .sigmoid (x ) * x # Swish activation
26+ x = torch .max (x , dim = 1 )[0 ] # Max pooling across channels
27+ return x
28+
29+ batch_size = 128
30+ in_channels = 3
31+ out_channels = 16
32+ depth , height , width = 16 , 32 , 32
33+ kernel_size = 3
34+ stride = 2
35+ padding = 1
36+ output_padding = 1
37+ pool_kernel_size = 2
38+ pool_stride = 2
39+ pool_padding = 0
40+
41+ def get_inputs ():
42+ return [torch .randn (batch_size , in_channels , depth , height , width )]
43+
44+ def get_init_inputs ():
45+ return [in_channels , out_channels , kernel_size , stride , padding , output_padding , pool_kernel_size , pool_stride , pool_padding ]
0 commit comments