77from tests .common import assert_equal
88
99
10- @pytest .mark .xfail (reason = "TODO: clamp logits to ensure finite values" )
1110@pytest .mark .parametrize ('init_tensor_type' , [torch .DoubleTensor , torch .FloatTensor ])
1211def test_bernoulli_underflow_gradient (init_tensor_type ):
1312 p = Variable (init_tensor_type ([0 ]), requires_grad = True )
@@ -18,7 +17,6 @@ def test_bernoulli_underflow_gradient(init_tensor_type):
1817 assert_equal (p .grad .data [0 ], 0 )
1918
2019
21- @pytest .mark .xfail (reason = "TODO: clamp logits to ensure finite values" )
2220@pytest .mark .parametrize ('init_tensor_type' , [torch .DoubleTensor , torch .FloatTensor ])
2321def test_bernoulli_overflow_gradient (init_tensor_type ):
2422 p = Variable (init_tensor_type ([1e32 ]), requires_grad = True )
@@ -29,6 +27,7 @@ def test_bernoulli_overflow_gradient(init_tensor_type):
2927 assert_equal (p .grad .data [0 ], 0 )
3028
3129
30+ @pytest .mark .xfail (reason = "TODO: clamp logits to ensure finite values" )
3231@pytest .mark .parametrize ('init_tensor_type' , [torch .FloatTensor ])
3332def test_bernoulli_with_logits_underflow_gradient (init_tensor_type ):
3433 p = Variable (init_tensor_type ([- 1e40 ]), requires_grad = True )
@@ -39,6 +38,7 @@ def test_bernoulli_with_logits_underflow_gradient(init_tensor_type):
3938 assert_equal (p .grad .data [0 ], 0 )
4039
4140
41+ @pytest .mark .xfail (reason = "TODO: clamp logits to ensure finite values" )
4242@pytest .mark .parametrize ('init_tensor_type' , [torch .DoubleTensor , torch .FloatTensor ])
4343def test_bernoulli_with_logits_overflow_gradient (init_tensor_type ):
4444 p = Variable (init_tensor_type ([1e40 ]), requires_grad = True )
0 commit comments