Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Sep 5, 2024

fix #9366
this is a known issue that happens when the first timestep is a duplicated one, so I added set_begin_index so we don't rely on the search method to find the first timestep (see more details on #6728)

I also refactored the get_timesteps method for SDXL img2img, so it's a little bit easier to reason about

a simple unit test for the get_timesteps, I will run more slow tests for affected pipelines

from diffusers import StableDiffusionXLImg2ImgPipeline
import numpy as np
import torch
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps

# this is the original get_timesteps function (before the change from this PR)
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
    # get the original timestep using init_timestep
    if denoising_start is None:
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
        t_start = max(num_inference_steps - init_timestep, 0)
    else:
        t_start = 0

    timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

    # Strength is irrelevant if we directly request a timestep to start at;
    # that is, strength is determined by the denoising_start instead.
    if denoising_start is not None:
        discrete_timestep_cutoff = int(
            round(
                self.scheduler.config.num_train_timesteps
                - (denoising_start * self.scheduler.config.num_train_timesteps)
            )
        )

        num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
        if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
            # if the scheduler is a 2nd order scheduler we might have to do +1
            # because `num_inference_steps` might be even given that every timestep
            # (except the highest one) is duplicated. If `num_inference_steps` is even it would
            # mean that we cut the timesteps in the middle of the denoising step
            # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
            # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
            num_inference_steps = num_inference_steps + 1

        # because t_n+1 >= t_n, we slice the timesteps starting from the end
        timesteps = timesteps[-num_inference_steps:]
        return timesteps, num_inference_steps

    return timesteps, num_inference_steps - t_start

# this is the new get_timesteps function (after the change from this PR)
get_timesteps_new = StableDiffusionXLImg2ImgPipeline.get_timesteps


pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16"
)

# testing when denoising_start is None

for num_inference_steps in range(10, 100, 10):
    for strength in np.arange(0, 1, 0.1):

        self = pipeline
        device="cuda"
        timesteps = None
        sigmas = None
        denoising_start = None
    
        timesteps, num_inference_steps = retrieve_timesteps(
            self.scheduler, num_inference_steps, device, timesteps, sigmas
        )
        print(f" test for num_inference_steps: {num_inference_steps}, strength: {strength}, denoising_start: {denoising_start}")
        timesteps_1, num_inference_steps_1 = get_timesteps(self, num_inference_steps, strength, device, denoising_start)
        timesteps_2, num_inference_steps_2 = get_timesteps_new(self, num_inference_steps, strength, device, denoising_start)
        print(f"timesteps(original): {timesteps_1}, num_inference_steps(original): {num_inference_steps_1}")
        print(f"timesteps(testin g): {timesteps_2}, num_inference_steps(testing): {num_inference_steps_2}")
        assert (timesteps_1 - timesteps_2).abs().sum() == 0
        assert num_inference_steps_1 == num_inference_steps_1
        print("-"*100)

# testing when denoising_start is not None

for num_inference_steps in range(10, 100, 10):
    for strength in np.arange(0, 1, 0.1):
        for denoising_start in np.arange(0, 1, 0.1):

            self = pipeline
            device="cuda"
            timesteps = None
            sigmas = None
        
            timesteps, num_inference_steps = retrieve_timesteps(
                self.scheduler, num_inference_steps, device, timesteps, sigmas
            )
            print(f" test for num_inference_steps: {num_inference_steps}, strength: {strength}, denoising_start: {denoising_start}")
            timesteps_1, num_inference_steps_1 = get_timesteps(self, num_inference_steps, strength, device, denoising_start)
            timesteps_2, num_inference_steps_2 = get_timesteps_new(self, num_inference_steps, strength, device, denoising_start)
            print(f"timesteps(original): {timesteps_1}, num_inference_steps(original): {num_inference_steps_1}")
            print(f"timesteps(testin g): {timesteps_2}, num_inference_steps(testing): {num_inference_steps_2}")
            assert (timesteps_1 - timesteps_2).abs().sum() == 0
            assert num_inference_steps_1 == num_inference_steps_1
            print("-"*100)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Sep 5, 2024

slow test 1, tested for denoising_start = 0 and None, outputs match with main across all scheduler configurations here

