Skip to content

[BUG] fixes in kadinsky pipeline #11080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
vae_scale_factor: int = 8,
vae_latent_channels: int = 4,
resample: str = "lanczos",
reducing_gap: int = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
Expand Down Expand Up @@ -498,7 +499,11 @@ def resize(
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
if isinstance(image, PIL.Image.Image):
if resize_mode == "default":
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
image = image.resize(
(width, height),
resample=PIL_INTERPOLATION[self.config.resample],
reducing_gap=self.config.reducing_gap,
)
elif resize_mode == "fill":
image = self._resize_and_fill(image, width, height)
elif resize_mode == "crop":
Expand Down
33 changes: 13 additions & 20 deletions src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@

from typing import Callable, List, Optional, Union

import numpy as np
import PIL.Image
import torch
from PIL import Image
from transformers import (
XLMRobertaTokenizer,
)

from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
Expand Down Expand Up @@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8):
return new_h * scale_factor, new_w * scale_factor


def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image


class KandinskyImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
Expand Down Expand Up @@ -143,7 +133,16 @@ def __init__(
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
self.movq_scale_factor = (
2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
)
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)

def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
Expand Down Expand Up @@ -417,7 +416,7 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)

image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)

latents = self.movq.encode(image)["latents"]
Expand Down Expand Up @@ -498,13 +497,7 @@ def __call__(
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")

if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

from typing import Callable, List, Optional, Union

import numpy as np
import PIL.Image
import torch
from PIL import Image

from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
Expand Down Expand Up @@ -105,27 +104,6 @@
"""


# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor


# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image


class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
Expand Down Expand Up @@ -157,7 +135,14 @@ def __init__(
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)

# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
Expand Down Expand Up @@ -316,15 +301,14 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)

image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)

latents = self.movq.encode(image)["latents"]
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
Expand Down Expand Up @@ -379,13 +363,7 @@ def __call__(
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")

if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

from typing import Callable, Dict, List, Optional, Union

import numpy as np
import PIL.Image
import torch
from PIL import Image

from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import deprecate, is_torch_xla_available, logging
Expand Down Expand Up @@ -76,27 +75,6 @@
"""


# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor


# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image


class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
Expand Down Expand Up @@ -129,7 +107,14 @@ def __init__(
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)

# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
Expand Down Expand Up @@ -319,15 +304,14 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)

image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)

latents = self.movq.encode(image)["latents"]
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
Expand Down Expand Up @@ -383,21 +367,9 @@ def __call__(
if XLA_AVAILABLE:
xm.mark_step()

if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
)

if not output_type == "latent":
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)
else:
image = latents

Expand Down
43 changes: 11 additions & 32 deletions src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import inspect
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import PIL
import PIL.Image
import torch
from transformers import T5EncoderModel, T5Tokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
Expand Down Expand Up @@ -53,24 +53,6 @@
"""


def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor


def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image


class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->movq->unet->movq"
_callback_tensor_inputs = [
Expand All @@ -94,6 +76,14 @@ def __init__(
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)

def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
Expand Down Expand Up @@ -566,7 +556,7 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)

image = torch.cat([prepare_image(i) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -630,20 +620,9 @@ def __call__(
xm.mark_step()

# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent":
image = self.movq.decode(latents, force_not_quantize=True)["sample"]

if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)
else:
image = latents

Expand Down
Loading