Closed
Description
Description
This means equivalent joins don't get merged:
import pytensor
import pytensor.tensor as pt
from pytensor.graph.rewriting import rewrite_graph
x, y = pt.matrices("xy")
out1 = pt.join(-1, x, y)
out2 = pt.join(1, x, y)
pytensor.dprint(rewrite_graph([out1, out2], include=("fast_run",), exclude=("inplace",)))
# Join [id A]
# ├─ -1 [id B]
# ├─ x [id C]
# └─ y [id D]
# Join [id E]
# ├─ 1 [id F]
# ├─ x [id C]
# └─ y [id D]
We should convert negative constant axis to positive during canonicalization.