Skip to content

Commit 1b42732

Browse files
authored
PIL-ify the pipeline outputs (huggingface#111)
1 parent 9e9d2db commit 1b42732

File tree

7 files changed

+49
-10
lines changed

7 files changed

+49
-10
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Optional, Union
2020

2121
from huggingface_hub import snapshot_download
22+
from PIL import Image
2223

2324
from .configuration_utils import ConfigMixin
2425
from .utils import DIFFUSERS_CACHE, logging
@@ -189,3 +190,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
189190
# 5. Instantiate the pipeline
190191
model = pipeline_class(**init_kwargs)
191192
return model
193+
194+
@staticmethod
195+
def numpy_to_pil(images):
196+
"""
197+
Convert a numpy image or a batch of images to a PIL image.
198+
"""
199+
if images.ndim == 3:
200+
images = images[None, ...]
201+
images = (images * 255).round().astype("uint8")
202+
pil_images = [Image.fromarray(image) for image in images]
203+
204+
return pil_images

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, unet, scheduler):
2828
self.register_modules(unet=unet, scheduler=scheduler)
2929

3030
@torch.no_grad()
31-
def __call__(self, batch_size=1, generator=None, torch_device=None):
31+
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="numpy"):
3232
if torch_device is None:
3333
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
3434

@@ -56,5 +56,7 @@ def __call__(self, batch_size=1, generator=None, torch_device=None):
5656

5757
image = (image / 2 + 0.5).clamp(0, 1)
5858
image = image.cpu().permute(0, 2, 3, 1).numpy()
59+
if output_type == "pil":
60+
image = self.numpy_to_pil(image)
5961

6062
return {"sample": image}

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __call__(
3030
eta=0.0,
3131
guidance_scale=1.0,
3232
num_inference_steps=50,
33+
output_type="numpy",
3334
):
3435
# eta corresponds to η in paper and should be between [0, 1]
3536

@@ -86,6 +87,8 @@ def __call__(
8687

8788
image = (image / 2 + 0.5).clamp(0, 1)
8889
image = image.cpu().permute(0, 2, 3, 1).numpy()
90+
if output_type == "pil":
91+
image = self.numpy_to_pil(image)
8992

9093
return {"sample": image}
9194

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@ def __init__(self, vqvae, unet, scheduler):
1313

1414
@torch.no_grad()
1515
def __call__(
16-
self,
17-
batch_size=1,
18-
generator=None,
19-
torch_device=None,
20-
eta=0.0,
21-
num_inference_steps=50,
16+
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="numpy"
2217
):
2318
# eta corresponds to η in paper and should be between [0, 1]
2419

@@ -47,5 +42,7 @@ def __call__(
4742

4843
image = (image / 2 + 0.5).clamp(0, 1)
4944
image = image.cpu().permute(0, 2, 3, 1).numpy()
45+
if output_type == "pil":
46+
image = self.numpy_to_pil(image)
5047

5148
return {"sample": image}

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, unet, scheduler):
2828
self.register_modules(unet=unet, scheduler=scheduler)
2929

3030
@torch.no_grad()
31-
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
31+
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="numpy"):
3232
# For more information on the sampling method you can take a look at Algorithm 2 of
3333
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
3434
if torch_device is None:
@@ -59,5 +59,7 @@ def __call__(self, batch_size=1, generator=None, torch_device=None, num_inferenc
5959

6060
image = (image / 2 + 0.5).clamp(0, 1)
6161
image = image.cpu().permute(0, 2, 3, 1).numpy()
62+
if output_type == "pil":
63+
image = self.numpy_to_pil(image)
6264

6365
return {"sample": image}

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, model, scheduler):
1111
self.register_modules(model=model, scheduler=scheduler)
1212

1313
@torch.no_grad()
14-
def __call__(self, num_inference_steps=2000, generator=None):
14+
def __call__(self, num_inference_steps=2000, generator=None, output_type="numpy"):
1515
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1616

1717
img_size = self.model.config.image_size
@@ -47,5 +47,7 @@ def __call__(self, num_inference_steps=2000, generator=None):
4747

4848
sample = sample.clamp(0, 1)
4949
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
50+
if output_type == "pil":
51+
sample = self.numpy_to_pil(sample)
5052

5153
return {"sample": sample}

tests/test_modeling_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import math
1919
import tempfile
2020
import unittest
21-
from atexit import register
2221

2322
import numpy as np
2423
import torch
2524

25+
import PIL
2626
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
2727
from diffusers import (
2828
AutoencoderKL,
@@ -728,6 +728,26 @@ def test_from_pretrained_hub(self):
728728

729729
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
730730

731+
@slow
732+
def test_output_format(self):
733+
model_path = "google/ddpm-cifar10-32"
734+
735+
pipe = DDIMPipeline.from_pretrained(model_path)
736+
737+
generator = torch.manual_seed(0)
738+
images = pipe(generator=generator)["sample"]
739+
assert images.shape == (1, 32, 32, 3)
740+
assert isinstance(images, np.ndarray)
741+
742+
images = pipe(generator=generator, output_type="numpy")["sample"]
743+
assert images.shape == (1, 32, 32, 3)
744+
assert isinstance(images, np.ndarray)
745+
746+
images = pipe(generator=generator, output_type="pil")["sample"]
747+
assert isinstance(images, list)
748+
assert len(images) == 1
749+
assert isinstance(images[0], PIL.Image.Image)
750+
731751
@slow
732752
def test_ddpm_cifar10(self):
733753
model_id = "google/ddpm-cifar10-32"

0 commit comments

Comments
 (0)