1+ import torch
2+ import torch .nn as nn
3+
4+ class Model (nn .Module ):
5+ """
6+ Model that performs a 3D transposed convolution, LogSumExp, HardSwish, subtraction, clamp, and maximum operations.
7+ """
8+ def __init__ (self , in_channels , out_channels , kernel_size , stride , padding , bias_shape ):
9+ super (Model , self ).__init__ ()
10+ self .conv_transpose = nn .ConvTranspose3d (in_channels , out_channels , kernel_size , stride = stride , padding = padding )
11+ self .bias = nn .Parameter (torch .randn (bias_shape ))
12+
13+ def forward (self , x ):
14+ x = self .conv_transpose (x )
15+ x = torch .logsumexp (x , dim = 1 , keepdim = True )
16+ x = x * torch .sigmoid (x + 3 ) / 6
17+ x = x - self .bias
18+ x = torch .clamp (x , min = - 1 , max = 1 )
19+ x = torch .max (x , dim = 1 , keepdim = True )[0 ]
20+ return x
21+
22+ batch_size = 128
23+ in_channels = 3
24+ out_channels = 16
25+ depth , height , width = 16 , 32 , 32
26+ kernel_size = 3
27+ stride = 2
28+ padding = 1
29+ bias_shape = (out_channels , 1 , 1 , 1 )
30+
31+ def get_inputs ():
32+ return [torch .randn (batch_size , in_channels , depth , height , width )]
33+
34+ def get_init_inputs ():
35+ return [in_channels , out_channels , kernel_size , stride , padding , bias_shape ]
0 commit comments