-
Notifications
You must be signed in to change notification settings - Fork 6.5k
diffusers#7426 fix stable diffusion xl inference on MPS when dtypes shift unexpectedly due to pytorch bugs #7446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ec67986 to
721dd57
Compare
|
@sayakpaul i included the other pipelines just to keep them consistent even though normally inference does not hit this problem. it is likely nice to fix it for the edge cases where it would |
| # compute the previous noisy sample x_t -> x_t-1 | ||
| old_dtype = latents.dtype | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | ||
| if latents.dtype != old_dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be guarded with torch.backends.mps.is_available() as well? It really seems to be happening only when mps is picked up no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit: I'd maybe call it latents_dtype; old sounds almost like it'd be ok if it changes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be guarded with
torch.backends.mps.is_available()
In my opinion, it's ok the way it is as the comment already mentions mps, but no strong opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul @pcuenca it's about the outcome. if any other accelerator behaves in this way, would you rather it crash out of the box so that an issue is filed, or should we just invisibly fix it when we find it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about throwing a warning so that the users are at least aware of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or raise a specific error asking them to file a report?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i modified the code so that it only executes on mps, as i figured it would be better to have full visibility into any other platform's dtype issues after further consideration.
the current error however, is pretty vague. it says two types are broadcast incompatible, which isn't very helpful to a new user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah raising an error sounds like a better idea. @pcuenca WDYT?
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree to add these workarounds to unblock use on mps. Thanks @bghira!
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
| # compute the previous noisy sample x_t -> x_t-1 | ||
| old_dtype = latents.dtype | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | ||
| if latents.dtype != old_dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit: I'd maybe call it latents_dtype; old sounds almost like it'd be ok if it changes).
| # compute the previous noisy sample x_t -> x_t-1 | ||
| old_dtype = latents.dtype | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | ||
| if latents.dtype != old_dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be guarded with
torch.backends.mps.is_available()
In my opinion, it's ok the way it is as the comment already mentions mps, but no strong opinion.
0a76491 to
bc7c7e8
Compare
bc7c7e8 to
bd2a802
Compare
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
|
@yiyixuxu could you also give this a look? |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
|
@bghira okay if I applied the safe-guarding logic to the rest of the scripts and prep the PR for merging? |
|
yup |
|
@bghira give this a final look and I will merge then? |
|
nice, lgtm |
|
Thanks for your contributions, @bghira! |
…hift unexpectedly due to pytorch bugs (huggingface#7446) * mps: fix XL pipeline inference at training time due to upstream pytorch bug * Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Co-authored-by: Sayak Paul <[email protected]> * apply the safe-guarding logic elsewhere. --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
…hift unexpectedly due to pytorch bugs (#7446) * mps: fix XL pipeline inference at training time due to upstream pytorch bug * Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Co-authored-by: Sayak Paul <[email protected]> * apply the safe-guarding logic elsewhere. --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
Fixes #7426
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.