Skip to content

Commit e312b23

Browse files
authored
[LoRA] support LyCORIS (huggingface#5102)
* better condition. * debugging * how about now? * how about now? * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * support for lycoris. * style * add: lycoris test * fix from_pretrained call. * fix assertion values.
1 parent 8263cf0 commit e312b23

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

src/diffusers/loaders.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1878,7 +1878,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
18781878
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
18791879

18801880
# SDXL specificity.
1881-
if "emb" in diffusers_name:
1881+
if "emb" in diffusers_name and "time" not in diffusers_name:
18821882
pattern = r"\.\d+(?=\D*$)"
18831883
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
18841884
if ".in." in diffusers_name:
@@ -1890,6 +1890,13 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
18901890
if "skip" in diffusers_name:
18911891
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
18921892

1893+
# LyCORIS specificity.
1894+
if "time" in diffusers_name:
1895+
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
1896+
if "conv.shortcut" in diffusers_name:
1897+
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
1898+
1899+
# General coverage.
18931900
if "transformer_blocks" in diffusers_name:
18941901
if "attn1" in diffusers_name or "attn2" in diffusers_name:
18951902
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")

src/diffusers/models/embeddings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import nn
2020

2121
from .activations import get_activation
22+
from .lora import LoRACompatibleLinear
2223

2324

2425
def get_timestep_embedding(
@@ -166,7 +167,7 @@ def __init__(
166167
):
167168
super().__init__()
168169

169-
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
170+
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
170171

171172
if cond_proj_dim is not None:
172173
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -179,7 +180,7 @@ def __init__(
179180
time_embed_dim_out = out_dim
180181
else:
181182
time_embed_dim_out = time_embed_dim
182-
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
183+
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
183184

184185
if post_act_fn is None:
185186
self.post_act = None

tests/lora/test_lora_layers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,25 @@ def test_a1111(self):
18761876

18771877
self.assertTrue(np.allclose(images, expected, atol=1e-3))
18781878

1879+
def test_lycoris(self):
1880+
generator = torch.Generator().manual_seed(0)
1881+
1882+
pipe = StableDiffusionPipeline.from_pretrained(
1883+
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
1884+
).to(torch_device)
1885+
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
1886+
lora_filename = "edgLycorisMugler-light.safetensors"
1887+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
1888+
1889+
images = pipe(
1890+
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
1891+
).images
1892+
1893+
images = images[0, -3:, -3:, -1].flatten()
1894+
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])
1895+
1896+
self.assertTrue(np.allclose(images, expected, atol=1e-3))
1897+
18791898
def test_a1111_with_model_cpu_offload(self):
18801899
generator = torch.Generator().manual_seed(0)
18811900

0 commit comments

Comments
 (0)