Skip to content

Commit 852dc76

Browse files
Support higher dimension LoRAs (huggingface#4625)
* Support higher dimension LoRAs * add: tests * fix: assertion values. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 064f150 commit 852dc76

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

src/diffusers/models/lora.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ class LoRALinearLayer(nn.Module):
2222
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
2323
super().__init__()
2424

25-
if rank > min(in_features, out_features):
26-
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
27-
2825
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
2926
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
3027
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
@@ -54,9 +51,6 @@ def __init__(
5451
):
5552
super().__init__()
5653

57-
if rank > min(in_features, out_features):
58-
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
59-
6054
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
6155
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
6256
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129

tests/models/test_lora_layers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,25 @@ def test_a1111(self):
779779

780780
self.assertTrue(np.allclose(images, expected, atol=1e-4))
781781

782+
def test_kohya_sd_v15_with_higher_dimensions(self):
783+
generator = torch.Generator().manual_seed(0)
784+
785+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
786+
torch_device
787+
)
788+
lora_model_id = "hf-internal-testing/urushisato-lora"
789+
lora_filename = "urushisato_v15.safetensors"
790+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
791+
792+
images = pipe(
793+
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
794+
).images
795+
796+
images = images[0, -3:, -3:, -1].flatten()
797+
expected = np.array([0.7165, 0.6616, 0.5833, 0.7504, 0.6718, 0.587, 0.6871, 0.6361, 0.5694])
798+
799+
self.assertTrue(np.allclose(images, expected, atol=1e-4))
800+
782801
def test_vanilla_funetuning(self):
783802
generator = torch.Generator().manual_seed(0)
784803

0 commit comments

Comments
 (0)