Skip to content

Join with constant axis doesn't get canonicalized #1505

Closed
@ricardoV94

Description

@ricardoV94

Description

This means equivalent joins don't get merged:

import pytensor
import pytensor.tensor as pt
from pytensor.graph.rewriting import rewrite_graph

x, y = pt.matrices("xy")
out1 = pt.join(-1, x, y)
out2 = pt.join(1, x, y)
pytensor.dprint(rewrite_graph([out1, out2], include=("fast_run",), exclude=("inplace",)))
# Join [id A]
#  ├─ -1 [id B]
#  ├─ x [id C]
#  └─ y [id D]
# Join [id E]
#  ├─ 1 [id F]
#  ├─ x [id C]
#  └─ y [id D]

We should convert negative constant axis to positive during canonicalization.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions