Skip to content

Commit 2d7893e

Browse files
neerajpradfritzo
authored andcommitted
Add tests to check for nans with logits (pyro-ppl#454)
1 parent 7e9886d commit 2d7893e

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

pyro/distributions/categorical.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ class Categorical(Distribution):
1919
:param ps: Probabilities. These should be non-negative and normalized
2020
along the rightmost axis.
2121
:type ps: `torch.autograd.Variable`.
22-
:param logits: Non-normalized log probability values. Either `ps` or `logits`
23-
should be specified but not both.
22+
:param logits: Log probability values. When exonentiated, these should
23+
sum to 1 along the last axis. Either `ps` or `logits` should be
24+
specified but not both.
2425
:type logits: `torch.autograd.Variable`.
2526
:param vs: Optional list of values in the support.
2627
:type vs: `list` or `numpy.array` or `torch.autograd.Variable`

pyro/distributions/util.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,20 @@ def _get_clamping_buffer(tensor):
150150

151151
def get_probs_and_logits(ps=None, logits=None, is_multidimensional=True):
152152
"""
153-
Convert probability values to logits, or vice-versa. Either `ps` or
154-
`logits` should be specified, but not both.
153+
Convert probability values to logits, or vice-versa. Either ``ps`` or
154+
``logits`` should be specified, but not both.
155155
156156
:param ps: tensor of probabilities. Should be in the interval *[0, 1]*.
157-
If, `is_multidimensional = True`, then must be normalized along
157+
If, ``is_multidimensional = True``, then must be normalized along
158158
axis -1.
159-
:param logits: tensor of logit values.
159+
:param logits: tensor of logit values. For the multidimensional case,
160+
the values, when exponentiated along the last dimension, must sum
161+
to 1.
160162
:param is_multidimensional: determines the computation of ps from logits,
161163
and vice-versa. For the multi-dimensional case, logit values are
162-
assumed to be non-normalized log probabilities, whereas for the uni-
163-
dimensional case, it specifically refers to log odds.
164-
:return: tuple containing raw probabilities and logits as tensors
164+
assumed to be log probabilities, whereas for the uni-dimensional case,
165+
it specifically refers to log odds.
166+
:return: tuple containing raw probabilities and logits as tensors.
165167
"""
166168
assert (ps is None) != (logits is None)
167169
if ps is not None:

tests/distributions/test_gradient_flow.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,41 @@ def test_bernoulli_overflow_gradient(init_tensor_type):
2727
assert_equal(p.grad.data[0], 0)
2828

2929

30+
@pytest.mark.parametrize('init_tensor_type', [torch.FloatTensor])
31+
def test_bernoulli_with_logits_underflow_gradient(init_tensor_type):
32+
p = Variable(init_tensor_type([-1e40]), requires_grad=True)
33+
bernoulli = Bernoulli(logits=p)
34+
log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([0])))
35+
log_pdf.sum().backward()
36+
assert_equal(log_pdf.data[0], 0)
37+
assert_equal(p.grad.data[0], 0)
38+
39+
40+
@pytest.mark.parametrize('init_tensor_type', [torch.DoubleTensor, torch.FloatTensor])
41+
def test_bernoulli_with_logits_overflow_gradient(init_tensor_type):
42+
p = Variable(init_tensor_type([1e40]), requires_grad=True)
43+
bernoulli = Bernoulli(logits=p)
44+
log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([1])))
45+
log_pdf.sum().backward()
46+
assert_equal(log_pdf.data[0], 0)
47+
assert_equal(p.grad.data[0], 0)
48+
49+
3050
@pytest.mark.parametrize('init_tensor_type', [torch.DoubleTensor, torch.FloatTensor])
3151
def test_categorical_gradient(init_tensor_type):
3252
p = Variable(init_tensor_type([0, 1]), requires_grad=True)
33-
bernoulli = Categorical(p)
34-
log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([0, 1])))
53+
categorical = Categorical(p)
54+
log_pdf = categorical.batch_log_pdf(Variable(init_tensor_type([0, 1])))
55+
log_pdf.sum().backward()
56+
assert_equal(log_pdf.data[0], 0)
57+
assert_equal(p.grad.data[0], 0)
58+
59+
60+
@pytest.mark.parametrize('init_tensor_type', [torch.DoubleTensor, torch.FloatTensor])
61+
def test_categorical_gradient_with_logits(init_tensor_type):
62+
p = Variable(init_tensor_type([-float('inf'), 0]), requires_grad=True)
63+
categorical = Categorical(logits=p)
64+
log_pdf = categorical.batch_log_pdf(Variable(init_tensor_type([0, 1])))
3565
log_pdf.sum().backward()
3666
assert_equal(log_pdf.data[0], 0)
3767
assert_equal(p.grad.data[0], 0)

0 commit comments

Comments
 (0)