Skip to content

Commit a062e47

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored
add flax pipelines to api doc + doc string examples (huggingface#2600)
* add api doc for flax pipeline + doc string examples * make style --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 75f1210 commit a062e47

File tree

6 files changed

+174
-5
lines changed

6 files changed

+174
-5
lines changed

docs/source/en/api/pipelines/stable_diffusion/img2img.mdx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,8 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
2929
- enable_attention_slicing
3030
- disable_attention_slicing
3131
- enable_xformers_memory_efficient_attention
32-
- disable_xformers_memory_efficient_attention
32+
- disable_xformers_memory_efficient_attention
33+
34+
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
35+
- all
36+
- __call__

docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,8 @@ Available checkpoints are:
3030
- enable_attention_slicing
3131
- disable_attention_slicing
3232
- enable_xformers_memory_efficient_attention
33-
- disable_xformers_memory_efficient_attention
33+
- disable_xformers_memory_efficient_attention
34+
35+
[[autodoc]] FlaxStableDiffusionInpaintPipeline
36+
- all
37+
- __call__

docs/source/en/api/pipelines/stable_diffusion/text2img.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,7 @@ Available Checkpoints are:
3939
- disable_xformers_memory_efficient_attention
4040
- enable_vae_tiling
4141
- disable_vae_tiling
42+
43+
[[autodoc]] FlaxStableDiffusionPipeline
44+
- all
45+
- __call__

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from flax.training.common_utils import shard
2525
from packaging import version
2626
from PIL import Image
27+
2728
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
2829

2930
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
@@ -33,7 +34,7 @@
3334
FlaxLMSDiscreteScheduler,
3435
FlaxPNDMScheduler,
3536
)
36-
from ...utils import deprecate, logging
37+
from ...utils import deprecate, logging, replace_example_docstring
3738
from ..pipeline_flax_utils import FlaxDiffusionPipeline
3839
from . import FlaxStableDiffusionPipelineOutput
3940
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -44,6 +45,39 @@
4445
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
4546
DEBUG = False
4647

48+
EXAMPLE_DOC_STRING = """
49+
Examples:
50+
```py
51+
>>> import jax
52+
>>> import numpy as np
53+
>>> from flax.jax_utils import replicate
54+
>>> from flax.training.common_utils import shard
55+
56+
>>> from diffusers import FlaxStableDiffusionPipeline
57+
58+
>>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
59+
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16
60+
... )
61+
62+
>>> prompt = "a photo of an astronaut riding a horse on mars"
63+
64+
>>> prng_seed = jax.random.PRNGKey(0)
65+
>>> num_inference_steps = 50
66+
67+
>>> num_samples = jax.device_count()
68+
>>> prompt = num_samples * [prompt]
69+
>>> prompt_ids = pipeline.prepare_inputs(prompt)
70+
# shard inputs and rng
71+
72+
>>> params = replicate(params)
73+
>>> prng_seed = jax.random.split(prng_seed, jax.device_count())
74+
>>> prompt_ids = shard(prompt_ids)
75+
76+
>>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
77+
>>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
78+
```
79+
"""
80+
4781

4882
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
4983
r"""
@@ -272,6 +306,7 @@ def loop_body(step, args):
272306
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
273307
return image
274308

309+
@replace_example_docstring(EXAMPLE_DOC_STRING)
275310
def __call__(
276311
self,
277312
prompt_ids: jnp.array,
@@ -316,6 +351,8 @@ def __call__(
316351
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
317352
a plain tuple.
318353
354+
Examples:
355+
319356
Returns:
320357
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
321358
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from flax.jax_utils import unreplicate
2424
from flax.training.common_utils import shard
2525
from PIL import Image
26+
2627
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
2728

2829
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
@@ -32,7 +33,7 @@
3233
FlaxLMSDiscreteScheduler,
3334
FlaxPNDMScheduler,
3435
)
35-
from ...utils import PIL_INTERPOLATION, logging
36+
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
3637
from ..pipeline_flax_utils import FlaxDiffusionPipeline
3738
from . import FlaxStableDiffusionPipelineOutput
3839
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -43,6 +44,64 @@
4344
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
4445
DEBUG = False
4546

47+
EXAMPLE_DOC_STRING = """
48+
Examples:
49+
```py
50+
>>> import jax
51+
>>> import numpy as np
52+
>>> import jax.numpy as jnp
53+
>>> from flax.jax_utils import replicate
54+
>>> from flax.training.common_utils import shard
55+
>>> import requests
56+
>>> from io import BytesIO
57+
>>> from PIL import Image
58+
>>> from diffusers import FlaxStableDiffusionImg2ImgPipeline
59+
60+
61+
>>> def create_key(seed=0):
62+
... return jax.random.PRNGKey(seed)
63+
64+
65+
>>> rng = create_key(0)
66+
67+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
68+
>>> response = requests.get(url)
69+
>>> init_img = Image.open(BytesIO(response.content)).convert("RGB")
70+
>>> init_img = init_img.resize((768, 512))
71+
72+
>>> prompts = "A fantasy landscape, trending on artstation"
73+
74+
>>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
75+
... "CompVis/stable-diffusion-v1-4",
76+
... revision="flax",
77+
... dtype=jnp.bfloat16,
78+
... )
79+
80+
>>> num_samples = jax.device_count()
81+
>>> rng = jax.random.split(rng, jax.device_count())
82+
>>> prompt_ids, processed_image = pipeline.prepare_inputs(
83+
... prompt=[prompts] * num_samples, image=[init_img] * num_samples
84+
... )
85+
>>> p_params = replicate(params)
86+
>>> prompt_ids = shard(prompt_ids)
87+
>>> processed_image = shard(processed_image)
88+
89+
>>> output = pipeline(
90+
... prompt_ids=prompt_ids,
91+
... image=processed_image,
92+
... params=p_params,
93+
... prng_seed=rng,
94+
... strength=0.75,
95+
... num_inference_steps=50,
96+
... jit=True,
97+
... height=512,
98+
... width=768,
99+
... ).images
100+
101+
>>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
102+
```
103+
"""
104+
46105

47106
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
48107
r"""
@@ -277,6 +336,7 @@ def loop_body(step, args):
277336
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
278337
return image
279338