# branch = "main" # or "add-begin-index-sdxlimg2img"
branch = "add-begin-index-sdxlimg2img"
# output_type ="pil" # or "pt"  
output_type = "pt"

steps = 20
seed = 0
denoising_start = 0.8 # or None
# denoising_start = None


import torch
from diffusers import DiffusionPipeline
from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler
import os

from diffusers.utils import make_image_grid

import gc

def flush():
    """Wipes off memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


# define all scheduler configs we want to test
config_min = {"final_sigmas_type":"sigma_min"}
config_min_euler = {"final_sigmas_type":"sigma_min", "euler_at_final": True }
config_zero = {"final_sigmas_type":"zero"}

schedulers = {
    "DPMPP_2M": {
        "min": (DPMSolverMultistepScheduler, config_min),
        "min_euler": (DPMSolverMultistepScheduler, config_min_euler),
        "zero": (DPMSolverMultistepScheduler, config_zero),
     },
     "DPMPP_2M_K": {
        "min": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min}),
        "min_euler": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min_euler}),
        "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_zero}),
     },
     "DPMPP_2M_SDE": {
        "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min}),
        "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min_euler}),
        "zero": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_zero}),
     },
     "DPMPP_2M_SDE_K": {
        "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min}),
        "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min_euler}),
        "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++", **config_zero}),
     },
     "DPMPP": {
        "min": (DPMSolverSinglestepScheduler, config_min),
        "min_euler": (DPMSolverSinglestepScheduler, config_min_euler),
        "zero": (DPMSolverSinglestepScheduler, config_zero),
     },
     "DPMPP_K": {
        "min": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min}),
        "min_euler": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min_euler}),
        "zero": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_zero}),
     },
}


# define save directory
save_dir = './yiyi_test_3_output'

if not os.path.exists(save_dir):
    os.mkdir(save_dir)


# load base model and create latent
from diffusers import DiffusionPipeline
import torch

base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
).to("cuda")

prompt = "A majestic lion jumping from a big stone at night"

generator = torch.Generator(device='cuda').manual_seed(seed)
image = base(
    prompt=prompt,
    num_inference_steps=40,
    denoising_end=denoising_start,
    generator=generator,
    output_type="latent",
).images


# load refiner pipe
model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
pipe = DiffusionPipeline.from_pretrained(
    model_id,
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")

del base
flush()


params = {
    "prompt": [prompt],
    "num_inference_steps": steps,
    "guidance_scale": 7,
    "image": image,
    "denoising_start":denoising_start,
}
for scheduler_name in schedulers.keys():
    scheduler_configs = schedulers[scheduler_name]
    for scheduler_config_name in scheduler_configs.keys():
        generator = torch.Generator(device='cuda').manual_seed(seed)
        scheduler = scheduler_configs[scheduler_config_name][0].from_pretrained(
            model_id,
            subfolder="scheduler",
            **scheduler_configs[scheduler_config_name][1],
        )
        pipe.scheduler = scheduler

        img = pipe(**params, generator=generator, output_type=output_type).images[0]
        if output_type == "pt":
            if branch == "main":
                torch.save(img, os.path.join(save_dir, f"{branch}_{scheduler_name}_{scheduler_config_name}_{steps}_{denoising_start}.pt"))
            else:
                img_expected = torch.load(os.path.join(save_dir, f"main_{scheduler_name}_{scheduler_config_name}_{steps}_{denoising_start}.pt"))
                assert (img - img_expected).abs().max() < 1e-3, f"Image mismatch for {scheduler_name}_{scheduler_config_name}_{steps}_{denoising_start}"
        else:
            img.save(os.path.join(save_dir, f"{branch}_{scheduler_name}_{scheduler_config_name}_{steps}_{denoising_start}.png"))

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Sep 5, 2024

slow test for all affected pipelines

branch = "main" # or "add-begin-index-sdxlimg2img"
# branch = "add-begin-index-sdxlimg2img"



import torch
from diffusers import AutoPipelineForImage2Image, StableDiffusionXLInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline
from diffusers.utils import load_image
import gc
import os

def flush():
    """Wipes off memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


# define save directory
save_dir = './yiyi_test_5_output'

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

