Skip to content

Commit abd922b

Browse files
authored
[docs] unet type hints (huggingface#7134)
update
1 parent fa633ed commit abd922b

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
upcast_attention: bool = False,
205205
resnet_time_scale_shift: str = "default",
206206
resnet_skip_time_act: bool = False,
207-
resnet_out_scale_factor: int = 1.0,
207+
resnet_out_scale_factor: float = 1.0,
208208
time_embedding_type: str = "positional",
209209
time_embedding_dim: Optional[int] = None,
210210
time_embedding_act_fn: Optional[str] = None,
@@ -217,7 +217,7 @@ def __init__(
217217
class_embeddings_concat: bool = False,
218218
mid_block_only_cross_attention: Optional[bool] = None,
219219
cross_attention_norm: Optional[str] = None,
220-
addition_embed_type_num_heads=64,
220+
addition_embed_type_num_heads: int = 64,
221221
):
222222
super().__init__()
223223

@@ -485,9 +485,9 @@ def _check_config(
485485
up_block_types: Tuple[str],
486486
only_cross_attention: Union[bool, Tuple[bool]],
487487
block_out_channels: Tuple[int],
488-
layers_per_block: [int, Tuple[int]],
488+
layers_per_block: Union[int, Tuple[int]],
489489
cross_attention_dim: Union[int, Tuple[int]],
490-
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
490+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
491491
reverse_transformer_layers_per_block: bool,
492492
attention_head_dim: int,
493493
num_attention_heads: Optional[Union[int, Tuple[int]]],
@@ -762,7 +762,7 @@ def set_default_attn_processor(self):
762762

763763
self.set_attn_processor(processor)
764764

765-
def set_attention_slice(self, slice_size):
765+
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
766766
r"""
767767
Enable sliced attention computation.
768768
@@ -831,7 +831,7 @@ def _set_gradient_checkpointing(self, module, value=False):
831831
if hasattr(module, "gradient_checkpointing"):
832832
module.gradient_checkpointing = value
833833

834-
def enable_freeu(self, s1, s2, b1, b2):
834+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
835835
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
836836
837837
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -953,7 +953,7 @@ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Ten
953953
return class_emb
954954

955955
def get_aug_embed(
956-
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict
956+
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
957957
) -> Optional[torch.Tensor]:
958958
aug_emb = None
959959
if self.config.addition_embed_type == "text":
@@ -1004,7 +1004,9 @@ def get_aug_embed(
10041004
aug_emb = self.add_embedding(image_embs, hint)
10051005
return aug_emb
10061006

1007-
def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor:
1007+
def process_encoder_hidden_states(
1008+
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1009+
) -> torch.Tensor:
10081010
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
10091011
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
10101012
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":

0 commit comments

Comments
 (0)