339+
@replace_example_docstring(EXAMPLE_DOC_STRING)
280340
def __call__(
281341
self,
282342
prompt_ids: jnp.array,
@@ -332,6 +392,8 @@ def __call__(
332392
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
333393
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
334394
395+
Examples:
396+
335397
Returns:
336398
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
337399
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from flax.training.common_utils import shard
2525
from packaging import version
2626
from PIL import Image
27+
2728
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
2829

2930
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
@@ -33,7 +34,7 @@
3334
FlaxLMSDiscreteScheduler,
3435
FlaxPNDMScheduler,
3536
)
36-
from ...utils import PIL_INTERPOLATION, deprecate, logging
37+
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
3738
from ..pipeline_flax_utils import FlaxDiffusionPipeline
3839
from . import FlaxStableDiffusionPipelineOutput
3940
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -44,6 +45,60 @@
4445
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
4546
DEBUG = False
4647

48+
EXAMPLE_DOC_STRING = """
49+
Examples:
50+
```py
51+
>>> import jax
52+
>>> import numpy as np
53+
>>> from flax.jax_utils import replicate
54+
>>> from flax.training.common_utils import shard
55+
>>> import PIL
56+
>>> import requests
57+
>>> from io import BytesIO
58+
>>> from diffusers import FlaxStableDiffusionInpaintPipeline
59+
60+
61+
>>> def download_image(url):
62+
... response = requests.get(url)
63+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
64+
65+
66+
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
67+
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
68+
69+
>>> init_image = download_image(img_url).resize((512, 512))
70+
>>> mask_image = download_image(mask_url).resize((512, 512))
71+
72+
>>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(
73+
... "xvjiarui/stable-diffusion-2-inpainting"
74+
... )
75+
76+
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
77+
>>> prng_seed = jax.random.PRNGKey(0)
78+
>>> num_inference_steps = 50
79+
80+
>>> num_samples = jax.device_count()
81+
>>> prompt = num_samples * [prompt]
82+
>>> init_image = num_samples * [init_image]
83+
>>> mask_image = num_samples * [mask_image]
84+
>>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(
85+
... prompt, init_image, mask_image
86+
... )
87+
# shard inputs and rng
88+
89+
>>> params = replicate(params)
90+
>>> prng_seed = jax.random.split(prng_seed, jax.device_count())
91+
>>> prompt_ids = shard(prompt_ids)
92+
>>> processed_masked_images = shard(processed_masked_images)
93+
>>> processed_masks = shard(processed_masks)
94+
95+
>>> images = pipeline(
96+
... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
97+
... ).images
98+
>>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
99+
```
100+
"""
101+
47102

48103
class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
49104
r"""
@@ -332,6 +387,7 @@ def loop_body(step, args):
332387
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
333388
return image
334389

390+
@replace_example_docstring(EXAMPLE_DOC_STRING)
335391
def __call__(
336392
self,
337393
prompt_ids: jnp.array,
@@ -378,6 +434,8 @@ def __call__(
378434
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
379435
a plain tuple.
380436
437+
Examples:
438+
381439
Returns:
382440
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
383441
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a

0 commit comments

Comments
 (0)