Skip to content

Conversation

@bghira
Copy link
Contributor

@bghira bghira commented Mar 23, 2024

What does this PR do?

Fixes #7426

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@bghira bghira force-pushed the bugfix/mps-inference-xl branch from ec67986 to 721dd57 Compare March 23, 2024 21:28
@bghira
Copy link
Contributor Author

bghira commented Mar 23, 2024

@sayakpaul i included the other pipelines just to keep them consistent even though normally inference does not hit this problem. it is likely nice to fix it for the edge cases where it would

# compute the previous noisy sample x_t -> x_t-1
old_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != old_dtype:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be guarded with torch.backends.mps.is_available() as well? It really seems to be happening only when mps is picked up no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit: I'd maybe call it latents_dtype; old sounds almost like it'd be ok if it changes).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be guarded with torch.backends.mps.is_available()

In my opinion, it's ok the way it is as the comment already mentions mps, but no strong opinion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul @pcuenca it's about the outcome. if any other accelerator behaves in this way, would you rather it crash out of the box so that an issue is filed, or should we just invisibly fix it when we find it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about throwing a warning so that the users are at least aware of it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or raise a specific error asking them to file a report?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i modified the code so that it only executes on mps, as i figured it would be better to have full visibility into any other platform's dtype issues after further consideration.

the current error however, is pretty vague. it says two types are broadcast incompatible, which isn't very helpful to a new user

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah raising an error sounds like a better idea. @pcuenca WDYT?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me the changes seem realistic and minimal enough to enable training on a different accelerator. So, I would be in favor of supporting this.

@yiyixuxu @pcuenca WDYT?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree to add these workarounds to unblock use on mps. Thanks @bghira!

# compute the previous noisy sample x_t -> x_t-1
old_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != old_dtype:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit: I'd maybe call it latents_dtype; old sounds almost like it'd be ok if it changes).

# compute the previous noisy sample x_t -> x_t-1
old_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != old_dtype:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be guarded with torch.backends.mps.is_available()

In my opinion, it's ok the way it is as the comment already mentions mps, but no strong opinion.

@bghira bghira force-pushed the bugfix/mps-inference-xl branch from 0a76491 to bc7c7e8 Compare March 24, 2024 13:38
@bghira bghira force-pushed the bugfix/mps-inference-xl branch from bc7c7e8 to bd2a802 Compare March 24, 2024 13:39
@sayakpaul sayakpaul requested a review from yiyixuxu March 24, 2024 15:11
@sayakpaul
Copy link
Member

@yiyixuxu could you also give this a look?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@sayakpaul
Copy link
Member

@bghira okay if I applied the safe-guarding logic to the rest of the scripts and prep the PR for merging?

@bghira
Copy link
Contributor Author

bghira commented Mar 26, 2024

yup

@sayakpaul
Copy link
Member

@bghira give this a final look and I will merge then?

@bghira
Copy link
Contributor Author

bghira commented Mar 26, 2024

nice, lgtm

@sayakpaul
Copy link
Member

Thanks for your contributions, @bghira!

@sayakpaul sayakpaul merged commit 544710e into huggingface:main Mar 26, 2024
@bghira bghira deleted the bugfix/mps-inference-xl branch March 26, 2024 16:25
AbhinavGopal pushed a commit to AbhinavGopal/diffusers that referenced this pull request Mar 27, 2024
…hift unexpectedly due to pytorch bugs (huggingface#7446)

* mps: fix XL pipeline inference at training time due to upstream pytorch bug

* Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Co-authored-by: Sayak Paul <[email protected]>

* apply the safe-guarding logic elsewhere.

---------

Co-authored-by: bghira <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
…hift unexpectedly due to pytorch bugs (#7446)

* mps: fix XL pipeline inference at training time due to upstream pytorch bug

* Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Co-authored-by: Sayak Paul <[email protected]>

* apply the safe-guarding logic elsewhere.

---------

Co-authored-by: bghira <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MPS] SDXL pipeline fails inference in fp16 mode

5 participants