Skip to content

Commit af27943

Browse files
authored
Flax tests: don't hardcode number of devices (huggingface#1175)
Flax tests: don't hardcode number of devices. This makes it possible to test on CPU/GPU. However, expected slices are only checked when there are 8 devices.
1 parent 4969f46 commit af27943

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

tests/test_pipelines_flax.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,19 @@ def test_dummy_all_tpus(self):
7373

7474
# shard inputs and rng
7575
params = replicate(params)
76-
prng_seed = jax.random.split(prng_seed, 8)
76+
prng_seed = jax.random.split(prng_seed, num_samples)
7777
prompt_ids = shard(prompt_ids)
7878

7979
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
8080

81-
assert images.shape == (8, 1, 128, 128, 3)
82-
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3
83-
assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1
81+
assert images.shape == (num_samples, 1, 128, 128, 3)
82+
if jax.device_count() == 8:
83+
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3
84+
assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1
8485

8586
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
8687

87-
assert len(images_pil) == 8
88+
assert len(images_pil) == num_samples
8889

8990
def test_stable_diffusion_v1_4(self):
9091
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
@@ -107,14 +108,15 @@ def test_stable_diffusion_v1_4(self):
107108

108109
# shard inputs and rng
109110
params = replicate(params)
110-
prng_seed = jax.random.split(prng_seed, 8)
111+
prng_seed = jax.random.split(prng_seed, num_samples)
111112
prompt_ids = shard(prompt_ids)
112113

113114
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
114115

115-
assert images.shape == (8, 1, 512, 512, 3)
116-
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
117-
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
116+
assert images.shape == (num_samples, 1, 512, 512, 3)
117+
if jax.device_count() == 8:
118+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
119+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
118120

119121
def test_stable_diffusion_v1_4_bfloat_16(self):
120122
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
@@ -137,14 +139,15 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
137139

138140
# shard inputs and rng
139141
params = replicate(params)
140-
prng_seed = jax.random.split(prng_seed, 8)
142+
prng_seed = jax.random.split(prng_seed, num_samples)
141143
prompt_ids = shard(prompt_ids)
142144

143145
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
144146

145-
assert images.shape == (8, 1, 512, 512, 3)
146-
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
147-
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
147+
assert images.shape == (num_samples, 1, 512, 512, 3)
148+
if jax.device_count() == 8:
149+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
150+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
148151

149152
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
150153
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
@@ -165,14 +168,15 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
165168

166169
# shard inputs and rng
167170
params = replicate(params)
168-
prng_seed = jax.random.split(prng_seed, 8)
171+
prng_seed = jax.random.split(prng_seed, num_samples)
169172
prompt_ids = shard(prompt_ids)
170173

171174
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
172175

173-
assert images.shape == (8, 1, 512, 512, 3)
174-
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
175-
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
176+
assert images.shape == (num_samples, 1, 512, 512, 3)
177+
if jax.device_count() == 8:
178+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
179+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
176180

177181
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
178182
scheduler = FlaxDDIMScheduler(
@@ -210,11 +214,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
210214

211215
# shard inputs and rng
212216
params = replicate(params)
213-
prng_seed = jax.random.split(prng_seed, 8)
217+
prng_seed = jax.random.split(prng_seed, num_samples)
214218
prompt_ids = shard(prompt_ids)
215219

216220
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
217221

218-
assert images.shape == (8, 1, 512, 512, 3)
219-
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
220-
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
222+
assert images.shape == (num_samples, 1, 512, 512, 3)
223+
if jax.device_count() == 8:
224+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
225+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1

0 commit comments

Comments
 (0)