Skip to content

Commit 7462156

Browse files
yiyixuxucene555yiyixuxupatrickvonplatenpcuenca
authored
Kandinsky_v22_yiyi (huggingface#3936)
* Kandinsky2_2 * fix init kandinsky2_2 * kandinsky2_2 fix inpainting * rename pipelines: remove decoder + 2_2 -> V22 * Update scheduling_unclip.py * remove text_encoder and tokenizer arguments from doc string * add test for text2img * add tests for text2img & img2img * fix * add test for inpaint * add prior tests * style * copies * add controlnet test * style * add a test for controlnet_img2img * update prior_emb2emb api to accept image_embedding or image * add a test for prior_emb2emb * style * remove try except * example * fix * add doc string examples to all kandinsky pipelines * style * update doc * style * add a top about 2.2 * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * vae -> movq * vae -> movq * style * fix the #copied from * remove decoder from file name * update doc: add a section for kandinsky 2.2 * fix * fix-copies * add coped from * add copies from for prior * add copies from for prior emb2emb * copy from for img2img * copied from for inpaint * more copied from * more copies from * more copies * remove the yiyi comments * Apply suggestions from code review * Self-contained example, pipeline order * Import prior output instead of redefining. * Style * Make VQModel compatible with model offload. * Fix copies --------- Co-authored-by: Shahmatov Arseniy <[email protected]> Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent bc9a8ce commit 7462156

29 files changed

+5646
-25
lines changed

docs/source/en/api/pipelines/kandinsky.mdx

Lines changed: 262 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,12 @@ specific language governing permissions and limitations under the License.
1111

1212
## Overview
1313

14-
Kandinsky 2.1 inherits best practices from [DALL-E 2](https://arxiv.org/abs/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas.
14+
Kandinsky inherits best practices from [DALL-E 2](https://huggingface.co/papers/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas.
1515

1616
It uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for encoding images and text, and a diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach enhances the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.
1717

18-
The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov) and the original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2)
18+
The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov). The original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2)
1919

20-
## Available Pipelines:
21-
22-
| Pipeline | Tasks |
23-
|---|---|
24-
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* |
25-
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* |
26-
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* |
2720

2821
## Usage example
2922

@@ -135,6 +128,7 @@ prompt = "birds eye view of a quilted paper style alien planet landscape, vibran
135128
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/alienplanet.png)
136129

137130

131+
138132
### Text Guided Image-to-Image Generation
139133

140134
The same Kandinsky model weights can be used for text-guided image-to-image translation. In this case, just make sure to load the weights using the [`KandinskyImg2ImgPipeline`] pipeline.
@@ -283,6 +277,207 @@ image.save("starry_cat.png")
283277
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png)
284278

285279

