@@ -436,6 +436,12 @@ def parse_args(input_args=None):
436436 default = None ,
437437 help = "The optional `class_label` conditioning to pass to the unet, available values are `timesteps`." ,
438438 )
439+ parser .add_argument (
440+ "--rank" ,
441+ type = int ,
442+ default = 4 ,
443+ help = ("The dimension of the LoRA update matrices." ),
444+ )
439445
440446 if input_args is not None :
441447 args = parser .parse_args (input_args )
@@ -845,7 +851,9 @@ def main(args):
845851 LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
846852 )
847853 unet_lora_attn_procs [name ] = lora_attn_processor_class (
848- hidden_size = hidden_size , cross_attention_dim = cross_attention_dim
854+ hidden_size = hidden_size ,
855+ cross_attention_dim = cross_attention_dim ,
856+ rank = args .rank ,
849857 )
850858
851859 unet .set_attn_processor (unet_lora_attn_procs )
@@ -860,7 +868,9 @@ def main(args):
860868 for name , module in text_encoder .named_modules ():
861869 if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
862870 text_lora_attn_procs [name ] = LoRAAttnProcessor (
863- hidden_size = module .out_proj .out_features , cross_attention_dim = None
871+ hidden_size = module .out_proj .out_features ,
872+ cross_attention_dim = None ,
873+ rank = args .rank ,
864874 )
865875 text_encoder_lora_layers = AttnProcsLayers (text_lora_attn_procs )
866876 temp_pipeline = DiffusionPipeline .from_pretrained (
0 commit comments