|
| 1 | +import inspect |
| 2 | +from typing import List, Optional, Union |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from tqdm.auto import tqdm |
| 7 | +from transformers import CLIPTextModel, CLIPTokenizer |
| 8 | + |
| 9 | +from ...models import AutoencoderKL, UNet2DConditionModel |
| 10 | +from ...pipeline_utils import DiffusionPipeline |
| 11 | +from ...schedulers import DDIMScheduler, PNDMScheduler |
| 12 | + |
| 13 | + |
| 14 | +class StableDiffusionPipeline(DiffusionPipeline): |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + vae: AutoencoderKL, |
| 18 | + text_encoder: CLIPTextModel, |
| 19 | + tokenizer: CLIPTokenizer, |
| 20 | + unet: UNet2DConditionModel, |
| 21 | + scheduler: Union[DDIMScheduler, PNDMScheduler], |
| 22 | + ): |
| 23 | + super().__init__() |
| 24 | + scheduler = scheduler.set_format("pt") |
| 25 | + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) |
| 26 | + |
| 27 | + @torch.no_grad() |
| 28 | + def __call__( |
| 29 | + self, |
| 30 | + prompt: Union[str, List[str]], |
| 31 | + num_inference_steps: Optional[int] = 50, |
| 32 | + guidance_scale: Optional[float] = 1.0, |
| 33 | + eta: Optional[float] = 0.0, |
| 34 | + generator: Optional[torch.Generator] = None, |
| 35 | + torch_device: Optional[Union[str, torch.device]] = None, |
| 36 | + output_type: Optional[str] = "pil", |
| 37 | + ): |
| 38 | + if torch_device is None: |
| 39 | + torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
| 40 | + |
| 41 | + if isinstance(prompt, str): |
| 42 | + batch_size = 1 |
| 43 | + elif isinstance(prompt, list): |
| 44 | + batch_size = len(prompt) |
| 45 | + else: |
| 46 | + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| 47 | + |
| 48 | + self.unet.to(torch_device) |
| 49 | + self.vae.to(torch_device) |
| 50 | + self.text_encoder.to(torch_device) |
| 51 | + |
| 52 | + # get prompt text embeddings |
| 53 | + text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") |
| 54 | + text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] |
| 55 | + |
| 56 | + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 57 | + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
| 58 | + # corresponds to doing no classifier free guidance. |
| 59 | + do_classifier_free_guidance = guidance_scale > 1.0 |
| 60 | + # get unconditional embeddings for classifier free guidance |
| 61 | + if do_classifier_free_guidance: |
| 62 | + max_length = text_input.input_ids.shape[-1] |
| 63 | + uncond_input = self.tokenizer( |
| 64 | + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" |
| 65 | + ) |
| 66 | + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] |
| 67 | + |
| 68 | + # For classifier free guidance, we need to do two forward passes. |
| 69 | + # Here we concatenate the unconditional and text embeddings into a single batch |
| 70 | + # to avoid doing two forward passes |
| 71 | + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| 72 | + |
| 73 | + # get the intial random noise |
| 74 | + latents = torch.randn( |
| 75 | + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), |
| 76 | + generator=generator, |
| 77 | + ) |
| 78 | + latents = latents.to(torch_device) |
| 79 | + |
| 80 | + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 81 | + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| 82 | + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
| 83 | + # and should be between [0, 1] |
| 84 | + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| 85 | + extra_kwargs = {} |
| 86 | + if accepts_eta: |
| 87 | + extra_kwargs["eta"] = eta |
| 88 | + |
| 89 | + self.scheduler.set_timesteps(num_inference_steps) |
| 90 | + |
| 91 | + for t in tqdm(self.scheduler.timesteps): |
| 92 | + # expand the latents if we are doing classifier free guidance |
| 93 | + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 94 | + |
| 95 | + # predict the noise residual |
| 96 | + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] |
| 97 | + |
| 98 | + # perform guidance |
| 99 | + if do_classifier_free_guidance: |
| 100 | + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| 101 | + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 102 | + |
| 103 | + # compute the previous noisy sample x_t -> x_t-1 |
| 104 | + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] |
| 105 | + |
| 106 | + # scale and decode the image latents with vae |
| 107 | + latents = 1 / 0.18215 * latents |
| 108 | + image = self.vae.decode(latents) |
| 109 | + |
| 110 | + image = (image / 2 + 0.5).clamp(0, 1) |
| 111 | + image = image.cpu().permute(0, 2, 3, 1).numpy() |
| 112 | + if output_type == "pil": |
| 113 | + image = self.numpy_to_pil(image) |
| 114 | + |
| 115 | + return {"sample": image} |
0 commit comments