-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Cast ZeroSumNormal
shape operations to config.floatX
#6889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6889 +/- ##
==========================================
- Coverage 92.05% 91.69% -0.37%
==========================================
Files 96 100 +4
Lines 16446 16851 +405
==========================================
+ Hits 15140 15451 +311
- Misses 1306 1400 +94
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just a small suggestion to make the test faster
with pytensor.config.change_flags(floatX="float32", warn_float64="raise"): | ||
with pm.Model(): | ||
pm.ZeroSumNormal("b", sigma=1, shape=(2,)) | ||
pm.sample(1, chains=1, tune=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be enough to call model.logp()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think it is enough because logp
does not call ZeroSumTransform.backward
. pm.sample
is a way to test both ZeroSumTransform.backward
and logp
. I added a comment about this interaction in 7906292
(#6889)
Note that the new test is mostly a non-regression test for #6886.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see how m.logp
works. Thanks for the suggestion.
ZeroSumNormal
shape operations to config.floatX
I suggested just calling |
Okay, I updated the test to use |
with pytensor.config.change_flags(floatX="float32", warn_float64="raise"): | ||
with pm.Model(): | ||
zsn = pm.ZeroSumNormal("b", sigma=1, shape=(2,)) | ||
pm.logp(zsn, value=np.zeros((2,))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I wasn't clear. I meant you should call model.logp()
.
import pymc as pm
import pytensor
with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
with pm.Model() as m:
pm.ZeroSumNormal("b", sigma=1, shape=(2,))
m.logp()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should call both backward and forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks !!!
Thanks @thomasjpfan! |
What is this PR about?
This PR prevents ZeroSumNormal from upcasting to float64, when
pytensor
is configured withfloatX = float32
.Checklist
Bugfixes
📚 Documentation preview 📚: https://pymc--6889.org.readthedocs.build/en/6889/