Skip to content

[feat] allow SDXL pipeline to run with fused QKV projections #6030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Dec 6, 2023
Merged

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 3, 2023

What does this PR do?

Adds an option for running the SDXL pipeline with QKV projections fused. For self-attention, all the projection matrices are horizontally fused. For cross-attention, key and value projection matrices are fused.

Some more comments are inline.

A lot of thanks to @cpuhrsch for helping.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice! Could you add some code that would allow to test the speed-ups?

@sayakpaul
Copy link
Member Author

sayakpaul commented Dec 4, 2023

Looks very nice! Could you add some code that would allow to test the speed-ups?

https://github.com/sayakpaul/sdxl-fast

Also, note that this is just one of the many things needed to speed things up. But throughput-wise it does contribute and one can check with the following code:

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
pipeline.fuse_qkv_projections()

for _ in range(5):
    _ = pipeline("hey", num_inference_steps=25).images[0]

pipeline.unfuse_qkv_projections()

for _ in range(5):
    _ = pipeline("hey", num_inference_steps=25).images[0]

You should see speedup in the throughput.

@patrickvonplaten
Copy link
Contributor

Indeed, getting a nice 5% speed-up!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool feature!

@sayakpaul
Copy link
Member Author

@DN6 any reason why the fetcher is failing?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice job here!

@sayakpaul sayakpaul merged commit a2bc2e1 into main Dec 6, 2023
@sayakpaul sayakpaul deleted the sdxl/feat branch December 6, 2023 02:03
@sayakpaul sayakpaul restored the sdxl/feat branch December 6, 2023 02:12
@@ -289,6 +290,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)

self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if sigmas.device.type == "cuda":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this breaks the add_noise function and thus inpaint and img2img and training

donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 18, 2023
…face#6030)

* debug

* from step

* print

* turn sigma a list

* make str

* init_noise_sigma

* comment

* remove prints

* feat: introduce fused projections

* change to a better name

* no grad

* device.

* device

* dtype

* okay

* print

* more print

* fix: unbind -> split

* fix: qkv >-> k

* enable disable

* apply attention processor within the method

* attn processors

* _enable_fused_qkv_projections

* remove print

* add fused projection to vae

* add todos.

* add: documentation and cleanups.

* add: test for qkv projection fusion.

* relax assertions.

* relax further

* fix: docs

* fix-copies

* correct error message.

* Empty-Commit

* better conditioning on disable_fused_qkv_projections

* check

* check processor

* bfloat16 computation.

* check latent dtype

* style

* remove copy temporarily

* cast latent to bfloat16

* fix: vae -> self.vae

* remove print.

* add _change_to_group_norm_32

* comment out stuff that didn't work

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* reflect patrick's suggestions.

* fix imports

* fix: disable call.

* fix more

* fix device and dtype

* fix conditions.

* fix more

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…face#6030)

* debug

* from step

* print

* turn sigma a list

* make str

* init_noise_sigma

* comment

* remove prints

* feat: introduce fused projections

* change to a better name

* no grad

* device.

* device

* dtype

* okay

* print

* more print

* fix: unbind -> split

* fix: qkv >-> k

* enable disable

* apply attention processor within the method

* attn processors

* _enable_fused_qkv_projections

* remove print

* add fused projection to vae

* add todos.

* add: documentation and cleanups.

* add: test for qkv projection fusion.

* relax assertions.

* relax further

* fix: docs

* fix-copies

* correct error message.

* Empty-Commit

* better conditioning on disable_fused_qkv_projections

* check

* check processor

* bfloat16 computation.

* check latent dtype

* style

* remove copy temporarily

* cast latent to bfloat16

* fix: vae -> self.vae

* remove print.

* add _change_to_group_norm_32

* comment out stuff that didn't work

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* reflect patrick's suggestions.

* fix imports

* fix: disable call.

* fix more

* fix device and dtype

* fix conditions.

* fix more

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…face#6030)

* debug

* from step

* print

* turn sigma a list

* make str

* init_noise_sigma

* comment

* remove prints

* feat: introduce fused projections

* change to a better name

* no grad

* device.

* device

* dtype

* okay

* print

* more print

* fix: unbind -> split

* fix: qkv >-> k

* enable disable

* apply attention processor within the method

* attn processors

* _enable_fused_qkv_projections

* remove print

* add fused projection to vae

* add todos.

* add: documentation and cleanups.

* add: test for qkv projection fusion.

* relax assertions.

* relax further

* fix: docs

* fix-copies

* correct error message.

* Empty-Commit

* better conditioning on disable_fused_qkv_projections

* check

* check processor

* bfloat16 computation.

* check latent dtype

* style

* remove copy temporarily

* cast latent to bfloat16

* fix: vae -> self.vae

* remove print.

* add _change_to_group_norm_32

* comment out stuff that didn't work

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* reflect patrick's suggestions.

* fix imports

* fix: disable call.

* fix more

* fix device and dtype

* fix conditions.

* fix more

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants