Skip to content

Commit 5782e03

Browse files
Stable diffusion pipeline (huggingface#168)
* add stable diffusion pipeline * get rid of multiple if/else * batch_size is unused * add type hints * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * fix some bugs Co-authored-by: Patrick von Platen <[email protected]>
1 parent 92b6dbb commit 5782e03

File tree

10 files changed

+187
-21
lines changed

10 files changed

+187
-21
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232

3333
if is_transformers_available():
34-
from .pipelines import LDMTextToImagePipeline
34+
from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline
35+
3536
else:
3637
from .utils.dummy_transformers_objects import *

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99

1010
if is_transformers_available():
1111
from .latent_diffusion import LDMTextToImagePipeline
12+
from .stable_diffusion import StableDiffusionPipeline

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def __call__(
6262

6363
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
6464
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
65-
extra_kwrags = {}
65+
66+
extra_kwargs = {}
6667
if accepts_eta:
67-
extra_kwrags["eta"] = eta
68+
extra_kwargs["eta"] = eta
6869

6970
for t in tqdm(self.scheduler.timesteps):
7071
if guidance_scale == 1.0:
@@ -86,7 +87,7 @@ def __call__(
8687
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
8788

8889
# compute the previous noisy sample x_t -> x_t-1
89-
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]
90+
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
9091

9192
# scale and decode the image latents with vae
9293
latents = 1 / 0.18215 * latents

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@ def __call__(
3535

3636
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
3737
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
38-
extra_kwrags = {}
38+
39+
extra_kwargs = {}
3940
if accepts_eta:
40-
extra_kwrags["eta"] = eta
41+
extra_kwargs["eta"] = eta
4142

4243
for t in tqdm(self.scheduler.timesteps):
4344
# predict the noise residual
4445
noise_prediction = self.unet(latents, t)["sample"]
4546
# compute the previous noisy sample x_t -> x_t-1
46-
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]
47+
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]
4748

4849
# decode the image latents with the VAE
4950
image = self.vqvae.decode(latents)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ...utils import is_transformers_available
2+
3+
4+
if is_transformers_available():
5+
from .pipeline_stable_diffusion import StableDiffusionPipeline
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import inspect
2+
from typing import List, Optional, Union
3+
4+
import torch
5+
6+
from tqdm.auto import tqdm
7+
from transformers import CLIPTextModel, CLIPTokenizer
8+
9+
from ...models import AutoencoderKL, UNet2DConditionModel
10+
from ...pipeline_utils import DiffusionPipeline
11+
from ...schedulers import DDIMScheduler, PNDMScheduler
12+
13+
14+
class StableDiffusionPipeline(DiffusionPipeline):
15+
def __init__(
16+
self,
17+
vae: AutoencoderKL,
18+
text_encoder: CLIPTextModel,
19+
tokenizer: CLIPTokenizer,
20+
unet: UNet2DConditionModel,
21+
scheduler: Union[DDIMScheduler, PNDMScheduler],
22+
):
23+
super().__init__()
24+
scheduler = scheduler.set_format("pt")
25+
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
26+
27+
@torch.no_grad()
28+
def __call__(
29+
self,
30+
prompt: Union[str, List[str]],
31+
num_inference_steps: Optional[int] = 50,
32+
guidance_scale: Optional[float] = 1.0,
33+
eta: Optional[float] = 0.0,
34+
generator: Optional[torch.Generator] = None,
35+
torch_device: Optional[Union[str, torch.device]] = None,
36+
output_type: Optional[str] = "pil",
37+
):
38+
if torch_device is None:
39+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
40+
41+
if isinstance(prompt, str):
42+
batch_size = 1
43+
elif isinstance(prompt, list):
44+
batch_size = len(prompt)
45+
else:
46+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
47+
48+
self.unet.to(torch_device)
49+
self.vae.to(torch_device)
50+
self.text_encoder.to(torch_device)
51+
52+
# get prompt text embeddings
53+
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
54+
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
55+
56+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
57+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
58+
# corresponds to doing no classifier free guidance.
59+
do_classifier_free_guidance = guidance_scale > 1.0
60+
# get unconditional embeddings for classifier free guidance
61+
if do_classifier_free_guidance:
62+
max_length = text_input.input_ids.shape[-1]
63+
uncond_input = self.tokenizer(
64+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
65+
)
66+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]
67+
68+
# For classifier free guidance, we need to do two forward passes.
69+
# Here we concatenate the unconditional and text embeddings into a single batch
70+
# to avoid doing two forward passes
71+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
72+
73+
# get the intial random noise
74+
latents = torch.randn(
75+
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
76+
generator=generator,
77+
)
78+
latents = latents.to(torch_device)
79+
80+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
81+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
82+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
83+
# and should be between [0, 1]
84+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
85+
extra_kwargs = {}
86+
if accepts_eta:
87+
extra_kwargs["eta"] = eta
88+
89+
self.scheduler.set_timesteps(num_inference_steps)
90+
91+
for t in tqdm(self.scheduler.timesteps):
92+
# expand the latents if we are doing classifier free guidance
93+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
94+
95+
# predict the noise residual
96+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
97+
98+
# perform guidance
99+
if do_classifier_free_guidance:
100+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
101+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
102+
103+
# compute the previous noisy sample x_t -> x_t-1
104+
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
105+
106+
# scale and decode the image latents with vae
107+
latents = 1 / 0.18215 * latents
108+
image = self.vae.decode(latents)
109+
110+
image = (image / 2 + 0.5).clamp(0, 1)
111+
image = image.cpu().permute(0, 2, 3, 1).numpy()
112+
if output_type == "pil":
113+
image = self.numpy_to_pil(image)
114+
115+
return {"sample": image}