280+
### Text-to-Image Generation with ControlNet Conditioning
281+
282+
In the following, we give a simple example of how to use [`KandinskyV22ControlnetPipeline`] to add control to the text-to-image generation with a depth image.
283+
284+
First, let's take an image and extract its depth map.
285+
286+
```python
287+
from diffusers.utils import load_image
288+
289+
img = load_image(
290+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png"
291+
).resize((768, 768))
292+
```
293+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png)
294+
295+
We can use the `depth-estimation` pipeline from transformers to process the image and retrieve its depth map.
296+
297+
```python
298+
import torch
299+
import numpy as np
300+
301+
from transformers import pipeline
302+
from diffusers.utils import load_image
303+
304+
305+
def make_hint(image, depth_estimator):
306+
image = depth_estimator(image)["depth"]
307+
image = np.array(image)
308+
image = image[:, :, None]
309+
image = np.concatenate([image, image, image], axis=2)
310+
detected_map = torch.from_numpy(image).float() / 255.0
311+
hint = detected_map.permute(2, 0, 1)
312+
return hint
313+
314+
315+
depth_estimator = pipeline("depth-estimation")
316+
hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
317+
```
318+
Now, we load the prior pipeline and the text-to-image controlnet pipeline
319+
320+
```python
321+
from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
322+
323+
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
324+
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
325+
)
326+
pipe_prior = pipe_prior.to("cuda")
327+
328+
pipe = KandinskyV22ControlnetPipeline.from_pretrained(
329+
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
330+
)
331+
pipe = pipe.to("cuda")
332+
```
333+
334+
We pass the prompt and negative prompt through the prior to generate image embeddings
335+
336+
```python
337+
prompt = "A robot, 4k photo"
338+
339+
negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
340+
341+
generator = torch.Generator(device="cuda").manual_seed(43)
342+
image_emb, zero_image_emb = pipe_prior(
343+
prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator
344+
).to_tuple()
345+
```
346+
347+
Now we can pass the image embeddings and the depth image we extracted to the controlnet pipeline. With Kandinsky 2.2, only prior pipelines accept `prompt` input. You do not need to pass the prompt to the controlnet pipeline.
348+
349+
```python
350+
images = pipe(
351+
image_embeds=image_emb,
352+
negative_image_embeds=zero_image_emb,
353+
hint=hint,
354+
num_inference_steps=50,
355+
generator=generator,
356+
height=768,
357+
width=768,
358+
).images
359+
360+
images[0].save("robot_cat.png")
361+
```
362+
363+
The output image looks as follow:
364+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat_text2img.png)
365+
366+
### Image-to-Image Generation with ControlNet Conditioning
367+
368+
Kandinsky 2.2 also includes a [`KandinskyV22ControlnetImg2ImgPipeline`] that will allow you to add control to the image generation process with both the image and its depth map. This pipeline works really well with [`KandinskyV22PriorEmb2EmbPipeline`], which generates image embeddings based on both a text prompt and an image.
369+
370+
For our robot cat example, we will pass the prompt and cat image together to the prior pipeline to generate an image embedding. We will then use that image embedding and the depth map of the cat to further control the image generation process.
371+
372+
We can use the same cat image and its depth map from the last example.
373+
374+
```python
375+
import torch
376+
import numpy as np
377+
378+
from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
379+
from diffusers.utils import load_image
380+
from transformers import pipeline
381+
382+
img = load_image(
383+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinskyv22/cat.png"
384+
).resize((768, 768))
385+
386+
387+
def make_hint(image, depth_estimator):
388+
image = depth_estimator(image)["depth"]
389+
image = np.array(image)
390+
image = image[:, :, None]
391+
image = np.concatenate([image, image, image], axis=2)
392+
detected_map = torch.from_numpy(image).float() / 255.0
393+
hint = detected_map.permute(2, 0, 1)
394+
return hint
395+
396+
397+
depth_estimator = pipeline("depth-estimation")
398+
hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
399+
400+
pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
401+
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
402+
)
403+
pipe_prior = pipe_prior.to("cuda")
404+
405+
pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
406+
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
407+
)
408+
pipe = pipe.to("cuda")
409+
410+
prompt = "A robot, 4k photo"
411+
negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
412+
413+
generator = torch.Generator(device="cuda").manual_seed(43)
414+
415+
# run prior pipeline
416+
417+
img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator)
418+
negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)
419+
420+
# run controlnet img2img pipeline
421+
images = pipe(
422+
image=img,
423+
strength=0.5,
424+
image_embeds=img_emb.image_embeds,
425+
negative_image_embeds=negative_emb.image_embeds,
426+
hint=hint,
427+
num_inference_steps=50,
428+
generator=generator,
429+
height=768,
430+
width=768,
431+
).images
432+
433+
images[0].save("robot_cat.png")
434+
```
435+
436+
Here is the output. Compared with the output from our text-to-image controlnet example, it kept a lot more cat facial details from the original image and worked into the robot style we asked for.
437+
438+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat.png)
439+
440+
## Kandinsky 2.2
441+
442+
The Kandinsky 2.2 release includes robust new text-to-image models that support text-to-image generation, image-to-image generation, image interpolation, and text-guided image inpainting. The general workflow to perform these tasks using Kandinsky 2.2 is the same as in Kandinsky 2.1. First, you will need to use a prior pipeline to generate image embeddings based on your text prompt, and then use one of the image decoding pipelines to generate the output image. The only difference is that in Kandinsky 2.2, all of the decoding pipelines no longer accept the `prompt` input, and the image generation process is conditioned with only `image_embeds` and `negative_image_embeds`.
443+
444+
Let's look at an example of how to perform text-to-image generation using Kandinsky 2.2.
445+
446+
First, let's create the prior pipeline and text-to-image pipeline with Kandinsky 2.2 checkpoints.
447+
448+
```python
449+
from diffusers import DiffusionPipeline
450+
import torch
451+
452+
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16)
453+
pipe_prior.to("cuda")
454+
455+
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
456+
t2i_pipe.to("cuda")
457+
```
458+
459+
You can then use `pipe_prior` to generate image embeddings.
460+
461+
```python
462+
prompt = "portrait of a women, blue eyes, cinematic"
463+
negative_prompt = "low quality, bad quality"
464+
465+
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
466+
```
467+
468+
Now you can pass these embeddings to the text-to-image pipeline. When using Kandinsky 2.2 you don't need to pass the `prompt` (but you do with the previous version, Kandinsky 2.1).
469+
470+
```
471+
image = t2i_pipe(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768).images[
472+
0
473+
]
474+
image.save("portrait.png")
475+
```
476+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/%20blue%20eyes.png)
477+
478+
We used the text-to-image pipeline as an example, but the same process applies to all decoding pipelines in Kandinsky 2.2. For more information, please refer to our API section for each pipeline.
479+
480+
286481
## Optimization
287482

