Skip to content

replaces numpy sqrt method with pytensor equivalent #6405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2022

Conversation

morganstrom
Copy link
Contributor

@morganstrom morganstrom commented Dec 20, 2022

What is this PR about?
When using JAX on a Google hosted VM with TPU support, the float64 dtype is not supported. Unfortunately, numpy uses this as default and aggressively casts arrays to float64. This caused an issue where my model wouldn't compile at all. After debugging, I located the issue to the zerosum-normal transformation, where np.sqrt is used. In this PR I'm replacing numpy with the pytensor equivalent, which fixes the issue.

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • Replace zerosumnormal's numpy operators with their pytensor counterparts to make them run on JAX

Documentation

  • ...

Maintenance

  • ...

@ricardoV94
Copy link
Member

Would be good to check for other constants in other transforms/logp/logcdf methods

@codecov
Copy link

codecov bot commented Dec 20, 2022

Codecov Report

Merging #6405 (a674948) into main (f231d13) will increase coverage by 8.68%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6405      +/-   ##
==========================================
+ Coverage   86.05%   94.74%   +8.68%     
==========================================
  Files         148      148              
  Lines       27645    27645              
==========================================
+ Hits        23791    26193    +2402     
+ Misses       3854     1452    -2402     
Impacted Files Coverage Δ
pymc/distributions/transforms.py 99.36% <100.00%> (ø)
pymc/sampling/parallel.py 88.42% <0.00%> (+1.05%) ⬆️
pymc/logprob/cumsum.py 100.00% <0.00%> (+3.12%) ⬆️
pymc/tests/logprob/utils.py 50.00% <0.00%> (+3.65%) ⬆️
pymc/logprob/rewriting.py 97.05% <0.00%> (+5.88%) ⬆️
pymc/logprob/abstract.py 97.56% <0.00%> (+6.09%) ⬆️
pymc/logprob/utils.py 100.00% <0.00%> (+13.79%) ⬆️
pymc/logprob/joint_logprob.py 97.01% <0.00%> (+19.40%) ⬆️
pymc/logprob/tensor.py 82.40% <0.00%> (+24.00%) ⬆️
pymc/logprob/transforms.py 96.42% <0.00%> (+28.63%) ⬆️
... and 16 more

@morganstrom
Copy link
Contributor Author

Would be good to check for other constants in other transforms/logp/logcdf methods

Do you need me to do this to merge @ricardoV94? Seems like it could take some time to go through all of the places where a numpy function is used..

@ricardoV94
Copy link
Member

No, it doesn't need to be done now, but it would be great 😃

@twiecki
Copy link
Member

twiecki commented Dec 20, 2022

We can certainly handle that in other PRs.

@twiecki twiecki merged commit 982e4c4 into pymc-devs:main Dec 20, 2022
@twiecki
Copy link
Member

twiecki commented Dec 20, 2022

Thanks for the contribution @morganstrom!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants