Skip to content

Commit a9190ba

Browse files
dhruvrnaikpcuenca
andauthored
Add Flax stable diffusion img2img pipeline (huggingface#1355)
* add flax img2img pipeline * update pipeline * black format file * remove argg from get_timesteps * update get_timesteps * fix bug: make use of timesteps for for_loop * black file * black, isort, flake8 * update docstring * update readme * update flax img2img readme * update sd pipeline init * Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca <[email protected]> * update inits * revert change * update var name to image, typo * update readme * return new t_start instead of modified timestep * black format * isort files * update docs * fix-copies * update prng_seed typing Co-authored-by: Pedro Cuenca <[email protected]>
1 parent d07f730 commit a9190ba

File tree

6 files changed

+481
-2
lines changed

6 files changed

+481
-2
lines changed

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,55 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).
235235
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
236236
```
237237

238+
Diffusers also has a Image-to-Image generation pipeline with Flax/Jax
239+
```python
240+
import jax
241+
import numpy as np
242+
import jax.numpy as jnp
243+
from flax.jax_utils import replicate
244+
from flax.training.common_utils import shard
245+
import requests
246+
from io import BytesIO
247+
from PIL import Image
248+
from diffusers import FlaxStableDiffusionImg2ImgPipeline
249+
250+
def create_key(seed=0):
251+
return jax.random.PRNGKey(seed)
252+
rng = create_key(0)
253+
254+
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
255+
response = requests.get(url)
256+
init_img = Image.open(BytesIO(response.content)).convert("RGB")
257+
init_img = init_img.resize((768, 512))
258+
259+
prompts = "A fantasy landscape, trending on artstation"
260+
261+
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
262+
"CompVis/stable-diffusion-v1-4", revision="flax",
263+
dtype=jnp.bfloat16,
264+
)
265+
266+
num_samples = jax.device_count()
267+
rng = jax.random.split(rng, jax.device_count())
268+
prompt_ids, processed_image = pipeline.prepare_inputs(prompt=[prompts]*num_samples, image = [init_img]*num_samples)
269+
p_params = replicate(params)
270+
prompt_ids = shard(prompt_ids)
271+
processed_image = shard(processed_image)
272+
273+
output = pipeline(
274+
prompt_ids=prompt_ids,
275+
image=processed_image,
276+
params=p_params,
277+
prng_seed=rng,
278+
strength=0.75,
279+
num_inference_steps=50,
280+
jit=True,
281+
height=512,
282+
width=768).images
283+
284+
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
285+
```
286+
238287
### Image-to-Image text-guided generation with Stable Diffusion
239288

240289
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.

src/diffusers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,11 @@
164164
FlaxScoreSdeVeScheduler,
165165
)
166166

167+
167168
try:
168169
if not (is_flax_available() and is_transformers_available()):
169170
raise OptionalDependencyNotAvailable()
170171
except OptionalDependencyNotAvailable:
171172
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
172173
else:
173-
from .pipelines import FlaxStableDiffusionPipeline
174+
from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,4 @@
9191
except OptionalDependencyNotAvailable:
9292
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
9393
else:
94-
from .stable_diffusion import FlaxStableDiffusionPipeline
94+
from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
9898

9999
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
100100
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
101+
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
101102
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

0 commit comments

Comments
 (0)