Skip to content

Commit 539c6bd

Browse files
Rename and fix 41_Gemm_BatchNorm_GELU_GroupNorm_Mean_ReLU.py → 41_Gemm_BatchNorm_GELU_ReLU.py
1 parent 3006786 commit 539c6bd

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

KernelBench/level2/41_Gemm_BatchNorm_GELU_GroupNorm_Mean_ReLU.py renamed to KernelBench/level2/41_Gemm_BatchNorm_GELU_ReLU.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33

44
class Model(nn.Module):
55
"""
6-
Model that performs a GEMM, BatchNorm, GELU, GroupNorm, Mean, and ReLU operations in sequence.
6+
Model that performs a GEMM, BatchNorm, GELU, and ReLU in sequence.
77
"""
88
def __init__(self, in_features, out_features, num_groups):
99
super(Model, self).__init__()
1010
self.gemm = nn.Linear(in_features, out_features)
1111
self.batch_norm = nn.BatchNorm1d(out_features)
12-
self.group_norm = nn.GroupNorm(num_groups, out_features)
1312

1413
def forward(self, x):
1514
"""
@@ -21,18 +20,16 @@ def forward(self, x):
2120
x = self.gemm(x)
2221
x = self.batch_norm(x)
2322
x = torch.nn.functional.gelu(x)
24-
x = self.group_norm(x)
25-
x = torch.mean(x, dim=1, keepdim=True)
2623
x = torch.relu(x)
2724
return x
2825

2926
batch_size = 128
3027
in_features = 512
3128
out_features = 1024
32-
num_groups = 8
29+
num_groups = 8 # Not used anymore
3330

3431
def get_inputs():
3532
return [torch.randn(batch_size, in_features)]
3633

3734
def get_init_inputs():
38-
return [in_features, out_features, num_groups]
35+
return [in_features, out_features, num_groups]

0 commit comments

Comments
 (0)