Skip to content

Commit 2bbd532

Browse files
pcuencapatil-suraj
andauthored
Docs: short section on changing the scheduler in Flax (huggingface#2181)
* Short doc on changing the scheduler in Flax. * Apply fix from @patil-suraj Co-authored-by: Suraj Patil <[email protected]> --------- Co-authored-by: Suraj Patil <[email protected]>
1 parent 68ef066 commit 2bbd532

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

docs/source/en/using-diffusers/schedulers.mdx

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ image
176176
<br>
177177
</p>
178178

179+
If you are a JAX/Flax user, please check [this section](#changing-the-scheduler-in-flax) instead.
179180

180181
## Compare schedulers
181182

@@ -260,3 +261,54 @@ image
260261

261262
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
262263
schedulers to compare results.
264+
265+
## Changing the Scheduler in Flax
266+
267+
If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DDPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):
268+
269+
```Python
270+
import jax
271+
import numpy as np
272+
from flax.jax_utils import replicate
273+
from flax.training.common_utils import shard
274+
275+
from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
276+
277+
model_id = "runwayml/stable-diffusion-v1-5"
278+
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
279+
model_id,
280+
subfolder="scheduler"
281+
)
282+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
283+
model_id,
284+
scheduler=scheduler,
285+
revision="bf16",
286+
dtype=jax.numpy.bfloat16,
287+
)
288+
params["scheduler"] = scheduler_state
289+
290+
# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
291+
prompt = "a photo of an astronaut riding a horse on mars"
292+
num_samples = jax.device_count()
293+
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
294+
295+
prng_seed = jax.random.PRNGKey(0)
296+
num_inference_steps = 25
297+
298+
# shard inputs and rng
299+
params = replicate(params)
300+
prng_seed = jax.random.split(prng_seed, jax.device_count())
301+
prompt_ids = shard(prompt_ids)
302+
303+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
304+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
305+
```
306+
307+
<Tip warning={true}>
308+
309+
The following Flax schedulers are _not yet compatible_ with the Flax Stable Diffusion Pipeline:
310+
311+
- `FlaxLMSDiscreteScheduler`
312+
- `FlaxDDPMScheduler`
313+
314+
</Tip>

0 commit comments

Comments
 (0)