Skip to content

Commit 2e980ac

Browse files
authored
[Tests] Adjust TPU test values (huggingface#1233)
* [Tests] Adjust TPU test values * slow tests * remaining refs
1 parent 0feb21a commit 2e980ac

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/test_scheduler_flax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ def test_full_loop_no_noise(self):
859859
result_mean = jnp.mean(jnp.abs(sample))
860860

861861
if jax_device == "tpu":
862-
assert abs(result_sum - 198.1542) < 1e-2
862+
assert abs(result_sum - 198.1275) < 1e-2
863863
assert abs(result_mean - 0.2580) < 1e-3
864864
else:
865865
assert abs(result_sum - 198.1318) < 1e-2
@@ -872,8 +872,8 @@ def test_full_loop_with_set_alpha_to_one(self):
872872
result_mean = jnp.mean(jnp.abs(sample))
873873

874874
if jax_device == "tpu":
875-
assert abs(result_sum - 185.4352) < 1e-2
876-
assert abs(result_mean - 0.24145) < 1e-3
875+
assert abs(result_sum - 186.83226) < 1e-2
876+
assert abs(result_mean - 0.24327) < 1e-3
877877
else:
878878
assert abs(result_sum - 186.9466) < 1e-2
879879
assert abs(result_mean - 0.24342) < 1e-3
@@ -885,8 +885,8 @@ def test_full_loop_with_no_set_alpha_to_one(self):
885885
result_mean = jnp.mean(jnp.abs(sample))
886886

887887
if jax_device == "tpu":
888-
assert abs(result_sum - 185.4352) < 1e-2
889-
assert abs(result_mean - 0.2414) < 1e-3
888+
assert abs(result_sum - 186.83226) < 1e-2
889+
assert abs(result_mean - 0.24327) < 1e-3
890890
else:
891891
assert abs(result_sum - 186.9482) < 1e-2
892892
assert abs(result_mean - 0.2434) < 1e-3

0 commit comments

Comments
 (0)