@@ -72,7 +72,7 @@ class PeftLoraLoaderMixinTests:
7272 unet_kwargs = None
7373 vae_kwargs = None
7474
75- def get_dummy_components (self , scheduler_cls = None ):
75+ def get_dummy_components (self , scheduler_cls = None , use_dora = False ):
7676 scheduler_cls = self .scheduler_cls if scheduler_cls is None else scheduler_cls
7777 rank = 4
7878
@@ -96,10 +96,15 @@ def get_dummy_components(self, scheduler_cls=None):
9696 lora_alpha = rank ,
9797 target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ],
9898 init_lora_weights = False ,
99+ use_dora = use_dora ,
99100 )
100101
101102 unet_lora_config = LoraConfig (
102- r = rank , lora_alpha = rank , target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ], init_lora_weights = False
103+ r = rank ,
104+ lora_alpha = rank ,
105+ target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ],
106+ init_lora_weights = False ,
107+ use_dora = use_dora ,
103108 )
104109
105110 if self .has_two_text_encoders :
@@ -1074,6 +1079,37 @@ def test_simple_inference_with_text_lora_unet_fused_multi(self):
10741079 "Fused lora should not change the output" ,
10751080 )
10761081
1082+ @require_peft_version_greater (peft_version = "0.9.0" )
1083+ def test_simple_inference_with_dora (self ):
1084+ for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
1085+ components , text_lora_config , unet_lora_config = self .get_dummy_components (scheduler_cls , use_dora = True )
1086+ pipe = self .pipeline_class (** components )
1087+ pipe = pipe .to (torch_device )
1088+ pipe .set_progress_bar_config (disable = None )
1089+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1090+
1091+ output_no_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 )).images
1092+ self .assertTrue (output_no_dora_lora .shape == (1 , 64 , 64 , 3 ))
1093+
1094+ pipe .text_encoder .add_adapter (text_lora_config )
1095+ pipe .unet .add_adapter (unet_lora_config )
1096+
1097+ self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
1098+ self .assertTrue (check_if_lora_correctly_set (pipe .unet ), "Lora not correctly set in Unet" )
1099+
1100+ if self .has_two_text_encoders :
1101+ pipe .text_encoder_2 .add_adapter (text_lora_config )
1102+ self .assertTrue (
1103+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
1104+ )
1105+
1106+ output_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 )).images
1107+
1108+ self .assertFalse (
1109+ np .allclose (output_dora_lora , output_no_dora_lora , atol = 1e-3 , rtol = 1e-3 ),
1110+ "DoRA lora should change the output" ,
1111+ )
1112+
10771113 @unittest .skip ("This is failing for now - need to investigate" )
10781114 def test_simple_inference_with_text_unet_lora_unfused_torch_compile (self ):
10791115 """
0 commit comments