Skip to content

Commit f242eba

Browse files
authored
Fix lpw stable diffusion pipeline compatibility (huggingface#1622)
1 parent 3faf204 commit f242eba

File tree

2 files changed

+254
-119
lines changed

2 files changed

+254
-119
lines changed

examples/community/lpw_stable_diffusion.py

Lines changed: 120 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,37 @@
55
import numpy as np
66
import torch
77

8+
import diffusers
89
import PIL
910
from diffusers import SchedulerMixin, StableDiffusionPipeline
1011
from diffusers.models import AutoencoderKL, UNet2DConditionModel
1112
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
12-
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
13+
from diffusers.utils import deprecate, logging
14+
from packaging import version
1315
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
1416

1517

18+
try:
19+
from diffusers.utils import PIL_INTERPOLATION
20+
except ImportError:
21+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
22+
PIL_INTERPOLATION = {
23+
"linear": PIL.Image.Resampling.BILINEAR,
24+
"bilinear": PIL.Image.Resampling.BILINEAR,
25+
"bicubic": PIL.Image.Resampling.BICUBIC,
26+
"lanczos": PIL.Image.Resampling.LANCZOS,
27+
"nearest": PIL.Image.Resampling.NEAREST,
28+
}
29+
else:
30+
PIL_INTERPOLATION = {
31+
"linear": PIL.Image.LINEAR,
32+
"bilinear": PIL.Image.BILINEAR,
33+
"bicubic": PIL.Image.BICUBIC,
34+
"lanczos": PIL.Image.LANCZOS,
35+
"nearest": PIL.Image.NEAREST,
36+
}
37+
# ------------------------------------------------------------------------------
38+
1639
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1740

1841
re_attention = re.compile(
@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
404427
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
405428
"""
406429

407-
def __init__(
408-
self,
409-
vae: AutoencoderKL,
410-
text_encoder: CLIPTextModel,
411-
tokenizer: CLIPTokenizer,
412-
unet: UNet2DConditionModel,
413-
scheduler: SchedulerMixin,
414-
safety_checker: StableDiffusionSafetyChecker,
415-
feature_extractor: CLIPFeatureExtractor,
416-
requires_safety_checker: bool = True,
417-
):
418-
super().__init__(
419-
vae=vae,
420-
text_encoder=text_encoder,
421-
tokenizer=tokenizer,
422-
unet=unet,
423-
scheduler=scheduler,
424-
safety_checker=safety_checker,
425-
feature_extractor=feature_extractor,
426-
requires_safety_checker=requires_safety_checker,
427-
)
430+
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
431+
432+
def __init__(
433+
self,
434+
vae: AutoencoderKL,
435+
text_encoder: CLIPTextModel,
436+
tokenizer: CLIPTokenizer,
437+
unet: UNet2DConditionModel,
438+
scheduler: SchedulerMixin,
439+
safety_checker: StableDiffusionSafetyChecker,
440+
feature_extractor: CLIPFeatureExtractor,
441+
requires_safety_checker: bool = True,
442+
):
443+
super().__init__(
444+
vae=vae,
445+
text_encoder=text_encoder,
446+
tokenizer=tokenizer,
447+
unet=unet,
448+
scheduler=scheduler,
449+
safety_checker=safety_checker,
450+
feature_extractor=feature_extractor,
451+
requires_safety_checker=requires_safety_checker,
452+
)
453+
self.__init__additional__()
454+
455+
else:
456+
457+
def __init__(
458+
self,
459+
vae: AutoencoderKL,
460+
text_encoder: CLIPTextModel,
461+
tokenizer: CLIPTokenizer,
462+
unet: UNet2DConditionModel,
463+
scheduler: SchedulerMixin,
464+
safety_checker: StableDiffusionSafetyChecker,
465+
feature_extractor: CLIPFeatureExtractor,
466+
):
467+
super().__init__(
468+
vae=vae,
469+
text_encoder=text_encoder,
470+
tokenizer=tokenizer,
471+
unet=unet,
472+
scheduler=scheduler,
473+
safety_checker=safety_checker,
474+
feature_extractor=feature_extractor,
475+
)
476+
self.__init__additional__()
477+
478+
def __init__additional__(self):
479+
if not hasattr(self, "vae_scale_factor"):
480+
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
481+
482+
@property
483+
def _execution_device(self):
484+
r"""
485+
Returns the device on which the pipeline's models will be executed. After calling
486+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
487+
hooks.
488+
"""
489+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
490+
return self.device
491+
for module in self.unet.modules():
492+
if (
493+
hasattr(module, "_hf_hook")
494+
and hasattr(module._hf_hook, "execution_device")
495+
and module._hf_hook.execution_device is not None
496+
):
497+
return torch.device(module._hf_hook.execution_device)
498+
return self.device
428499

429500
def _encode_prompt(
430501
self,
@@ -752,37 +823,33 @@ def __call__(
752823
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
753824

754825
# 8. Denoising loop
755-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
756-
with self.progress_bar(total=num_inference_steps) as progress_bar:
757-
for i, t in enumerate(timesteps):
758-
# expand the latents if we are doing classifier free guidance
759-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
760-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
761-
762-
# predict the noise residual
763-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
764-
765-
# perform guidance
766-
if do_classifier_free_guidance:
767-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
768-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
769-
770-
# compute the previous noisy sample x_t -> x_t-1
771-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
772-
773-
if mask is not None:
774-
# masking
775-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
776-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
777-
778-
# call the callback, if provided
779-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
780-
progress_bar.update()
781-
if i % callback_steps == 0:
782-
if callback is not None:
783-
callback(i, t, latents)
784-
if is_cancelled_callback is not None and is_cancelled_callback():
785-
return None
826+
for i, t in enumerate(self.progress_bar(timesteps)):
827+
# expand the latents if we are doing classifier free guidance
828+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
829+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
830+
831+
# predict the noise residual
832+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
833+
834+
# perform guidance
835+
if do_classifier_free_guidance:
836+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
837+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
838+
839+
# compute the previous noisy sample x_t -> x_t-1
840+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
841+
842+
if mask is not None:
843+
# masking
844+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
845+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
846+
847+
# call the callback, if provided
848+
if i % callback_steps == 0:
849+
if callback is not None:
850+
callback(i, t, latents)
851+
if is_cancelled_callback is not None and is_cancelled_callback():
852+
return None
786853

787854
# 9. Post-processing
788855
image = self.decode_latents(latents)

0 commit comments

Comments
 (0)