Skip to content

Commit 92b98f5

Browse files
committed
Deny broadcasting of the measurable input
1 parent dfb05b6 commit 92b98f5

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

pymc/logprob/binary.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def find_measurable_comparisons(
5555
[measurable_var] = measurable_inputs
5656
measurable_var_idx = node.inputs.index(measurable_var)
5757

58+
# deny broadcasting of the measurable input
59+
if measurable_var.type.broadcastable != node.outputs[0].type.broadcastable:
60+
return None
61+
5862
# Check that the other input is not potentially measurable, in which case this rewrite
5963
# would be invalid
6064
const = node.inputs[(measurable_var_idx + 1) % 2]

tests/logprob/test_binary.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,15 @@ def test_potentially_measurable_operand():
147147
match="Logprob method not implemented",
148148
):
149149
logp(y_rv, y_vv).eval({y_vv: y_vv_test})
150+
151+
152+
def test_comparison_invalid_broadcast():
153+
x_rv = pt.random.normal(0.5, 1, size=(3,))
154+
155+
const = np.array([[0.1], [0.2], [-0.1]])
156+
y_rv_invalid = pt.gt(x_rv, const)
157+
158+
y_vv_invalid = y_rv_invalid.clone()
159+
160+
with pytest.raises(NotImplementedError, match="Logprob method not implemented for"):
161+
logp(y_rv_invalid, y_vv_invalid)

0 commit comments

Comments
 (0)