@@ -284,6 +284,53 @@ output = pipeline(
284284output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[- 3 :])))
285285```
286286
287+ Diffusers also has a Text-guided inpainting pipeline with Flax/Jax
288+
289+ ``` python
290+ import jax
291+ import numpy as np
292+ from flax.jax_utils import replicate
293+ from flax.training.common_utils import shard
294+ import PIL
295+ import requests
296+ from io import BytesIO
297+
298+
299+ from diffusers import FlaxStableDiffusionInpaintPipeline
300+
301+ def download_image (url ):
302+ response = requests.get(url)
303+ return PIL .Image.open(BytesIO(response.content)).convert(" RGB" )
304+ img_url = " https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
305+ mask_url = " https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
306+
307+ init_image = download_image(img_url).resize((512 , 512 ))
308+ mask_image = download_image(mask_url).resize((512 , 512 ))
309+
310+ pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(" xvjiarui/stable-diffusion-2-inpainting" )
311+
312+ prompt = " Face of a yellow cat, high resolution, sitting on a park bench"
313+ prng_seed = jax.random.PRNGKey(0 )
314+ num_inference_steps = 50
315+
316+ num_samples = jax.device_count()
317+ prompt = num_samples * [prompt]
318+ init_image = num_samples * [init_image]
319+ mask_image = num_samples * [mask_image]
320+ prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
321+
322+
323+ # shard inputs and rng
324+ params = replicate(params)
325+ prng_seed = jax.random.split(prng_seed, jax.device_count())
326+ prompt_ids = shard(prompt_ids)
327+ processed_masked_images = shard(processed_masked_images)
328+ processed_masks = shard(processed_masks)
329+
330+ images = pipeline(prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit = True ).images
331+ images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[- 3 :])))
332+ ```
333+
287334### Image-to-Image text-guided generation with Stable Diffusion
288335
289336The ` StableDiffusionImg2ImgPipeline ` lets you pass a text prompt and an initial image to condition the generation of new images.
0 commit comments