@@ -27,11 +27,41 @@ def test_bernoulli_overflow_gradient(init_tensor_type):
2727 assert_equal (p .grad .data [0 ], 0 )
2828
2929
30+ @pytest .mark .parametrize ('init_tensor_type' , [torch .FloatTensor ])
31+ def test_bernoulli_with_logits_underflow_gradient (init_tensor_type ):
32+ p = Variable (init_tensor_type ([- 1e40 ]), requires_grad = True )
33+ bernoulli = Bernoulli (logits = p )
34+ log_pdf = bernoulli .batch_log_pdf (Variable (init_tensor_type ([0 ])))
35+ log_pdf .sum ().backward ()
36+ assert_equal (log_pdf .data [0 ], 0 )
37+ assert_equal (p .grad .data [0 ], 0 )
38+
39+
40+ @pytest .mark .parametrize ('init_tensor_type' , [torch .DoubleTensor , torch .FloatTensor ])
41+ def test_bernoulli_with_logits_overflow_gradient (init_tensor_type ):
42+ p = Variable (init_tensor_type ([1e40 ]), requires_grad = True )
43+ bernoulli = Bernoulli (logits = p )
44+ log_pdf = bernoulli .batch_log_pdf (Variable (init_tensor_type ([1 ])))
45+ log_pdf .sum ().backward ()
46+ assert_equal (log_pdf .data [0 ], 0 )
47+ assert_equal (p .grad .data [0 ], 0 )
48+
49+
3050@pytest .mark .parametrize ('init_tensor_type' , [torch .DoubleTensor , torch .FloatTensor ])
3151def test_categorical_gradient (init_tensor_type ):
3252 p = Variable (init_tensor_type ([0 , 1 ]), requires_grad = True )
33- bernoulli = Categorical (p )
34- log_pdf = bernoulli .batch_log_pdf (Variable (init_tensor_type ([0 , 1 ])))
53+ categorical = Categorical (p )
54+ log_pdf = categorical .batch_log_pdf (Variable (init_tensor_type ([0 , 1 ])))
55+ log_pdf .sum ().backward ()
56+ assert_equal (log_pdf .data [0 ], 0 )
57+ assert_equal (p .grad .data [0 ], 0 )
58+
59+
60+ @pytest .mark .parametrize ('init_tensor_type' , [torch .DoubleTensor , torch .FloatTensor ])
61+ def test_categorical_gradient_with_logits (init_tensor_type ):
62+ p = Variable (init_tensor_type ([- float ('inf' ), 0 ]), requires_grad = True )
63+ categorical = Categorical (logits = p )
64+ log_pdf = categorical .batch_log_pdf (Variable (init_tensor_type ([0 , 1 ])))
3565 log_pdf .sum ().backward ()
3666 assert_equal (log_pdf .data [0 ], 0 )
3767 assert_equal (p .grad .data [0 ], 0 )
0 commit comments