@@ -1496,7 +1496,7 @@ def prepare_init_args_and_inputs_for_common(self):
14961496 inputs_dict = self .dummy_input
14971497 return init_dict , inputs_dict
14981498
1499- def test_lora_processors (self ):
1499+ def test_lora_at_different_scales (self ):
15001500 # enable deterministic behavior for gradient checkpointing
15011501 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
15021502
@@ -1514,9 +1514,6 @@ def test_lora_processors(self):
15141514 model .load_attn_procs (lora_params )
15151515 model .to (torch_device )
15161516
1517- # test that attn processors can be set to itself
1518- model .set_attn_processor (model .attn_processors )
1519-
15201517 with torch .no_grad ():
15211518 sample2 = model (** inputs_dict , cross_attention_kwargs = {"scale" : 0.0 }).sample
15221519 sample3 = model (** inputs_dict , cross_attention_kwargs = {"scale" : 0.5 }).sample
@@ -1595,7 +1592,7 @@ def test_lora_xformers_on_off(self, expected_max_diff=6e-4):
15951592
15961593
15971594@deprecate_after_peft_backend
1598- class UNet3DConditionModelTests (unittest .TestCase ):
1595+ class UNet3DConditionLoRAModelTests (unittest .TestCase ):
15991596 model_class = UNet3DConditionModel
16001597 main_input_name = "sample"
16011598
@@ -1638,7 +1635,7 @@ def prepare_init_args_and_inputs_for_common(self):
16381635 inputs_dict = self .dummy_input
16391636 return init_dict , inputs_dict
16401637
1641- def test_lora_processors (self ):
1638+ def test_lora_at_different_scales (self ):
16421639 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
16431640
16441641 init_dict ["attention_head_dim" ] = 8
@@ -1655,9 +1652,6 @@ def test_lora_processors(self):
16551652 model .load_attn_procs (unet_lora_params )
16561653 model .to (torch_device )
16571654
1658- # test that attn processors can be set to itself
1659- model .set_attn_processor (model .attn_processors )
1660-
16611655 with torch .no_grad ():
16621656 sample2 = model (** inputs_dict , cross_attention_kwargs = {"scale" : 0.0 }).sample
16631657 sample3 = model (** inputs_dict , cross_attention_kwargs = {"scale" : 0.5 }).sample
0 commit comments