2828 import jax .numpy as jnp
2929 from flax .jax_utils import replicate
3030 from flax .training .common_utils import shard
31- from jax import pmap
3231
3332 from diffusers import FlaxDDIMScheduler , FlaxDiffusionPipeline , FlaxStableDiffusionPipeline
3433
@@ -70,14 +69,12 @@ def test_dummy_all_tpus(self):
7069 prompt = num_samples * [prompt ]
7170 prompt_ids = pipeline .prepare_inputs (prompt )
7271
73- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
74-
7572 # shard inputs and rng
7673 params = replicate (params )
7774 prng_seed = jax .random .split (prng_seed , num_samples )
7875 prompt_ids = shard (prompt_ids )
7976
80- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
77+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
8178
8279 assert images .shape == (num_samples , 1 , 64 , 64 , 3 )
8380 if jax .device_count () == 8 :
@@ -105,14 +102,12 @@ def test_stable_diffusion_v1_4(self):
105102 prompt = num_samples * [prompt ]
106103 prompt_ids = pipeline .prepare_inputs (prompt )
107104
108- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
109-
110105 # shard inputs and rng
111106 params = replicate (params )
112107 prng_seed = jax .random .split (prng_seed , num_samples )
113108 prompt_ids = shard (prompt_ids )
114109
115- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
110+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
116111
117112 assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
118113 if jax .device_count () == 8 :
@@ -136,14 +131,12 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
136131 prompt = num_samples * [prompt ]
137132 prompt_ids = pipeline .prepare_inputs (prompt )
138133
139- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
140-
141134 # shard inputs and rng
142135 params = replicate (params )
143136 prng_seed = jax .random .split (prng_seed , num_samples )
144137 prompt_ids = shard (prompt_ids )
145138
146- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
139+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
147140
148141 assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
149142 if jax .device_count () == 8 :
@@ -211,14 +204,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
211204 prompt = num_samples * [prompt ]
212205 prompt_ids = pipeline .prepare_inputs (prompt )
213206
214- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
215-
216207 # shard inputs and rng
217208 params = replicate (params )
218209 prng_seed = jax .random .split (prng_seed , num_samples )
219210 prompt_ids = shard (prompt_ids )
220211
221- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
212+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
222213
223214 assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
224215 if jax .device_count () == 8 :
0 commit comments