Skip to content

Commit 43979c2

Browse files
TST Fix LoRA test that fails with PEFT >= 0.7.0 (huggingface#6216)
See huggingface#6185 for context. Co-authored-by: Sayak Paul <[email protected]>
1 parent 9ea6ac1 commit 43979c2

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/lora/test_lora_layers_peft.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
import importlib
1617
import os
1718
import tempfile
1819
import time
@@ -24,6 +25,7 @@
2425
import torch.nn.functional as F
2526
from huggingface_hub import hf_hub_download
2627
from huggingface_hub.repocard import RepoCard
28+
from packaging import version
2729
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
2830

2931
from diffusers import (
@@ -1983,10 +1985,26 @@ def test_sdxl_1_0_fuse_unfuse_all(self):
19831985
fused_te_2_state_dict = pipe.text_encoder_2.state_dict()
19841986
unet_state_dict = pipe.unet.state_dict()
19851987

1988+
peft_ge_070 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0")
1989+
1990+
def remap_key(key, sd):
1991+
# some keys have moved around for PEFT >= 0.7.0, but they should still be loaded correctly
1992+
if (key in sd) or (not peft_ge_070):
1993+
return key
1994+
1995+
# instead of linear.weight, we now have linear.base_layer.weight, etc.
1996+
if key.endswith(".weight"):
1997+
key = key[:-7] + ".base_layer.weight"
1998+
elif key.endswith(".bias"):
1999+
key = key[:-5] + ".base_layer.bias"
2000+
return key
2001+
19862002
for key, value in text_encoder_1_sd.items():
2003+
key = remap_key(key, fused_te_state_dict)
19872004
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
19882005

19892006
for key, value in text_encoder_2_sd.items():
2007+
key = remap_key(key, fused_te_2_state_dict)
19902008
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
19912009

19922010
for key, value in unet_state_dict.items():

0 commit comments

Comments
 (0)