Skip to content

Commit 5a59f9b

Browse files
duongna21patrickvonplatenpatil-surajpcuenca
authored
Add LDM Super Resolution pipeline (huggingface#1116)
* Add ldm super resolution pipeline * style * fix copies * style * fix doc * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Suraj Patil <[email protected]> * add doc * address comments * address comments * fix doc * minor * add tests * add tests * load text encoder from subfolder * fix test * fix test * style * style * handle mps latents * unfix typo * unfix typo * Update tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py Co-authored-by: Pedro Cuenca <[email protected]> * fix set_timesteps mps * fix set_timesteps mps * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py Co-authored-by: Suraj Patil <[email protected]> * style * test 64x64 instead of 256x256 Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent b93fe08 commit 5a59f9b

File tree

7 files changed

+310
-0
lines changed

7 files changed

+310
-0
lines changed

docs/source/api/pipelines/latent_diffusion.mdx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,15 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
3333
| Pipeline | Tasks | Colab
3434
|---|---|:---:|
3535
| [pipeline_latent_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) | *Text-to-Image Generation* | - |
36+
| [pipeline_latent_diffusion_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py) | *Super Resolution* | - |
3637

3738
## Examples:
3839

3940

4041
## LDMTextToImagePipeline
4142
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
4243
- __call__
44+
45+
## LDMSuperResolutionPipeline
46+
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline
47+
- __call__

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
DDPMPipeline,
3636
KarrasVePipeline,
3737
LDMPipeline,
38+
LDMSuperResolutionPipeline,
3839
PNDMPipeline,
3940
RePaintPipeline,
4041
ScoreSdeVePipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .dance_diffusion import DanceDiffusionPipeline
66
from .ddim import DDIMPipeline
77
from .ddpm import DDPMPipeline
8+
from .latent_diffusion import LDMSuperResolutionPipeline
89
from .latent_diffusion_uncond import LDMPipeline
910
from .pndm import PNDMPipeline
1011
from .repaint import RePaintPipeline

src/diffusers/pipelines/latent_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# flake8: noqa
22
from ...utils import is_transformers_available
3+
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
34

45

56
if is_transformers_available():
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import inspect
2+
from typing import Optional, Tuple, Union
3+
4+
import numpy as np
5+
import torch
6+
import torch.utils.checkpoint
7+
8+
import PIL
9+
10+
from ...models import UNet2DModel, VQModel
11+
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
12+
from ...schedulers import (
13+
DDIMScheduler,
14+
DPMSolverMultistepScheduler,
15+
EulerAncestralDiscreteScheduler,
16+
EulerDiscreteScheduler,
17+
LMSDiscreteScheduler,
18+
PNDMScheduler,
19+
)
20+
21+
22+
def preprocess(image):
23+
w, h = image.size
24+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
25+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
26+
image = np.array(image).astype(np.float32) / 255.0
27+
image = image[None].transpose(0, 3, 1, 2)
28+
image = torch.from_numpy(image)
29+
return 2.0 * image - 1.0
30+
31+
32+
class LDMSuperResolutionPipeline(DiffusionPipeline):
33+
r"""
34+
A pipeline for image super-resolution using Latent
35+
36+
This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
37+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
38+
39+
Parameters:
40+
vqvae ([`VQModel`]):
41+
Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations.
42+
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
43+
scheduler ([`SchedulerMixin`]):
44+
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
45+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
46+
[`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
47+
"""
48+
49+
def __init__(
50+
self,
51+
vqvae: VQModel,
52+
unet: UNet2DModel,
53+
scheduler: Union[
54+
DDIMScheduler,
55+
PNDMScheduler,
56+
LMSDiscreteScheduler,
57+
EulerDiscreteScheduler,
58+
EulerAncestralDiscreteScheduler,
59+
DPMSolverMultistepScheduler,
60+
],
61+
):
62+
super().__init__()
63+
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
64+
65+
@torch.no_grad()
66+
def __call__(
67+
self,
68+
init_image: Union[torch.Tensor, PIL.Image.Image],
69+
batch_size: Optional[int] = 1,
70+
num_inference_steps: Optional[int] = 100,
71+
eta: Optional[float] = 0.0,
72+
generator: Optional[torch.Generator] = None,
73+
output_type: Optional[str] = "pil",
74+
return_dict: bool = True,
75+
**kwargs,
76+
) -> Union[Tuple, ImagePipelineOutput]:
77+
r"""
78+
Args:
79+
init_image (`torch.Tensor` or `PIL.Image.Image`):
80+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
81+
process.
82+
batch_size (`int`, *optional*, defaults to 1):
83+
Number of images to generate.
84+
num_inference_steps (`int`, *optional*, defaults to 100):
85+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
86+
expense of slower inference.
87+
eta (`float`, *optional*, defaults to 0.0):
88+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
89+
[`schedulers.DDIMScheduler`], will be ignored for others.
90+
generator (`torch.Generator`, *optional*):
91+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
92+
deterministic.
93+
output_type (`str`, *optional*, defaults to `"pil"`):
94+
The output format of the generate image. Choose between
95+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
96+
return_dict (`bool`, *optional*):
97+
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
98+
99+
Returns:
100+
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
101+
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
102+
generated images.
103+
"""
104+
105+
if isinstance(init_image, PIL.Image.Image):
106+
batch_size = 1
107+
elif isinstance(init_image, torch.Tensor):
108+
batch_size = init_image.shape[0]
109+
else:
110+
raise ValueError(
111+
f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}"
112+
)
113+
114+
if isinstance(init_image, PIL.Image.Image):
115+
init_image = preprocess(init_image)
116+
117+
height, width = init_image.shape[-2:]
118+
119+
# in_channels should be 6: 3 for latents, 3 for low resolution image
120+
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
121+
latents_dtype = next(self.unet.parameters()).dtype
122+
123+
if self.device.type == "mps":
124+
# randn does not work reproducibly on mps
125+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
126+
latents = latents.to(self.device)
127+
else:
128+
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
129+
130+
init_image = init_image.to(device=self.device, dtype=latents_dtype)
131+
132+
# set timesteps and move to the correct device
133+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
134+
timesteps_tensor = self.scheduler.timesteps
135+
136+
# scale the initial noise by the standard deviation required by the scheduler
137+
latents = latents * self.scheduler.init_noise_sigma
138+
139+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
140+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
141+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
142+
# and should be between [0, 1]
143+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
144+
extra_kwargs = {}
145+
if accepts_eta:
146+
extra_kwargs["eta"] = eta
147+
148+
for t in self.progress_bar(timesteps_tensor):
149+
# concat latents and low resolution image in the channel dimension.
150+
latents_input = torch.cat([latents, init_image], dim=1)
151+
latents_input = self.scheduler.scale_model_input(latents_input, t)
152+
# predict the noise residual
153+
noise_pred = self.unet(latents_input, t).sample
154+
# compute the previous noisy sample x_t -> x_t-1
155+
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
156+
157+
# decode the image latents with the VQVAE
158+
image = self.vqvae.decode(latents).sample
159+
image = torch.clamp(image, -1.0, 1.0)
160+
image = image / 2 + 0.5
161+
image = image.cpu().permute(0, 2, 3, 1).numpy()
162+
163+
if output_type == "pil":
164+
image = self.numpy_to_pil(image)
165+
166+
if not return_dict:
167+
return (image,)
168+
169+
return ImagePipelineOutput(images=image)

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs):
227227
requires_backends(cls, ["torch"])
228228

229229

230+
class LDMSuperResolutionPipeline(metaclass=DummyObject):
231+
_backends = ["torch"]
232+
233+
def __init__(self, *args, **kwargs):
234+
requires_backends(self, ["torch"])
235+
236+
@classmethod
237+
def from_config(cls, *args, **kwargs):
238+
requires_backends(cls, ["torch"])
239+
240+
@classmethod
241+
def from_pretrained(cls, *args, **kwargs):
242+
requires_backends(cls, ["torch"])
243+
244+
230245
class PNDMPipeline(metaclass=DummyObject):
231246
_backends = ["torch"]
232247

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import random
17+
import unittest
18+
19+
import numpy as np
20+
import torch
21+
22+
import PIL
23+
from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
24+
from diffusers.utils import floats_tensor, load_image, slow, torch_device
25+
from diffusers.utils.testing_utils import require_torch
26+
27+
from ...test_pipelines_common import PipelineTesterMixin
28+
29+
30+
torch.backends.cuda.matmul.allow_tf32 = False
31+
32+
33+
class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
34+
@property
35+
def dummy_image(self):
36+
batch_size = 1
37+
num_channels = 3
38+
sizes = (32, 32)
39+
40+
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
41+
return image
42+
43+
@property
44+
def dummy_uncond_unet(self):
45+
torch.manual_seed(0)
46+
model = UNet2DModel(
47+
block_out_channels=(32, 64),
48+
layers_per_block=2,
49+
sample_size=32,
50+
in_channels=6,
51+
out_channels=3,
52+
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
53+
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
54+
)
55+
return model
56+
57+
@property
58+
def dummy_vq_model(self):
59+
torch.manual_seed(0)
60+
model = VQModel(
61+
block_out_channels=[32, 64],
62+
in_channels=3,
63+
out_channels=3,
64+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
65+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
66+
latent_channels=3,
67+
)
68+
return model
69+
70+
def test_inference_superresolution(self):
71+
unet = self.dummy_uncond_unet
72+
scheduler = DDIMScheduler()
73+
vqvae = self.dummy_vq_model
74+
75+
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
76+
ldm.to(torch_device)
77+
ldm.set_progress_bar_config(disable=None)
78+
79+
init_image = self.dummy_image.to(torch_device)
80+
81+
# Warmup pass when using mps (see #372)
82+
if torch_device == "mps":
83+
generator = torch.manual_seed(0)
84+
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images
85+
86+
generator = torch.manual_seed(0)
87+
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
88+
89+
image_slice = image[0, -3:, -3:, -1]
90+
91+
assert image.shape == (1, 64, 64, 3)
92+
expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176])
93+
tolerance = 1e-2 if torch_device != "mps" else 3e-2
94+
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
95+
96+
97+
@slow
98+
@require_torch
99+
class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
100+
def test_inference_superresolution(self):
101+
init_image = load_image(
102+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
103+
"/vq_diffusion/teddy_bear_pool.png"
104+
)
105+
init_image = init_image.resize((64, 64), resample=PIL.Image.LANCZOS)
106+
107+
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
108+
ldm.to(torch_device)
109+
ldm.set_progress_bar_config(disable=None)
110+
111+
generator = torch.Generator(device=torch_device).manual_seed(0)
112+
image = ldm(init_image, generator=generator, num_inference_steps=20, output_type="numpy").images
113+
114+
image_slice = image[0, -3:, -3:, -1]
115+
116+
assert image.shape == (1, 256, 256, 3)
117+
expected_slice = np.array([0.7418, 0.7472, 0.7424, 0.7422, 0.7463, 0.726, 0.7382, 0.7248, 0.6828])
118+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

0 commit comments

Comments
 (0)