|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | import copy |
| 16 | +import importlib |
16 | 17 | import os |
17 | 18 | import tempfile |
18 | 19 | import time |
|
24 | 25 | import torch.nn.functional as F |
25 | 26 | from huggingface_hub import hf_hub_download |
26 | 27 | from huggingface_hub.repocard import RepoCard |
| 28 | +from packaging import version |
27 | 29 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer |
28 | 30 |
|
29 | 31 | from diffusers import ( |
@@ -1983,10 +1985,26 @@ def test_sdxl_1_0_fuse_unfuse_all(self): |
1983 | 1985 | fused_te_2_state_dict = pipe.text_encoder_2.state_dict() |
1984 | 1986 | unet_state_dict = pipe.unet.state_dict() |
1985 | 1987 |
|
| 1988 | + peft_ge_070 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0") |
| 1989 | + |
| 1990 | + def remap_key(key, sd): |
| 1991 | + # some keys have moved around for PEFT >= 0.7.0, but they should still be loaded correctly |
| 1992 | + if (key in sd) or (not peft_ge_070): |
| 1993 | + return key |
| 1994 | + |
| 1995 | + # instead of linear.weight, we now have linear.base_layer.weight, etc. |
| 1996 | + if key.endswith(".weight"): |
| 1997 | + key = key[:-7] + ".base_layer.weight" |
| 1998 | + elif key.endswith(".bias"): |
| 1999 | + key = key[:-5] + ".base_layer.bias" |
| 2000 | + return key |
| 2001 | + |
1986 | 2002 | for key, value in text_encoder_1_sd.items(): |
| 2003 | + key = remap_key(key, fused_te_state_dict) |
1987 | 2004 | self.assertTrue(torch.allclose(fused_te_state_dict[key], value)) |
1988 | 2005 |
|
1989 | 2006 | for key, value in text_encoder_2_sd.items(): |
| 2007 | + key = remap_key(key, fused_te_2_state_dict) |
1990 | 2008 | self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value)) |
1991 | 2009 |
|
1992 | 2010 | for key, value in unet_state_dict.items(): |
|
0 commit comments