Skip to content

Commit 864ecb3

Browse files
committed
Fix bug in switch mixture logp
The True and False branches were being mixed up
1 parent 8b5f437 commit 864ecb3

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pymc/logprob/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def find_measurable_switch_mixture(fgraph, node):
344344
old_mixture_rv.broadcastable,
345345
)
346346
new_mixture_rv = mix_op.make_node(
347-
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components)
347+
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1])
348348
).default_output()
349349

350350
if pytensor.config.compute_test_value != "off":

tests/logprob/test_mixture.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,12 @@ def test_switch_mixture():
917917
i_vv = I_rv.clone()
918918
i_vv.name = "i"
919919

920+
# When I_rv == True, X_rv flows through otherwise Y_rv does
920921
Z1_rv = pt.switch(I_rv, X_rv, Y_rv)
922+
923+
assert Z1_rv.eval({I_rv: 0}) > 5
924+
assert Z1_rv.eval({I_rv: 1}) < -5
925+
921926
z_vv = Z1_rv.clone()
922927
z_vv.name = "z1"
923928

@@ -935,7 +940,10 @@ def test_switch_mixture():
935940

936941
# building the identical graph but with a stack to check that mixture computations are identical
937942

938-
Z2_rv = pt.stack((X_rv, Y_rv))[I_rv]
943+
Z2_rv = pt.stack((Y_rv, X_rv))[I_rv]
944+
945+
assert Z2_rv.eval({I_rv: 0}) > 5
946+
assert Z2_rv.eval({I_rv: 1}) < -5
939947

940948
fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
941949

@@ -949,8 +957,8 @@ def test_switch_mixture():
949957
# below should follow immediately from the equal_computations assertion above
950958
assert equal_computations([z1_logp_combined], [z2_logp_combined])
951959

952-
np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 0}))
953-
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 0}))
960+
np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1}))
961+
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1}))
954962

955963

956964
def test_ifelse_mixture_one_component():

0 commit comments

Comments
 (0)