@@ -917,7 +917,12 @@ def test_switch_mixture():
917
917
i_vv = I_rv .clone ()
918
918
i_vv .name = "i"
919
919
920
+ # When I_rv == True, X_rv flows through otherwise Y_rv does
920
921
Z1_rv = pt .switch (I_rv , X_rv , Y_rv )
922
+
923
+ assert Z1_rv .eval ({I_rv : 0 }) > 5
924
+ assert Z1_rv .eval ({I_rv : 1 }) < - 5
925
+
921
926
z_vv = Z1_rv .clone ()
922
927
z_vv .name = "z1"
923
928
@@ -935,7 +940,10 @@ def test_switch_mixture():
935
940
936
941
# building the identical graph but with a stack to check that mixture computations are identical
937
942
938
- Z2_rv = pt .stack ((X_rv , Y_rv ))[I_rv ]
943
+ Z2_rv = pt .stack ((Y_rv , X_rv ))[I_rv ]
944
+
945
+ assert Z2_rv .eval ({I_rv : 0 }) > 5
946
+ assert Z2_rv .eval ({I_rv : 1 }) < - 5
939
947
940
948
fgraph2 , _ , _ = construct_ir_fgraph ({Z2_rv : z_vv , I_rv : i_vv })
941
949
@@ -949,8 +957,8 @@ def test_switch_mixture():
949
957
# below should follow immediately from the equal_computations assertion above
950
958
assert equal_computations ([z1_logp_combined ], [z2_logp_combined ])
951
959
952
- np .testing .assert_almost_equal (0.69049938 , z1_logp_combined .eval ({z_vv : - 10 , i_vv : 0 }))
953
- np .testing .assert_almost_equal (0.69049938 , z2_logp_combined .eval ({z_vv : - 10 , i_vv : 0 }))
960
+ np .testing .assert_almost_equal (0.69049938 , z1_logp_combined .eval ({z_vv : - 10 , i_vv : 1 }))
961
+ np .testing .assert_almost_equal (0.69049938 , z2_logp_combined .eval ({z_vv : - 10 , i_vv : 1 }))
954
962
955
963
956
964
def test_ifelse_mixture_one_component ():
0 commit comments