# test1: sdxl img2img
model_name = "sdxl_img2img"
print(f" Running test for {model_name}")
pipe = AutoPipelineForImage2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"

init_image = load_image(url).convert("RGB")
prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, image=init_image, generator=generator).images[0]
image.save(f"{save_dir}/{branch}_{model_name}.png")

del pipe
flush()

# test2: sdxl inpaint
model_name = "sdxl_inpaint"

pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)
pipe.to("cuda")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")

prompt = "A majestic tiger sitting on a bench"
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(
    prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80, generator=generator
).images[0]

image.save(f"{save_dir}/{branch}_{model_name}.png")

del pipe
flush()


# test3: kolors img2img
model_name = "kolors_img2img"
pipe = AutoPipelineForImage2Image.from_pretrained(
    "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
url = (
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/bunny_source.png"
)

init_image = load_image(url)
prompt = "high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed."

generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, image=init_image, generator=generator).images[0]

image.save(f"{save_dir}/{branch}_{model_name}.png")

del pipe
flush()


# test4: controlnet sdxl inpaint
from diffusers import ControlNetModel, DDIMScheduler
from PIL import Image
import numpy as np
import cv2

model_name = "controlnet_sdxl_inpaint"
init_image = load_image(
    "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
)
init_image = init_image.resize((1024, 1024))

generator = torch.Generator(device="cpu").manual_seed(1)

mask_image = load_image(
    "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
)
mask_image = mask_image.resize((1024, 1024))


def make_canny_condition(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    image = Image.fromarray(image)
    return image


control_image = make_canny_condition(init_image)

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
)

pipe.enable_model_cpu_offload()

generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(
    "a handsome man with ray-ban sunglasses",
    num_inference_steps=20,
    generator=generator,
    eta=1.0,
    image=init_image,
    mask_image=mask_image,
    control_image=control_image,
).images[0]

image.save(f"{save_dir}/{branch}_{model_name}.png")

del pipe
flush()


# test5: pag img2img
model_name = "pag_img2img"
from diffusers import AutoPipelineForImage2Image

pipe = AutoPipelineForImage2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    enable_pag=True,
)
pipe = pipe.to("cuda")
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"

init_image = load_image(url).convert("RGB")
prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, image=init_image, pag_scale=0.3, generator=generator).images[0]

image.save(f"{save_dir}/{branch}_{model_name}.png")


# test6: pag inpaint
model_name = "pag_inpaint"
from diffusers import AutoPipelineForInpainting

pipe = AutoPipelineForInpainting.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    enable_pag=True,
)
pipe.to("cuda")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")

prompt = "A majestic tiger sitting on a bench"
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(
    prompt=prompt,
    image=init_image,
    mask_image=mask_image,
    num_inference_steps=50,
    strength=0.80,
    pag_scale=0.3,
    generator=generator,
).images[0]

image.save(f"{save_dir}/{branch}_{model_name}.png")

del pipe
flush()

contorlnet SDXL inpaint

main this PR
main_controlnet_sdxl_inpaint add-begin-index-sdxlimg2img_controlnet_sdxl_inpaint

kolors img2img

main this PR
main_kolors_img2img add-begin-index-sdxlimg2img_kolors_img2img

PAG img2img

main this PR
main_pag_img2img add-begin-index-sdxlimg2img_pag_img2img

PAG inpaint

main this PR
main_pag_inpaint add-begin-index-sdxlimg2img_pag_inpaint

SDXL img2img

main this PR
main_sdxl_img2img add-begin-index-sdxlimg2img_sdxl_img2img

SDXL inpaint

main this PR
main_sdxl_inpaint add-begin-index-sdxlimg2img_sdxl_inpaint

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Sep 5, 2024

cc @tolgacangoz in case you're interested in giving this a review! (no worries if not)

@yiyixuxu yiyixuxu merged commit 485b8bb into main Sep 9, 2024
@yiyixuxu yiyixuxu deleted the add-begin-index-sdxlimg2img branch September 9, 2024 16:38
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* refator + add begin_index

* add kolors img2img to doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DPMSolverMultistepScheduler with AutoPipelineForImage2Image fails at specific combinations of step counts and strength

3 participants