Skip to content

Commit 045157a

Browse files
authored
Fix Flax usage comments (huggingface#1211)
* Fix Flax usage comments (they didn't work). * Spell out dtype * make style
1 parent a09d475 commit 045157a

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

src/diffusers/pipeline_flax_utils.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,18 +268,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
268268
>>> from diffusers import FlaxDiffusionPipeline
269269
270270
>>> # Download pipeline from huggingface.co and cache.
271-
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
272-
273-
>>> # Download pipeline that requires an authorization token
274-
>>> # For more information on access tokens, please refer to this section
275-
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
276-
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
277-
278-
>>> # Download pipeline, but overwrite scheduler
279-
>>> from diffusers import LMSDiscreteScheduler
280-
281-
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
282-
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
271+
>>> # Requires to be logged in to Hugging Face hub,
272+
>>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
273+
>>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
274+
... "runwayml/stable-diffusion-v1-5",
275+
... revision="bf16",
276+
... dtype=jnp.bfloat16,
277+
... )
278+
279+
>>> # Download pipeline, but use a different scheduler
280+
>>> from diffusers import FlaxDPMSolverMultistepScheduler
281+
282+
>>> model_id = "runwayml/stable-diffusion-v1-5"
283+
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config(
284+
... model_id,
285+
... subfolder="scheduler",
286+
... )
287+
288+
>>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained(
289+
... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
290+
... )
291+
>>> dpm_params["scheduler"] = dpmpp_state
283292
```
284293
"""
285294
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)

0 commit comments

Comments
 (0)