288483
Running Kandinsky in inference requires running both a first prior pipeline: [`KandinskyPriorPipeline`]
@@ -335,30 +530,84 @@ t2i_pipe.unet = torch.compile(t2i_pipe.unet, mode="reduce-overhead", fullgraph=T
335530
After compilation you should see a very fast inference time. For more information,
336531
feel free to have a look at [Our PyTorch 2.0 benchmark](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0).
337532

533+
## Available Pipelines:
534+
535+
| Pipeline | Tasks |
536+
|---|---|
537+
| [pipeline_kandinsky2_2.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py) | *Text-to-Image Generation* |
538+
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* |
539+
| [pipeline_kandinsky2_2_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py) | *Image-Guided Image Generation* |
540+
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* |
541+
| [pipeline_kandinsky2_2_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py) | *Image-Guided Image Generation* |
542+
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* |
543+
| [pipeline_kandinsky2_2_controlnet.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py) | *Image-Guided Image Generation* |
544+
| [pipeline_kandinsky2_2_controlnet_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py) | *Image-Guided Image Generation* |
545+
546+
547+
### KandinskyV22Pipeline
338548

549+
[[autodoc]] KandinskyV22Pipeline
550+
- all
551+
- __call__
552+
553+
### KandinskyV22ControlnetPipeline
339554

555+
[[autodoc]] KandinskyV22ControlnetPipeline
556+
- all
557+
- __call__
558+
559+
### KandinskyV22ControlnetImg2ImgPipeline
560+
561+
[[autodoc]] KandinskyV22ControlnetImg2ImgPipeline
562+
- all
563+
- __call__
340564

565+
### KandinskyV22Img2ImgPipeline
566+
567+
[[autodoc]] KandinskyV22Img2ImgPipeline
568+
- all
569+
- __call__
570+
571+
### KandinskyV22InpaintPipeline
572+
573+
[[autodoc]] KandinskyV22InpaintPipeline
574+
- all
575+
- __call__
576+
577+
### KandinskyV22PriorPipeline
578+
579+
[[autodoc]] ## KandinskyV22PriorPipeline
580+
- all
581+
- __call__
582+
- interpolate
583+
584+
### KandinskyV22PriorEmb2EmbPipeline
585+
586+
[[autodoc]] KandinskyV22PriorEmb2EmbPipeline
587+
- all
588+
- __call__
589+
- interpolate
341590

342-
## KandinskyPriorPipeline
591+
### KandinskyPriorPipeline
343592

344593
[[autodoc]] KandinskyPriorPipeline
345594
- all
346595
- __call__
347596
- interpolate
348597

349-
## KandinskyPipeline
598+
### KandinskyPipeline
350599

351600
[[autodoc]] KandinskyPipeline
352601
- all
353602
- __call__
354603

355-
## KandinskyImg2ImgPipeline
604+
### KandinskyImg2ImgPipeline
356605

357606
[[autodoc]] KandinskyImg2ImgPipeline
358607
- all
359608
- __call__
360609

361-
## KandinskyInpaintPipeline
610+
### KandinskyInpaintPipeline
362611

363612
[[autodoc]] KandinskyInpaintPipeline
364613
- all

src/diffusers/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@
139139
KandinskyInpaintPipeline,
140140
KandinskyPipeline,
141141
KandinskyPriorPipeline,
142+
KandinskyV22ControlnetImg2ImgPipeline,
143+
KandinskyV22ControlnetPipeline,
144+
KandinskyV22Img2ImgPipeline,
145+
KandinskyV22InpaintPipeline,
146+
KandinskyV22Pipeline,
147+
KandinskyV22PriorEmb2EmbPipeline,
148+
KandinskyV22PriorPipeline,
142149
LDMTextToImagePipeline,
143150
PaintByExamplePipeline,
144151
SemanticStableDiffusionPipeline,

0 commit comments

Comments
 (0)