@@ -431,6 +431,10 @@ def norm_encoder_hidden_states(self, encoder_hidden_states):
431431
432432
433433class AttnProcessor :
434+ r"""
435+ Default processor for performing attention-related computations.
436+ """
437+
434438 def __call__ (
435439 self ,
436440 attn : Attention ,
@@ -516,6 +520,18 @@ def forward(self, hidden_states):
516520
517521
518522class LoRAAttnProcessor (nn .Module ):
523+ r"""
524+ Processor for implementing the LoRA attention mechanism.
525+
526+ Args:
527+ hidden_size (`int`, *optional*):
528+ The hidden size of the attention layer.
529+ cross_attention_dim (`int`, *optional*):
530+ The number of channels in the `encoder_hidden_states`.
531+ rank (`int`, defaults to 4):
532+ The dimension of the LoRA update matrices.
533+ """
534+
519535 def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 ):
520536 super ().__init__ ()
521537
@@ -580,6 +596,24 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580596
581597
582598class CustomDiffusionAttnProcessor (nn .Module ):
599+ r"""
600+ Processor for implementing attention for the Custom Diffusion method.
601+
602+ Args:
603+ train_kv (`bool`, defaults to `True`):
604+ Whether to newly train the key and value matrices corresponding to the text features.
605+ train_q_out (`bool`, defaults to `True`):
606+ Whether to newly train query matrices corresponding to the latent image features.
607+ hidden_size (`int`, *optional*, defaults to `None`):
608+ The hidden size of the attention layer.
609+ cross_attention_dim (`int`, *optional*, defaults to `None`):
610+ The number of channels in the `encoder_hidden_states`.
611+ out_bias (`bool`, defaults to `True`):
612+ Whether to include the bias parameter in `train_q_out`.
613+ dropout (`float`, *optional*, defaults to 0.0):
614+ The dropout probability to use.
615+ """
616+
583617 def __init__ (
584618 self ,
585619 train_kv = True ,
@@ -658,6 +692,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
658692
659693
660694class AttnAddedKVProcessor :
695+ r"""
696+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
697+ encoder.
698+ """
699+
661700 def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
662701 residual = hidden_states
663702 hidden_states = hidden_states .view (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 ).transpose (1 , 2 )
@@ -707,6 +746,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
707746
708747
709748class AttnAddedKVProcessor2_0 :
749+ r"""
750+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
751+ learnable key and value matrices for the text encoder.
752+ """
753+
710754 def __init__ (self ):
711755 if not hasattr (F , "scaled_dot_product_attention" ):
712756 raise ImportError (
@@ -765,6 +809,19 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
765809
766810
767811class LoRAAttnAddedKVProcessor (nn .Module ):
812+ r"""
813+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
814+ encoder.
815+
816+ Args:
817+ hidden_size (`int`, *optional*):
818+ The hidden size of the attention layer.
819+ cross_attention_dim (`int`, *optional*, defaults to `None`):
820+ The number of channels in the `encoder_hidden_states`.
821+ rank (`int`, defaults to 4):
822+ The dimension of the LoRA update matrices.
823+ """
824+
768825 def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 ):
769826 super ().__init__ ()
770827
@@ -832,6 +889,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
832889
833890
834891class XFormersAttnProcessor :
892+ r"""
893+ Processor for implementing memory efficient attention using xFormers.
894+
895+ Args:
896+ attention_op (`Callable`, *optional*, defaults to `None`):
897+ The base
898+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
899+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
900+ operator.
901+ """
902+
835903 def __init__ (self , attention_op : Optional [Callable ] = None ):
836904 self .attention_op = attention_op
837905
@@ -905,6 +973,10 @@ def __call__(
905973
906974
907975class AttnProcessor2_0 :
976+ r"""
977+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
978+ """
979+
908980 def __init__ (self ):
909981 if not hasattr (F , "scaled_dot_product_attention" ):
910982 raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
@@ -983,6 +1055,23 @@ def __call__(
9831055
9841056
9851057class LoRAXFormersAttnProcessor (nn .Module ):
1058+ r"""
1059+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1060+
1061+ Args:
1062+ hidden_size (`int`, *optional*):
1063+ The hidden size of the attention layer.
1064+ cross_attention_dim (`int`, *optional*):
1065+ The number of channels in the `encoder_hidden_states`.
1066+ rank (`int`, defaults to 4):
1067+ The dimension of the LoRA update matrices.
1068+ attention_op (`Callable`, *optional*, defaults to `None`):
1069+ The base
1070+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1071+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1072+ operator.
1073+ """
1074+
9861075 def __init__ (self , hidden_size , cross_attention_dim , rank = 4 , attention_op : Optional [Callable ] = None ):
9871076 super ().__init__ ()
9881077
@@ -1049,6 +1138,28 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
10491138
10501139
10511140class CustomDiffusionXFormersAttnProcessor (nn .Module ):
1141+ r"""
1142+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1143+
1144+ Args:
1145+ train_kv (`bool`, defaults to `True`):
1146+ Whether to newly train the key and value matrices corresponding to the text features.
1147+ train_q_out (`bool`, defaults to `True`):
1148+ Whether to newly train query matrices corresponding to the latent image features.
1149+ hidden_size (`int`, *optional*, defaults to `None`):
1150+ The hidden size of the attention layer.
1151+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1152+ The number of channels in the `encoder_hidden_states`.
1153+ out_bias (`bool`, defaults to `True`):
1154+ Whether to include the bias parameter in `train_q_out`.
1155+ dropout (`float`, *optional*, defaults to 0.0):
1156+ The dropout probability to use.
1157+ attention_op (`Callable`, *optional*, defaults to `None`):
1158+ The base
1159+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1160+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1161+ """
1162+
10521163 def __init__ (
10531164 self ,
10541165 train_kv = True ,
@@ -1134,6 +1245,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
11341245
11351246
11361247class SlicedAttnProcessor :
1248+ r"""
1249+ Processor for implementing sliced attention.
1250+
1251+ Args:
1252+ slice_size (`int`, *optional*):
1253+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1254+ `attention_head_dim` must be a multiple of the `slice_size`.
1255+ """
1256+
11371257 def __init__ (self , slice_size ):
11381258 self .slice_size = slice_size
11391259
@@ -1206,6 +1326,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
12061326
12071327
12081328class SlicedAttnAddedKVProcessor :
1329+ r"""
1330+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1331+
1332+ Args:
1333+ slice_size (`int`, *optional*):
1334+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1335+ `attention_head_dim` must be a multiple of the `slice_size`.
1336+ """
1337+
12091338 def __init__ (self , slice_size ):
12101339 self .slice_size = slice_size
12111340
0 commit comments