Skip to content

Fix 1497 - Change tolerance used to decide whether a constant is one in rewrite functions #1526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

lciti
Copy link
Contributor

@lciti lciti commented Jul 7, 2025

Description

The previous tolerance used within a rewrite to decide whether a constant is one (or minus one) is too large.
For example c - sigmoid(x) is rewritten as sigmoid(-x) even when $c=1 − p$ where p is 1 in 10000.
Many rewrites currently use np.isclose and np.allclose with the default tolerances (rtol=1e-05, atol=1e-08), which are unnecessarily large (and independent on the data type of the constant computed).

This PR implements a function isclose used within all rewrites in place of np.isclose and np.allclose. This new function uses a much smaller tolerance by default, i.e. 10 unit in the last place (ULPs). This tolerance is dtype dependent, so it's stricter for a float64 than a float32. See #1497 for a back of the envelope justification for choosing 10 ULPs.

This PR also implements allow_cast in PatternNodeRewriter to allow rewrites that would otherwise fail when the new and old dtype differ. For example, a rewrite attempt for np.array(1., "float64") - sigmoid(x) (where x is fmatrix) currently fails because in the rewrite sigmoid(-x) the type would change. This PR allows an automatic cast to be added so the expression is rewritten as cast(sigmoid(-x), "float64").

Relevant tests added.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1526.org.readthedocs.build/en/1526/

Luca Citi added 4 commits July 7, 2025 15:25
to allow rewrites that would otherwise fail when the new and old dtype differ.
Example:
`np.array(1., "float64") - sigmoid(x)` cannot be rewritten as
`sigmoid(-x)` (where x is an fmatrix) because the type would change.
This commit allows an automatic cast to be added so the expression
is rewritten as `cast(sigmoid(-x), "float64")`.
Relevant tests added.
…ion isclose, which uses 10 ULPs by default
@lciti lciti changed the title Fix 1497 Fix 1497 - Change tolerance used to decide whether a constant is one in rewrite functions Jul 7, 2025
Comment on lines 1672 to 1673
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all types have a dtype, we should check it's a TensorType before even trying to access dtype and doing stuff with it. I would perhaps write like this:

The whole logic is weird though with the if ret.owner, why do we care about the type of outputs we're not replacing. It's actually dangerous to try to replace only one of them without the user consent. Since this is WIP I would change to if len(node.outputs) != 1: return False, before we try to unify.

Then here we just have to worry about the final else branch below:

[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
  if not (
    self.allow_cast 
    and isinstance(old_out.type, TensorType) 
    and isinstance(ret.type, TensorType)
  ):
    return False

  # Try to cast
  ret = ret.astype(old_out.type.dtype)
  if not old_out.type.is_super(ret.type):
    return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to replace as you suggest but I am not sure how to fit it within the rest. This is the current code:

        if ret.owner:
            if not (
                len(node.outputs) == len(ret.owner.outputs)
                and all(
                    o.type.is_super(new_o.type)
                    for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
                )
            ):
                return False
        else:
            # ret is just an input variable
            assert len(node.outputs) == 1
            if not node.outputs[0].type.is_super(ret.type):
                return False

(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
(np.array([[1.0]], "float64") - sigmoid(xd), sigmoid(-xd)),
]:
f = pytensor.function([x, xd], out, m, on_unused_input="ignore")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are not evaluating f, just rewrite it with rewrite_graph, possibly including ("canonicalize", "stabilize", "specialize"), or whatever is needed

Comment on lines +4144 to +4149
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format(
*(
pytensor.dprint(expr, print_type=True, file="str")
for expr in (out, f_outs, expected)
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do pytensor.dprint(tuple[Variable]). If you want the rewritten, expected, which many times I do while writing these sort of tests we could add an assert_equal_computations helper that does that. That way it's reusable and doesn't make each test very verbose like this?

out_dtype = node.outputs[0].type.dtype
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
if self.allow_cast:
Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was reviewing commit by commit, see you changed this after. Anyway my original comment still stands.

Generally, feel free to squash commits and force-push when iterating on a PR so the git changes stay clean

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: the tolerance used to decide whether a constant is one (or minus one) in rewrite functions may be too large
2 participants