Skip to content

Commit 0048f22

Browse files
desimoneapaszke
authored andcommitted
Add spatial test for LogSoftmax
1 parent 2748b92 commit 0048f22

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

test/common_nn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@
9797
input_size=(10, 20),
9898
reference_fn=lambda i,_: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_()
9999
),
100+
dict(
101+
module_name='LogSoftmax',
102+
input_size=(1, 3, 10, 20),
103+
reference_fn=lambda i,_: torch.exp(i).div_(torch.exp(i).sum(1).expand_as(i)).log_(),
104+
desc='multiparam'
105+
),
100106
dict(
101107
module_name='ELU',
102108
constructor_args=(2.,),

0 commit comments

Comments
 (0)