Skip to content

Commit b25843e

Browse files
unCLIP docs (huggingface#1754)
* [unCLIP docs] markdown * [unCLIP docs] UnCLIPPipeline
1 parent 830a9d1 commit b25843e

File tree

6 files changed

+114
-0
lines changed

6 files changed

+114
-0
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
title: "Stochastic Karras VE"
123123
- local: api/pipelines/dance_diffusion
124124
title: "Dance Diffusion"
125+
- local: api/pipelines/unclip
126+
title: "UnCLIP"
125127
- local: api/pipelines/versatile_diffusion
126128
title: "Versatile Diffusion"
127129
- local: api/pipelines/vq_diffusion

docs/source/api/models.mdx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
5858
## Transformer2DModelOutput
5959
[[autodoc]] models.attention.Transformer2DModelOutput
6060

61+
## PriorTransformer
62+
[[autodoc]] models.prior_transformer.PriorTransformer
63+
64+
## PriorTransformerOutput
65+
[[autodoc]] models.prior_transformer.PriorTransformerOutput
66+
6167
## FlaxModelMixin
6268
[[autodoc]] FlaxModelMixin
6369

docs/source/api/pipelines/overview.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ available a colab notebook to directly try them out.
6565
| [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
6666
| [stable_diffusion_safe](./stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
6767
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
68+
| [unclip](./unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation |
6869
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
6970
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
7071
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
3+
the License. You may obtain a copy of the License at
4+
http://www.apache.org/licenses/LICENSE-2.0
5+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
6+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
7+
specific language governing permissions and limitations under the License.
8+
-->
9+
10+
# unCLIP
11+
12+
## Overview
13+
14+
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen
15+
16+
The abstract of the paper is the following:
17+
18+
Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.
19+
20+
The unCLIP model in diffusers comes from kakaobrain's karlo and the original codebase can be found [here](https://github.com/kakaobrain/karlo). Additionally, lucidrains has a DALL-E 2 recreation [here](https://github.com/lucidrains/DALLE2-pytorch).
21+
22+
## Available Pipelines:
23+
24+
| Pipeline | Tasks | Colab
25+
|---|---|:---:|
26+
| [pipeline_unclip.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/unclip/pipeline_unclip.py) | *Text-to-Image Generation* | - |
27+
28+
29+
## UnCLIPPipeline
30+
[[autodoc]] pipelines.unclip.pipeline_unclip.UnCLIPPipeline
31+
- __call__

docs/source/index.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ available a colab notebook to directly try them out.
5555
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
5656
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
5757
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
58+
| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation |
5859
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
5960
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
6061
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |

src/diffusers/pipelines/unclip/pipeline_unclip.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,35 @@
3131

3232

3333
class UnCLIPPipeline(DiffusionPipeline):
34+
"""
35+
Pipeline for text-to-image generation using unCLIP
36+
37+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
38+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
39+
40+
Args:
41+
text_encoder ([`CLIPTextModelWithProjection`]):
42+
Frozen text-encoder.
43+
tokenizer (`CLIPTokenizer`):
44+
Tokenizer of class
45+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
46+
prior ([`PriorTransformer`]):
47+
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
48+
decoder ([`UNet2DConditionModel`]):
49+
The decoder to invert the image embedding into an image.
50+
super_res_first ([`UNet2DModel`]):
51+
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
52+
super_res_last ([`UNet2DModel`]):
53+
Super resolution unet. Used in the last step of the super resolution diffusion process.
54+
prior_scheduler ([`UnCLIPScheduler`]):
55+
Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
56+
decoder_scheduler ([`UnCLIPScheduler`]):
57+
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
58+
super_res_scheduler ([`UnCLIPScheduler`]):
59+
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
60+
61+
"""
62+
3463
prior: PriorTransformer
3564
decoder: UNet2DConditionModel
3665
text_proj: UnCLIPTextProjModel
@@ -173,6 +202,50 @@ def __call__(
173202
output_type: Optional[str] = "pil",
174203
return_dict: bool = True,
175204
):
205+
"""
206+
Function invoked when calling the pipeline for generation.
207+
208+
Args:
209+
prompt (`str` or `List[str]`):
210+
The prompt or prompts to guide the image generation.
211+
num_images_per_prompt (`int`, *optional*, defaults to 1):
212+
The number of images to generate per prompt.
213+
prior_num_inference_steps (`int`, *optional*, defaults to 25):
214+
The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
215+
image at the expense of slower inference.
216+
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
217+
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
218+
image at the expense of slower inference.
219+
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
220+
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
221+
quality image at the expense of slower inference.
222+
generator (`torch.Generator`, *optional*):
223+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
224+
to make generation deterministic.
225+
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
226+
Pre-generated noisy latents to be used as inputs for the prior.
227+
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
228+
Pre-generated noisy latents to be used as inputs for the decoder.
229+
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
230+
Pre-generated noisy latents to be used as inputs for the decoder.
231+
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
232+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
233+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
234+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
235+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
236+
usually at the expense of lower image quality.
237+
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
238+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
239+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
240+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
241+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
242+
usually at the expense of lower image quality.
243+
output_type (`str`, *optional*, defaults to `"pil"`):
244+
The output format of the generated image. Choose between
245+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
246+
return_dict (`bool`, *optional*, defaults to `True`):
247+
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
248+
"""
176249
if isinstance(prompt, str):
177250
batch_size = 1
178251
elif isinstance(prompt, list):

0 commit comments

Comments
 (0)