-
Notifications
You must be signed in to change notification settings - Fork 137
Fix 1497 - Change tolerance used to decide whether a constant is one in rewrite functions #1526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
to allow rewrites that would otherwise fail when the new and old dtype differ. Example: `np.array(1., "float64") - sigmoid(x)` cannot be rewritten as `sigmoid(-x)` (where x is an fmatrix) because the type would change. This commit allows an automatic cast to be added so the expression is rewritten as `cast(sigmoid(-x), "float64")`. Relevant tests added.
…tain dtype like MyType in the tests
…ion isclose, which uses 10 ULPs by default
pytensor/graph/rewriting/basic.py
Outdated
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype: | ||
ret = pytensor.tensor.basic.cast(ret, out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all types have a dtype, we should check it's a TensorType before even trying to access dtype
and doing stuff with it. I would perhaps write like this:
The whole logic is weird though with the if ret.owner
, why do we care about the type of outputs we're not replacing. It's actually dangerous to try to replace only one of them without the user consent. Since this is WIP I would change to if len(node.outputs) != 1: return False
, before we try to unify.
Then here we just have to worry about the final else branch below:
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
if not (
self.allow_cast
and isinstance(old_out.type, TensorType)
and isinstance(ret.type, TensorType)
):
return False
# Try to cast
ret = ret.astype(old_out.type.dtype)
if not old_out.type.is_super(ret.type):
return False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to replace as you suggest but I am not sure how to fit it within the rest. This is the current code:
if ret.owner:
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
)
):
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
if not node.outputs[0].type.is_super(ret.type):
return False
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)), | ||
(np.array([[1.0]], "float64") - sigmoid(xd), sigmoid(-xd)), | ||
]: | ||
f = pytensor.function([x, xd], out, m, on_unused_input="ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are not evaluating f, just rewrite it with rewrite_graph
, possibly including ("canonicalize", "stabilize", "specialize")
, or whatever is needed
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format( | ||
*( | ||
pytensor.dprint(expr, print_type=True, file="str") | ||
for expr in (out, f_outs, expected) | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do pytensor.dprint(tuple[Variable])
. If you want the rewritten, expected, which many times I do while writing these sort of tests we could add an assert_equal_computations
helper that does that. That way it's reusable and doesn't make each test very verbose like this?
out_dtype = node.outputs[0].type.dtype | ||
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype: | ||
ret = pytensor.tensor.basic.cast(ret, out_dtype) | ||
if self.allow_cast: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was reviewing commit by commit, see you changed this after. Anyway my original comment still stands.
Generally, feel free to squash commits and force-push when iterating on a PR so the git changes stay clean
Description
The previous tolerance used within a rewrite to decide whether a constant is one (or minus one) is too large.$c=1 − p$ where p is 1 in 10000.
For example
c - sigmoid(x)
is rewritten assigmoid(-x)
even whenMany rewrites currently use np.isclose and np.allclose with the default tolerances (rtol=1e-05, atol=1e-08), which are unnecessarily large (and independent on the data type of the constant computed).
This PR implements a function
isclose
used within all rewrites in place ofnp.isclose
andnp.allclose
. This new function uses a much smaller tolerance by default, i.e. 10 unit in the last place (ULPs). This tolerance is dtype dependent, so it's stricter for a float64 than a float32. See #1497 for a back of the envelope justification for choosing 10 ULPs.This PR also implements
allow_cast
in PatternNodeRewriter to allow rewrites that would otherwise fail when the new and old dtype differ. For example, a rewrite attempt fornp.array(1., "float64") - sigmoid(x)
(where x isfmatrix
) currently fails because in the rewritesigmoid(-x)
the type would change. This PR allows an automatic cast to be added so the expression is rewritten ascast(sigmoid(-x), "float64")
.Relevant tests added.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1526.org.readthedocs.build/en/1526/