Skip to content

Commit 4356a08

Browse files
Restore 45_Gemm_Sigmoid_Sum_LogSumExp.py
1 parent fff5f09 commit 4356a08

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class Model(nn.Module):
5+
"""
6+
Model that performs a matrix multiplication (Gemm), applies Sigmoid, sums the result, and calculates the LogSumExp.
7+
"""
8+
def __init__(self, input_size, hidden_size, output_size):
9+
super(Model, self).__init__()
10+
self.linear1 = nn.Linear(input_size, hidden_size)
11+
self.linear2 = nn.Linear(hidden_size, output_size)
12+
13+
def forward(self, x):
14+
x = self.linear1(x)
15+
x = torch.sigmoid(x)
16+
x = torch.sum(x, dim=1)
17+
x = torch.logsumexp(x, dim=0)
18+
return x
19+
20+
batch_size = 128
21+
input_size = 10
22+
hidden_size = 20
23+
output_size = 5
24+
25+
def get_inputs():
26+
return [torch.randn(batch_size, input_size)]
27+
28+
def get_init_inputs():
29+
return [input_size, hidden_size, output_size]

0 commit comments

Comments
 (0)