replaces numpy sqrt method with pytensor equivalent #6405
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What is this PR about?
When using JAX on a Google hosted VM with TPU support, the float64 dtype is not supported. Unfortunately, numpy uses this as default and aggressively casts arrays to float64. This caused an issue where my model wouldn't compile at all. After debugging, I located the issue to the zerosum-normal transformation, where
np.sqrt
is used. In this PR I'm replacing numpy with the pytensor equivalent, which fixes the issue.Checklist
Major / Breaking Changes
New features
Bugfixes
zerosumnormal
's numpy operators with their pytensor counterparts to make them run on JAXDocumentation
Maintenance