File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -859,7 +859,7 @@ def test_full_loop_no_noise(self):
859859 result_mean = jnp .mean (jnp .abs (sample ))
860860
861861 if jax_device == "tpu" :
862- assert abs (result_sum - 198.1542 ) < 1e-2
862+ assert abs (result_sum - 198.1275 ) < 1e-2
863863 assert abs (result_mean - 0.2580 ) < 1e-3
864864 else :
865865 assert abs (result_sum - 198.1318 ) < 1e-2
@@ -872,8 +872,8 @@ def test_full_loop_with_set_alpha_to_one(self):
872872 result_mean = jnp .mean (jnp .abs (sample ))
873873
874874 if jax_device == "tpu" :
875- assert abs (result_sum - 185.4352 ) < 1e-2
876- assert abs (result_mean - 0.24145 ) < 1e-3
875+ assert abs (result_sum - 186.83226 ) < 1e-2
876+ assert abs (result_mean - 0.24327 ) < 1e-3
877877 else :
878878 assert abs (result_sum - 186.9466 ) < 1e-2
879879 assert abs (result_mean - 0.24342 ) < 1e-3
@@ -885,8 +885,8 @@ def test_full_loop_with_no_set_alpha_to_one(self):
885885 result_mean = jnp .mean (jnp .abs (sample ))
886886
887887 if jax_device == "tpu" :
888- assert abs (result_sum - 185.4352 ) < 1e-2
889- assert abs (result_mean - 0.2414 ) < 1e-3
888+ assert abs (result_sum - 186.83226 ) < 1e-2
889+ assert abs (result_mean - 0.24327 ) < 1e-3
890890 else :
891891 assert abs (result_sum - 186.9482 ) < 1e-2
892892 assert abs (result_mean - 0.2434 ) < 1e-3
You can’t perform that action at this time.
0 commit comments