@@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
6363 ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
6464 ... ).to("cuda")
6565
66- >>> pipe("horse", generator=torch.manual_seed(0)).images
66+ >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
67+ >>> image
6768 ```
6869 """
6970
@@ -72,6 +73,7 @@ def __init__(
7273 self ,
7374 scaling_factor : float = 0.18215 ,
7475 latent_channels : int = 4 ,
76+ sample_size : int = 32 ,
7577 encoder_act_fn : str = "silu" ,
7678 encoder_block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 ),
7779 encoder_double_z : bool = True ,
@@ -153,6 +155,16 @@ def __init__(
153155 self .use_slicing = False
154156 self .use_tiling = False
155157
158+ # only relevant if vae tiling is enabled
159+ self .tile_sample_min_size = self .config .sample_size
160+ sample_size = (
161+ self .config .sample_size [0 ]
162+ if isinstance (self .config .sample_size , (list , tuple ))
163+ else self .config .sample_size
164+ )
165+ self .tile_latent_min_size = int (sample_size / (2 ** (len (self .config .block_out_channels ) - 1 )))
166+ self .tile_overlap_factor = 0.25
167+
156168 # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
157169 def enable_tiling (self , use_tiling : bool = True ):
158170 r"""
@@ -272,7 +284,7 @@ def encode(
272284 Args:
273285 x (`torch.FloatTensor`): Input batch of images.
274286 return_dict (`bool`, *optional*, defaults to `True`):
275- Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput `] instead of a plain
287+ Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput `] instead of a plain
276288 tuple.
277289
278290 Returns:
@@ -305,6 +317,19 @@ def decode(
305317 return_dict : bool = True ,
306318 num_inference_steps : int = 2 ,
307319 ) -> Union [DecoderOutput , Tuple [torch .FloatTensor ]]:
320+ """
321+ Decodes the input latent vector `z` using the consistency decoder VAE model.
322+
323+ Args:
324+ z (torch.FloatTensor): The input latent vector.
325+ generator (Optional[torch.Generator]): The random number generator. Default is None.
326+ return_dict (bool): Whether to return the output as a dictionary. Default is True.
327+ num_inference_steps (int): The number of inference steps. Default is 2.
328+
329+ Returns:
330+ Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output.
331+
332+ """
308333 z = (z * self .config .scaling_factor - self .means ) / self .stds
309334
310335 scale_factor = 2 ** (len (self .config .block_out_channels ) - 1 )
@@ -345,7 +370,9 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
345370 b [:, :, :, x ] = a [:, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, x ] * (x / blend_extent )
346371 return b
347372
348- def tiled_encode (self , x : torch .FloatTensor , return_dict : bool = True ) -> ConsistencyDecoderVAEOutput :
373+ def tiled_encode (
374+ self , x : torch .FloatTensor , return_dict : bool = True
375+ ) -> Union [ConsistencyDecoderVAEOutput , Tuple ]:
349376 r"""Encode a batch of images using a tiled encoder.
350377
351378 When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
0 commit comments