@@ -872,7 +872,9 @@ def main(args):
872872 LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
873873 )
874874
875- module = lora_attn_processor_class (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
875+ module = lora_attn_processor_class (
876+ hidden_size = hidden_size , cross_attention_dim = cross_attention_dim , rank = args .rank
877+ )
876878 unet_lora_attn_procs [name ] = module
877879 unet_lora_parameters .extend (module .parameters ())
878880
@@ -882,7 +884,7 @@ def main(args):
882884 # So, instead, we monkey-patch the forward calls of its attention-blocks.
883885 if args .train_text_encoder :
884886 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
885- text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 )
887+ text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 , rank = args . rank )
886888
887889 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
888890 def save_model_hook (models , weights , output_dir ):
@@ -1364,7 +1366,7 @@ def compute_text_embeddings(prompt):
13641366 pipeline = pipeline .to (accelerator .device )
13651367
13661368 # load attention processors
1367- pipeline .load_lora_weights (args .output_dir )
1369+ pipeline .load_lora_weights (args .output_dir , weight_name = "pytorch_lora_weights.bin" )
13681370
13691371 # run inference
13701372 images = []
0 commit comments