@@ -204,7 +204,7 @@ def __init__(
204204 upcast_attention : bool = False ,
205205 resnet_time_scale_shift : str = "default" ,
206206 resnet_skip_time_act : bool = False ,
207- resnet_out_scale_factor : int = 1.0 ,
207+ resnet_out_scale_factor : float = 1.0 ,
208208 time_embedding_type : str = "positional" ,
209209 time_embedding_dim : Optional [int ] = None ,
210210 time_embedding_act_fn : Optional [str ] = None ,
@@ -217,7 +217,7 @@ def __init__(
217217 class_embeddings_concat : bool = False ,
218218 mid_block_only_cross_attention : Optional [bool ] = None ,
219219 cross_attention_norm : Optional [str ] = None ,
220- addition_embed_type_num_heads = 64 ,
220+ addition_embed_type_num_heads : int = 64 ,
221221 ):
222222 super ().__init__ ()
223223
@@ -485,9 +485,9 @@ def _check_config(
485485 up_block_types : Tuple [str ],
486486 only_cross_attention : Union [bool , Tuple [bool ]],
487487 block_out_channels : Tuple [int ],
488- layers_per_block : [int , Tuple [int ]],
488+ layers_per_block : Union [int , Tuple [int ]],
489489 cross_attention_dim : Union [int , Tuple [int ]],
490- transformer_layers_per_block : Union [int , Tuple [int ], Tuple [Tuple ]],
490+ transformer_layers_per_block : Union [int , Tuple [int ], Tuple [Tuple [ int ] ]],
491491 reverse_transformer_layers_per_block : bool ,
492492 attention_head_dim : int ,
493493 num_attention_heads : Optional [Union [int , Tuple [int ]]],
@@ -762,7 +762,7 @@ def set_default_attn_processor(self):
762762
763763 self .set_attn_processor (processor )
764764
765- def set_attention_slice (self , slice_size ):
765+ def set_attention_slice (self , slice_size : Union [ str , int , List [ int ]] = "auto" ):
766766 r"""
767767 Enable sliced attention computation.
768768
@@ -831,7 +831,7 @@ def _set_gradient_checkpointing(self, module, value=False):
831831 if hasattr (module , "gradient_checkpointing" ):
832832 module .gradient_checkpointing = value
833833
834- def enable_freeu (self , s1 , s2 , b1 , b2 ):
834+ def enable_freeu (self , s1 : float , s2 : float , b1 : float , b2 : float ):
835835 r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
836836
837837 The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -953,7 +953,7 @@ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Ten
953953 return class_emb
954954
955955 def get_aug_embed (
956- self , emb : torch .Tensor , encoder_hidden_states : torch .Tensor , added_cond_kwargs : Dict
956+ self , emb : torch .Tensor , encoder_hidden_states : torch .Tensor , added_cond_kwargs : Dict [ str , Any ]
957957 ) -> Optional [torch .Tensor ]:
958958 aug_emb = None
959959 if self .config .addition_embed_type == "text" :
@@ -1004,7 +1004,9 @@ def get_aug_embed(
10041004 aug_emb = self .add_embedding (image_embs , hint )
10051005 return aug_emb
10061006
1007- def process_encoder_hidden_states (self , encoder_hidden_states : torch .Tensor , added_cond_kwargs ) -> torch .Tensor :
1007+ def process_encoder_hidden_states (
1008+ self , encoder_hidden_states : torch .Tensor , added_cond_kwargs : Dict [str , Any ]
1009+ ) -> torch .Tensor :
10081010 if self .encoder_hid_proj is not None and self .config .encoder_hid_dim_type == "text_proj" :
10091011 encoder_hidden_states = self .encoder_hid_proj (encoder_hidden_states )
10101012 elif self .encoder_hid_proj is not None and self .config .encoder_hid_dim_type == "text_image_proj" :
0 commit comments