|
5 | 5 | import numpy as np |
6 | 6 | import torch |
7 | 7 |
|
| 8 | +import diffusers |
8 | 9 | import PIL |
9 | 10 | from diffusers import SchedulerMixin, StableDiffusionPipeline |
10 | 11 | from diffusers.models import AutoencoderKL, UNet2DConditionModel |
11 | 12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker |
12 | | -from diffusers.utils import PIL_INTERPOLATION, deprecate, logging |
| 13 | +from diffusers.utils import deprecate, logging |
| 14 | +from packaging import version |
13 | 15 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
14 | 16 |
|
15 | 17 |
|
| 18 | +try: |
| 19 | + from diffusers.utils import PIL_INTERPOLATION |
| 20 | +except ImportError: |
| 21 | + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
| 22 | + PIL_INTERPOLATION = { |
| 23 | + "linear": PIL.Image.Resampling.BILINEAR, |
| 24 | + "bilinear": PIL.Image.Resampling.BILINEAR, |
| 25 | + "bicubic": PIL.Image.Resampling.BICUBIC, |
| 26 | + "lanczos": PIL.Image.Resampling.LANCZOS, |
| 27 | + "nearest": PIL.Image.Resampling.NEAREST, |
| 28 | + } |
| 29 | + else: |
| 30 | + PIL_INTERPOLATION = { |
| 31 | + "linear": PIL.Image.LINEAR, |
| 32 | + "bilinear": PIL.Image.BILINEAR, |
| 33 | + "bicubic": PIL.Image.BICUBIC, |
| 34 | + "lanczos": PIL.Image.LANCZOS, |
| 35 | + "nearest": PIL.Image.NEAREST, |
| 36 | + } |
| 37 | +# ------------------------------------------------------------------------------ |
| 38 | + |
16 | 39 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
17 | 40 |
|
18 | 41 | re_attention = re.compile( |
@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): |
404 | 427 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
405 | 428 | """ |
406 | 429 |
|
407 | | - def __init__( |
408 | | - self, |
409 | | - vae: AutoencoderKL, |
410 | | - text_encoder: CLIPTextModel, |
411 | | - tokenizer: CLIPTokenizer, |
412 | | - unet: UNet2DConditionModel, |
413 | | - scheduler: SchedulerMixin, |
414 | | - safety_checker: StableDiffusionSafetyChecker, |
415 | | - feature_extractor: CLIPFeatureExtractor, |
416 | | - requires_safety_checker: bool = True, |
417 | | - ): |
418 | | - super().__init__( |
419 | | - vae=vae, |
420 | | - text_encoder=text_encoder, |
421 | | - tokenizer=tokenizer, |
422 | | - unet=unet, |
423 | | - scheduler=scheduler, |
424 | | - safety_checker=safety_checker, |
425 | | - feature_extractor=feature_extractor, |
426 | | - requires_safety_checker=requires_safety_checker, |
427 | | - ) |
| 430 | + if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): |
| 431 | + |
| 432 | + def __init__( |
| 433 | + self, |
| 434 | + vae: AutoencoderKL, |
| 435 | + text_encoder: CLIPTextModel, |
| 436 | + tokenizer: CLIPTokenizer, |
| 437 | + unet: UNet2DConditionModel, |
| 438 | + scheduler: SchedulerMixin, |
| 439 | + safety_checker: StableDiffusionSafetyChecker, |
| 440 | + feature_extractor: CLIPFeatureExtractor, |
| 441 | + requires_safety_checker: bool = True, |
| 442 | + ): |
| 443 | + super().__init__( |
| 444 | + vae=vae, |
| 445 | + text_encoder=text_encoder, |
| 446 | + tokenizer=tokenizer, |
| 447 | + unet=unet, |
| 448 | + scheduler=scheduler, |
| 449 | + safety_checker=safety_checker, |
| 450 | + feature_extractor=feature_extractor, |
| 451 | + requires_safety_checker=requires_safety_checker, |
| 452 | + ) |
| 453 | + self.__init__additional__() |
| 454 | + |
| 455 | + else: |
| 456 | + |
| 457 | + def __init__( |
| 458 | + self, |
| 459 | + vae: AutoencoderKL, |
| 460 | + text_encoder: CLIPTextModel, |
| 461 | + tokenizer: CLIPTokenizer, |
| 462 | + unet: UNet2DConditionModel, |
| 463 | + scheduler: SchedulerMixin, |
| 464 | + safety_checker: StableDiffusionSafetyChecker, |
| 465 | + feature_extractor: CLIPFeatureExtractor, |
| 466 | + ): |
| 467 | + super().__init__( |
| 468 | + vae=vae, |
| 469 | + text_encoder=text_encoder, |
| 470 | + tokenizer=tokenizer, |
| 471 | + unet=unet, |
| 472 | + scheduler=scheduler, |
| 473 | + safety_checker=safety_checker, |
| 474 | + feature_extractor=feature_extractor, |
| 475 | + ) |
| 476 | + self.__init__additional__() |
| 477 | + |
| 478 | + def __init__additional__(self): |
| 479 | + if not hasattr(self, "vae_scale_factor"): |
| 480 | + setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) |
| 481 | + |
| 482 | + @property |
| 483 | + def _execution_device(self): |
| 484 | + r""" |
| 485 | + Returns the device on which the pipeline's models will be executed. After calling |
| 486 | + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module |
| 487 | + hooks. |
| 488 | + """ |
| 489 | + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): |
| 490 | + return self.device |
| 491 | + for module in self.unet.modules(): |
| 492 | + if ( |
| 493 | + hasattr(module, "_hf_hook") |
| 494 | + and hasattr(module._hf_hook, "execution_device") |
| 495 | + and module._hf_hook.execution_device is not None |
| 496 | + ): |
| 497 | + return torch.device(module._hf_hook.execution_device) |
| 498 | + return self.device |
428 | 499 |
|
429 | 500 | def _encode_prompt( |
430 | 501 | self, |
@@ -752,37 +823,33 @@ def __call__( |
752 | 823 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
753 | 824 |
|
754 | 825 | # 8. Denoising loop |
755 | | - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
756 | | - with self.progress_bar(total=num_inference_steps) as progress_bar: |
757 | | - for i, t in enumerate(timesteps): |
758 | | - # expand the latents if we are doing classifier free guidance |
759 | | - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
760 | | - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
761 | | - |
762 | | - # predict the noise residual |
763 | | - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
764 | | - |
765 | | - # perform guidance |
766 | | - if do_classifier_free_guidance: |
767 | | - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
768 | | - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
769 | | - |
770 | | - # compute the previous noisy sample x_t -> x_t-1 |
771 | | - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
772 | | - |
773 | | - if mask is not None: |
774 | | - # masking |
775 | | - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) |
776 | | - latents = (init_latents_proper * mask) + (latents * (1 - mask)) |
777 | | - |
778 | | - # call the callback, if provided |
779 | | - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
780 | | - progress_bar.update() |
781 | | - if i % callback_steps == 0: |
782 | | - if callback is not None: |
783 | | - callback(i, t, latents) |
784 | | - if is_cancelled_callback is not None and is_cancelled_callback(): |
785 | | - return None |
| 826 | + for i, t in enumerate(self.progress_bar(timesteps)): |
| 827 | + # expand the latents if we are doing classifier free guidance |
| 828 | + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 829 | + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| 830 | + |
| 831 | + # predict the noise residual |
| 832 | + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| 833 | + |
| 834 | + # perform guidance |
| 835 | + if do_classifier_free_guidance: |
| 836 | + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| 837 | + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 838 | + |
| 839 | + # compute the previous noisy sample x_t -> x_t-1 |
| 840 | + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| 841 | + |
| 842 | + if mask is not None: |
| 843 | + # masking |
| 844 | + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) |
| 845 | + latents = (init_latents_proper * mask) + (latents * (1 - mask)) |
| 846 | + |
| 847 | + # call the callback, if provided |
| 848 | + if i % callback_steps == 0: |
| 849 | + if callback is not None: |
| 850 | + callback(i, t, latents) |
| 851 | + if is_cancelled_callback is not None and is_cancelled_callback(): |
| 852 | + return None |
786 | 853 |
|
787 | 854 | # 9. Post-processing |
788 | 855 | image = self.decode_latents(latents) |
|
0 commit comments