Skip to content

Commit d8bc1a4

Browse files
[Torch.compile] Fixes torch compile graph break (huggingface#4315)
* fix torch compile * Fix all * make style
1 parent 80c10d8 commit d8bc1a4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/models/lora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Optional
1616

17+
import torch.nn.functional as F
1718
from torch import nn
1819

1920

@@ -91,7 +92,9 @@ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
9192

9293
def forward(self, x):
9394
if self.lora_layer is None:
94-
return super().forward(x)
95+
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
96+
# see: https://github.com/huggingface/diffusers/pull/4315
97+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
9598
else:
9699
return super().forward(x) + self.lora_layer(x)
97100

0 commit comments

Comments
 (0)