1+ import torch
2+ import torch .nn as nn
3+ import torch .nn .functional as F
4+ from torch .utils .cpp_extension import load_inline
5+
6+ # Define the custom CUDA kernel for fused conv3d transpose + clamp + div
7+ conv_clamp_div_source = """
8+ #include <torch/extension.h>
9+ #include <cuda_runtime.h>
10+ #include <c10/cuda/CUDAGuard.h>
11+
12+ __global__ void conv_transpose3d_clamp_div_kernel(
13+ const float* input,
14+ const float* weight,
15+ const float* bias,
16+ float* output,
17+ const int batch_size,
18+ const int in_channels,
19+ const int out_channels,
20+ const int input_depth,
21+ const int input_height,
22+ const int input_width,
23+ const int output_depth,
24+ const int output_height,
25+ const int output_width,
26+ const int kernel_size,
27+ const int stride,
28+ const int padding,
29+ const float min_value,
30+ const float divisor
31+ ) {
32+ // Calculate output position
33+ int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
34+ int total_output_elements = batch_size * out_channels * output_depth * output_height * output_width;
35+
36+ if (out_idx >= total_output_elements) return;
37+
38+ // Decompose output index
39+ int temp = out_idx;
40+ const int w_out = temp % output_width;
41+ temp /= output_width;
42+ const int h_out = temp % output_height;
43+ temp /= output_height;
44+ const int d_out = temp % output_depth;
45+ temp /= output_depth;
46+ const int c_out = temp % out_channels;
47+ const int n = temp / out_channels;
48+
49+ // Calculate input region that contributes to this output element
50+ const int d_start = d_out * stride - padding;
51+ const int h_start = h_out * stride - padding;
52+ const int w_start = w_out * stride - padding;
53+
54+ float sum = 0.0f;
55+
56+ // Iterate through kernel
57+ for (int kd = 0; kd < kernel_size; kd++) {
58+ for (int kh = 0; kh < kernel_size; kh++) {
59+ for (int kw = 0; kw < kernel_size; kw++) {
60+ const int d_in = d_start + kd;
61+ const int h_in = h_start + kh;
62+ const int w_in = w_start + kw;
63+
64+ // Check if input position is valid
65+ if (d_in >= 0 && d_in < input_depth &&
66+ h_in >= 0 && h_in < input_height &&
67+ w_in >= 0 && w_in < input_width) {
68+
69+ // Iterate through input channels
70+ for (int c_in = 0; c_in < in_channels; c_in++) {
71+ // Calculate weight index (output channel, input channel, kernel dimensions)
72+ const int weight_idx = c_out * (in_channels * kernel_size * kernel_size * kernel_size) +
73+ c_in * (kernel_size * kernel_size * kernel_size) +
74+ kd * (kernel_size * kernel_size) +
75+ kh * kernel_size +
76+ kw;
77+
78+ // Calculate input index
79+ const int input_idx = n * (in_channels * input_depth * input_height * input_width) +
80+ c_in * (input_depth * input_height * input_width) +
81+ d_in * (input_height * input_width) +
82+ h_in * input_width +
83+ w_in;
84+
85+ sum += input[input_idx] * weight[weight_idx];
86+ }
87+ }
88+ }
89+ }
90+ }
91+
92+ // Add bias
93+ sum += bias[c_out];
94+
95+ // Apply clamp and division
96+ sum = fmaxf(sum, min_value);
97+ sum = sum / divisor;
98+
99+ // Write output
100+ output[out_idx] = sum;
101+ }
102+
103+ torch::Tensor conv_transpose3d_clamp_div_cuda(
104+ torch::Tensor input,
105+ torch::Tensor weight,
106+ torch::Tensor bias,
107+ int kernel_size,
108+ int stride,
109+ int padding,
110+ float min_value,
111+ float divisor
112+ ) {
113+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
114+
115+ // Get dimensions
116+ const int batch_size = input.size(0);
117+ const int in_channels = input.size(1);
118+ const int input_depth = input.size(2);
119+ const int input_height = input.size(3);
120+ const int input_width = input.size(4);
121+
122+ // Calculate output dimensions
123+ const int output_depth = (input_depth - 1) * stride - 2 * padding + kernel_size;
124+ const int output_height = (input_height - 1) * stride - 2 * padding + kernel_size;
125+ const int output_width = (input_width - 1) * stride - 2 * padding + kernel_size;
126+
127+ // Create output tensor
128+ auto output = torch::zeros({batch_size, weight.size(0), output_depth, output_height, output_width},
129+ torch::TensorOptions().dtype(input.dtype()).device(input.device()));
130+
131+ // Calculate total number of output elements
132+ const int total_output_elements = batch_size * weight.size(0) * output_depth * output_height * output_width;
133+
134+ // Configure kernel launch parameters
135+ const int block_size = 256;
136+ const int num_blocks = (total_output_elements + block_size - 1) / block_size;
137+
138+ // Launch kernel
139+ conv_transpose3d_clamp_div_kernel<<<num_blocks, block_size>>>(
140+ input.data_ptr<float>(),
141+ weight.data_ptr<float>(),
142+ bias.data_ptr<float>(),
143+ output.data_ptr<float>(),
144+ batch_size,
145+ in_channels,
146+ weight.size(0),
147+ input_depth,
148+ input_height,
149+ input_width,
150+ output_depth,
151+ output_height,
152+ output_width,
153+ kernel_size,
154+ stride,
155+ padding,
156+ min_value,
157+ divisor
158+ );
159+
160+ return output;
161+ }
162+ """
163+
164+ conv_clamp_div_cpp_source = """
165+ torch::Tensor conv_transpose3d_clamp_div_cuda(
166+ torch::Tensor input,
167+ torch::Tensor weight,
168+ torch::Tensor bias,
169+ int kernel_size,
170+ int stride,
171+ int padding,
172+ float min_value,
173+ float divisor
174+ );
175+ """
176+
177+ # Compile the inline CUDA code
178+ conv_clamp_div = load_inline (
179+ name = "conv_clamp_div" ,
180+ cpp_sources = conv_clamp_div_cpp_source ,
181+ cuda_sources = conv_clamp_div_source ,
182+ functions = ["conv_transpose3d_clamp_div_cuda" ],
183+ verbose = False ,
184+ extra_cflags = ["" ],
185+ extra_ldflags = ["" ],
186+ )
187+
188+ class ModelNew (nn .Module ):
189+ """
190+ Optimized model with fused conv3d transpose + clamp + div operation
191+ """
192+ def __init__ (self , in_channels , out_channels , kernel_size , stride , padding , min_value , divisor ):
193+ super (ModelNew , self ).__init__ ()
194+ self .conv_transpose = nn .ConvTranspose3d (in_channels , out_channels , kernel_size , stride = stride , padding = padding )
195+ self .min_value = min_value
196+ self .divisor = divisor
197+ self .fused_op = conv_clamp_div
198+
199+ def forward (self , x ):
200+ return self .fused_op .conv_transpose3d_clamp_div_cuda (
201+ x ,
202+ self .conv_transpose .weight ,
203+ self .conv_transpose .bias ,
204+ self .conv_transpose .kernel_size [0 ],
205+ self .conv_transpose .stride [0 ],
206+ self .conv_transpose .padding [0 ],
207+ self .min_value ,
208+ self .divisor
209+ )
0 commit comments