2121import torch
2222from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
2323
24+ from ...image_processor import VaeImageProcessor
2425from ...loaders import TextualInversionLoaderMixin
2526from ...models import AutoencoderKL , UNet2DConditionModel
2627from ...models .attention_processor import AttnProcessor2_0 , LoRAXFormersAttnProcessor , XFormersAttnProcessor
3435
3536
3637def preprocess (image ):
38+ warnings .warn (
39+ "The preprocess method is deprecated and will be removed in a future version. Please"
40+ " use VaeImageProcessor.preprocess instead" ,
41+ FutureWarning ,
42+ )
3743 if isinstance (image , torch .Tensor ):
3844 return image
3945 elif isinstance (image , PIL .Image .Image ):
@@ -125,6 +131,8 @@ def __init__(
125131 watermarker = watermarker ,
126132 feature_extractor = feature_extractor ,
127133 )
134+ self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
135+ self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor , resample = "bicubic" )
128136 self .register_to_config (max_noise_level = max_noise_level )
129137
130138 def enable_sequential_cpu_offload (self , gpu_id = 0 ):
@@ -432,14 +440,15 @@ def check_inputs(
432440 if (
433441 not isinstance (image , torch .Tensor )
434442 and not isinstance (image , PIL .Image .Image )
443+ and not isinstance (image , np .ndarray )
435444 and not isinstance (image , list )
436445 ):
437446 raise ValueError (
438- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is { type (image )} "
447+ f"`image` has to be of type `torch.Tensor`, `np.ndarray`, ` PIL.Image.Image` or `list` but is { type (image )} "
439448 )
440449
441- # verify batch size of prompt and image are same if image is a list or tensor
442- if isinstance (image , list ) or isinstance (image , torch .Tensor ):
450+ # verify batch size of prompt and image are same if image is a list or tensor or numpy array
451+ if isinstance (image , list ) or isinstance (image , torch .Tensor ) or isinstance ( image , np . ndarray ) :
443452 if isinstance (prompt , str ):
444453 batch_size = 1
445454 else :
@@ -483,7 +492,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
483492 def __call__ (
484493 self ,
485494 prompt : Union [str , List [str ]] = None ,
486- image : Union [torch .FloatTensor , PIL .Image .Image , List [PIL .Image .Image ]] = None ,
495+ image : Union [
496+ torch .FloatTensor ,
497+ PIL .Image .Image ,
498+ np .ndarray ,
499+ List [torch .FloatTensor ],
500+ List [PIL .Image .Image ],
501+ List [np .ndarray ],
502+ ] = None ,
487503 num_inference_steps : int = 75 ,
488504 guidance_scale : float = 9.0 ,
489505 noise_level : int = 20 ,
@@ -506,7 +522,7 @@ def __call__(
506522 prompt (`str` or `List[str]`, *optional*):
507523 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
508524 instead.
509- image (`PIL.Image.Image` or List[` PIL.Image.Image`] or `torch.FloatTensor `):
525+ image (`torch.FloatTensor`, ` PIL.Image.Image`, `np.ndarray`, ` List[torch.FloatTensor]`, `List[ PIL.Image.Image]`, or `List[np.ndarray] `):
510526 `Image`, or tensor representing an image batch which will be upscaled. *
511527 num_inference_steps (`int`, *optional*, defaults to 50):
512528 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -627,7 +643,7 @@ def __call__(
627643 )
628644
629645 # 4. Preprocess image
630- image = preprocess (image )
646+ image = self . image_processor . preprocess (image )
631647 image = image .to (dtype = prompt_embeds .dtype , device = device )
632648
633649 # 5. set timesteps
@@ -723,25 +739,25 @@ def __call__(
723739 else :
724740 latents = latents .float ()
725741
726- # 11. Convert to PIL
727- if output_type == "pil" :
728- image = self .decode_latents (latents )
729-
742+ # post-processing
743+ if not output_type == "latent" :
744+ image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
730745 image , has_nsfw_concept , _ = self .run_safety_checker (image , device , prompt_embeds .dtype )
731-
732- image = self .numpy_to_pil (image )
733-
734- # 11. Apply watermark
735- if self .watermarker is not None :
736- image = self .watermarker .apply_watermark (image )
737- elif output_type == "pt" :
738- latents = 1 / self .vae .config .scaling_factor * latents
739- image = self .vae .decode (latents ).sample
740- has_nsfw_concept = None
741746 else :
742- image = self . decode_latents ( latents )
747+ image = latents
743748 has_nsfw_concept = None
744749
750+ if has_nsfw_concept is None :
751+ do_denormalize = [True ] * image .shape [0 ]
752+ else :
753+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept ]
754+
755+ image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
756+
757+ # 11. Apply watermark
758+ if output_type == "pil" and self .watermarker is not None :
759+ image = self .watermarker .apply_watermark (image )
760+
745761 # Offload last model to CPU
746762 if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
747763 self .final_offload_hook .offload ()
0 commit comments