@@ -73,18 +73,19 @@ def test_dummy_all_tpus(self):
7373
7474 # shard inputs and rng
7575 params = replicate (params )
76- prng_seed = jax .random .split (prng_seed , 8 )
76+ prng_seed = jax .random .split (prng_seed , num_samples )
7777 prompt_ids = shard (prompt_ids )
7878
7979 images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
8080
81- assert images .shape == (8 , 1 , 128 , 128 , 3 )
82- assert np .abs (np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 3.1111548 ) < 1e-3
83- assert np .abs (np .abs (images , dtype = np .float32 ).sum () - 199746.95 ) < 5e-1
81+ assert images .shape == (num_samples , 1 , 128 , 128 , 3 )
82+ if jax .device_count () == 8 :
83+ assert np .abs (np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 3.1111548 ) < 1e-3
84+ assert np .abs (np .abs (images , dtype = np .float32 ).sum () - 199746.95 ) < 5e-1
8485
8586 images_pil = pipeline .numpy_to_pil (np .asarray (images .reshape ((num_samples ,) + images .shape [- 3 :])))
8687
87- assert len (images_pil ) == 8
88+ assert len (images_pil ) == num_samples
8889
8990 def test_stable_diffusion_v1_4 (self ):
9091 pipeline , params = FlaxStableDiffusionPipeline .from_pretrained (
@@ -107,14 +108,15 @@ def test_stable_diffusion_v1_4(self):
107108
108109 # shard inputs and rng
109110 params = replicate (params )
110- prng_seed = jax .random .split (prng_seed , 8 )
111+ prng_seed = jax .random .split (prng_seed , num_samples )
111112 prompt_ids = shard (prompt_ids )
112113
113114 images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
114115
115- assert images .shape == (8 , 1 , 512 , 512 , 3 )
116- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.05652401 )) < 1e-3
117- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2383808.2 )) < 5e-1
116+ assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
117+ if jax .device_count () == 8 :
118+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.05652401 )) < 1e-3
119+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2383808.2 )) < 5e-1
118120
119121 def test_stable_diffusion_v1_4_bfloat_16 (self ):
120122 pipeline , params = FlaxStableDiffusionPipeline .from_pretrained (
@@ -137,14 +139,15 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
137139
138140 # shard inputs and rng
139141 params = replicate (params )
140- prng_seed = jax .random .split (prng_seed , 8 )
142+ prng_seed = jax .random .split (prng_seed , num_samples )
141143 prompt_ids = shard (prompt_ids )
142144
143145 images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
144146
145- assert images .shape == (8 , 1 , 512 , 512 , 3 )
146- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
147- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
147+ assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
148+ if jax .device_count () == 8 :
149+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
150+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
148151
149152 def test_stable_diffusion_v1_4_bfloat_16_with_safety (self ):
150153 pipeline , params = FlaxStableDiffusionPipeline .from_pretrained (
@@ -165,14 +168,15 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
165168
166169 # shard inputs and rng
167170 params = replicate (params )
168- prng_seed = jax .random .split (prng_seed , 8 )
171+ prng_seed = jax .random .split (prng_seed , num_samples )
169172 prompt_ids = shard (prompt_ids )
170173
171174 images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
172175
173- assert images .shape == (8 , 1 , 512 , 512 , 3 )
174- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
175- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
176+ assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
177+ if jax .device_count () == 8 :
178+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
179+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
176180
177181 def test_stable_diffusion_v1_4_bfloat_16_ddim (self ):
178182 scheduler = FlaxDDIMScheduler (
@@ -210,11 +214,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
210214
211215 # shard inputs and rng
212216 params = replicate (params )
213- prng_seed = jax .random .split (prng_seed , 8 )
217+ prng_seed = jax .random .split (prng_seed , num_samples )
214218 prompt_ids = shard (prompt_ids )
215219
216220 images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
217221
218- assert images .shape == (8 , 1 , 512 , 512 , 3 )
219- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.045043945 )) < 1e-3
220- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2347693.5 )) < 5e-1
222+ assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
223+ if jax .device_count () == 8 :
224+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.045043945 )) < 1e-3
225+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2347693.5 )) < 5e-1
0 commit comments