Skip to content

Commit 8bc2556

Browse files
gokceneraslanfritzo
authored andcommitted
Make pyro.distributions.util.log_gamma compatible with Tensors (pyro-ppl#509)
Current implementation of log_gamma is not compatible with Tensors.
1 parent ece8ccc commit 8bc2556

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pyro/distributions/util.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77

88
def log_gamma(xx):
9-
if isinstance(xx, Variable):
10-
ttype = xx.data.type()
11-
elif isinstance(xx, torch.Tensor):
12-
ttype = xx.type()
9+
if isinstance(xx, torch.Tensor):
10+
xx = Variable(xx)
11+
ttype = xx.data.type()
1312
gamma_coeff = [
1413
76.18009172947146,
1514
-86.50532032941677,

0 commit comments

Comments
 (0)