Skip to content

Commit 2695ba8

Browse files
Roy Hvaarahawkinsp
andauthored
[JAX] Replace uses of jax.devices("cpu") with jax.local_devices(backend="cpu") (huggingface#5864)
An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes. This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future. This change is always safe (i.e., it should always preserve the previous behavior), but it may sometimes be unnecessary if code is never used in a multicontroller setting. Co-authored-by: Peter Hawkins <[email protected]>
1 parent 3ab9211 commit 2695ba8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/modeling_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def from_pretrained(
437437
# make sure all arrays are stored as jnp.ndarray
438438
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
439439
# https://github.com/google/flax/issues/1261
440-
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
440+
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
441441

442442
# flatten dicts
443443
state = flatten_dict(state)

0 commit comments

Comments
 (0)