src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010

1111
class KarrasVePipeline(DiffusionPipeline):
1212
"""
13-
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
14-
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
13+
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
14+
the VE column of Table 1 from [1] for reference.
1515
16-
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
17-
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
16+
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
17+
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
18+
differential equations." https://arxiv.org/abs/2011.13456
1819
"""
1920

2021
unet: UNet2DModel

src/diffusers/schedulers/scheduling_karras_ve.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424

2525
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
2626
"""
27-
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
28-
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
27+
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
28+
the VE column of Table 1 from [1] for reference.
2929
30-
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
31-
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
30+
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
31+
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
32+
differential equations." https://arxiv.org/abs/2011.13456
3233
"""
3334

3435
@register_to_config
@@ -43,10 +44,9 @@ def __init__(
4344
tensor_format="pt",
4445
):
4546
"""
46-
For more details on the parameters, see the original paper's Appendix E.:
47-
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364.
48-
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model
49-
are described in Table 5 of the paper.
47+
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
48+
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
49+
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
5050
5151
Args:
5252
sigma_min (`float`): minimum noise magnitude
@@ -81,8 +81,8 @@ def set_timesteps(self, num_inference_steps):
8181

8282
def add_noise_to_input(self, sample, sigma, generator=None):
8383
"""
84-
Explicit Langevin-like "churn" step of adding noise to the sample according to
85-
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
84+
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
85+
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
8686
"""
8787
if self.s_min <= sigma <= self.s_max:
8888
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)

src/diffusers/utils/dummy_transformers_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
88

99
def __init__(self, *args, **kwargs):
1010
requires_backends(self, ["transformers"])
11+
12+
13+
class StableDiffusionPipeline(metaclass=DummyObject):
14+
_backends = ["transformers"]
15+
16+
def __init__(self, *args, **kwargs):
17+
requires_backends(self, ["transformers"])

tests/test_modeling_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from diffusers.testing_utils import floats_tensor, slow, torch_device
4646
from diffusers.training_utils import EMAModel
4747

48+
from ..src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
49+
4850

4951
torch.backends.cuda.matmul.allow_tf32 = False
5052

@@ -839,6 +841,38 @@ def test_ldm_text2img_fast(self):
839841
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
840842
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
841843

844+
@slow
845+
def test_stable_diffusion(self):
846+
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
847+
848+
prompt = "A painting of a squirrel eating a burger"
849+
generator = torch.manual_seed(0)
850+
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
851+
"sample"
852+
]
853+
854+
image_slice = image[0, -3:, -3:, -1]
855+
856+
# TODO: update the expected_slice
857+
assert image.shape == (1, 512, 512, 3)
858+
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
859+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
860+
861+
@slow
862+
def test_stable_diffusion_fast(self):
863+
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
864+
865+
prompt = "A painting of a squirrel eating a burger"
866+
generator = torch.manual_seed(0)
867+
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
868+
869+
image_slice = image[0, -3:, -3:, -1]
870+
871+
# TODO: update the expected_slice
872+
assert image.shape == (1, 512, 512, 3)
873+
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
874+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
875+
842876
@slow
843877
def test_score_sde_ve_pipeline(self):
844878
model_id = "google/ncsnpp-church-256"

0 commit comments

Comments
 (0)