-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[feat] allow SDXL pipeline to run with fused QKV projections #6030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice! Could you add some code that would allow to test the speed-ups?
https://github.com/sayakpaul/sdxl-fast Also, note that this is just one of the many things needed to speed things up. But throughput-wise it does contribute and one can check with the following code: from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
pipeline.fuse_qkv_projections()
for _ in range(5):
_ = pipeline("hey", num_inference_steps=25).images[0]
pipeline.unfuse_qkv_projections()
for _ in range(5):
_ = pipeline("hey", num_inference_steps=25).images[0] You should see speedup in the throughput. |
Co-authored-by: Patrick von Platen <[email protected]>
Indeed, getting a nice 5% speed-up! |
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool feature!
Co-authored-by: Patrick von Platen <[email protected]>
@DN6 any reason why the fetcher is failing? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice job here!
@@ -289,6 +290,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |||
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) | |||
|
|||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) | |||
if sigmas.device.type == "cuda": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this breaks the add_noise
function and thus inpaint and img2img and training
…face#6030) * debug * from step * print * turn sigma a list * make str * init_noise_sigma * comment * remove prints * feat: introduce fused projections * change to a better name * no grad * device. * device * dtype * okay * print * more print * fix: unbind -> split * fix: qkv >-> k * enable disable * apply attention processor within the method * attn processors * _enable_fused_qkv_projections * remove print * add fused projection to vae * add todos. * add: documentation and cleanups. * add: test for qkv projection fusion. * relax assertions. * relax further * fix: docs * fix-copies * correct error message. * Empty-Commit * better conditioning on disable_fused_qkv_projections * check * check processor * bfloat16 computation. * check latent dtype * style * remove copy temporarily * cast latent to bfloat16 * fix: vae -> self.vae * remove print. * add _change_to_group_norm_32 * comment out stuff that didn't work * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * reflect patrick's suggestions. * fix imports * fix: disable call. * fix more * fix device and dtype * fix conditions. * fix more * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
…face#6030) * debug * from step * print * turn sigma a list * make str * init_noise_sigma * comment * remove prints * feat: introduce fused projections * change to a better name * no grad * device. * device * dtype * okay * print * more print * fix: unbind -> split * fix: qkv >-> k * enable disable * apply attention processor within the method * attn processors * _enable_fused_qkv_projections * remove print * add fused projection to vae * add todos. * add: documentation and cleanups. * add: test for qkv projection fusion. * relax assertions. * relax further * fix: docs * fix-copies * correct error message. * Empty-Commit * better conditioning on disable_fused_qkv_projections * check * check processor * bfloat16 computation. * check latent dtype * style * remove copy temporarily * cast latent to bfloat16 * fix: vae -> self.vae * remove print. * add _change_to_group_norm_32 * comment out stuff that didn't work * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * reflect patrick's suggestions. * fix imports * fix: disable call. * fix more * fix device and dtype * fix conditions. * fix more * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
…face#6030) * debug * from step * print * turn sigma a list * make str * init_noise_sigma * comment * remove prints * feat: introduce fused projections * change to a better name * no grad * device. * device * dtype * okay * print * more print * fix: unbind -> split * fix: qkv >-> k * enable disable * apply attention processor within the method * attn processors * _enable_fused_qkv_projections * remove print * add fused projection to vae * add todos. * add: documentation and cleanups. * add: test for qkv projection fusion. * relax assertions. * relax further * fix: docs * fix-copies * correct error message. * Empty-Commit * better conditioning on disable_fused_qkv_projections * check * check processor * bfloat16 computation. * check latent dtype * style * remove copy temporarily * cast latent to bfloat16 * fix: vae -> self.vae * remove print. * add _change_to_group_norm_32 * comment out stuff that didn't work * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * reflect patrick's suggestions. * fix imports * fix: disable call. * fix more * fix device and dtype * fix conditions. * fix more * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
What does this PR do?
Adds an option for running the SDXL pipeline with QKV projections fused. For self-attention, all the projection matrices are horizontally fused. For cross-attention, key and value projection matrices are fused.
Some more comments are inline.
A lot of thanks to @cpuhrsch for helping.