-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Torch.compile] Fixes torch compile graph break #4315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@Chillee to me this looks like something that should be fixed in downstream PyTorch. This is more of a quick'n'dirty fix. To reproduce: pip install diffusers==0.19.0 transformers and then run: #!/usr/bin/env python3
from diffusers import StableDiffusionPipeline
import torch
path = "runwayml/stable-diffusion-v1-5"
run_compile = True # Set True / False
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
pipe = pipe.to("cuda:0")
pipe.unet.to(memory_format=torch.channels_last)
if run_compile:
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
prompt = "ghibli style, a fantasy landscape with castles"
for _ in range(3):
images = pipe(prompt=prompt).images which gives:
which is related to the changes done in this PR. This PR fixes the problem, but it's not clear why the previous code did not work out of the box for PyTorch. |
The documentation is not available anymore as the PR was closed or merged. |
* fix torch compile * Fix all * make style
* fix torch compile * Fix all * make style
* fix torch compile * Fix all * make style
* fix torch compile * Fix all * make style
It looks like this fix will break the seamless texture solution by setting conv2d.init padding to "circular". In lora.py, forward(), when lora_layer is None, F.conv() doesn't handle padding as super().forward does. |
yes, "F.conv()" causes the problem the seamless texture. cannot set the parameter "padding_mode". @patrickvonplaten |
Aiii ok, we can probably revert this change once Torch 2.1 is out as this has been natively fixed in PyTorch now: pytorch/pytorch#106402 Can you maybe open a new issue here so that we can track the issue? |
* fix torch compile * Fix all * make style
* fix torch compile * Fix all * make style
For some reason one cannot create a parent class of
nn.Conv2d
and then callsuper().forward(...)
in it withtorch.compile
.See issue: #4305