Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Sep 25, 2025

[trax] Explicitly set jax_pmap_shmap_merge=False.

trainer._multi_device_update_fn uses jax.pmap and when jax_pmap_shmap_merge=True, jax.pmap requires inputs be explicitly sharded as the underlying jax.jit expects.

This would need to be fixed if jax_pmap_shmap_merge=True.

@copybara-service copybara-service bot changed the title Disable e2e tests when jax_pmap_shmap_merge=True. [trax] Explicitly set jax_pmap_shmap_merge=False. Sep 26, 2025
`trainer._multi_device_update_fn` uses `jax.pmap` and when `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly sharded as the underlying `jax.jit` expects.

This would need to be fixed if `jax_pmap_shmap_merge=True`.

PiperOrigin-RevId: 811810947
@copybara-service copybara-service bot merged commit 31022d6 into master Sep 26, 2025
@copybara-service copybara-service bot deleted the test_811399636 branch September 26, 2025 14:37
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants