Skip to content

Commit 3d02c92

Browse files
pcuencaanton-l
andauthored
mps changes for PyTorch 1.13 (huggingface#926)
* Docs: refer to pre-RC version of PyTorch 1.13.0. * Remove temporary workaround for unavailable op. * Update comment to make it less ambiguous. * Remove use of contiguous in mps. It appears to not longer be necessary. * Special case: use einsum for much better performance in mps * Update mps docs. * Minor doc update. * Accept suggestion Co-authored-by: Anton Lozhkov <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 28b134e commit 3d02c92

File tree

6 files changed

+47
-25
lines changed

6 files changed

+47
-25
lines changed

docs/source/optimization/mps.mdx

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@ specific language governing permissions and limitations under the License.
1717
## Requirements
1818

1919
- Mac computer with Apple silicon (M1/M2) hardware.
20-
- macOS 12.3 or later.
20+
- macOS 12.6 or later (13.0 or later recommended).
2121
- arm64 version of Python.
22-
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
22+
- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using:
23+
24+
```
25+
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu
26+
```
2327

2428
## Inference Pipeline
2529

@@ -34,6 +38,9 @@ from diffusers import StableDiffusionPipeline
3438
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
3539
pipe = pipe.to("mps")
3640

41+
# Recommended if your computer has < 64 GB of RAM
42+
pipe.enable_attention_slicing()
43+
3744
prompt = "a photo of an astronaut riding a horse on mars"
3845

3946
# First-time "warmup" pass (see explanation above)
@@ -43,16 +50,17 @@ _ = pipe(prompt, num_inference_steps=1)
4350
image = pipe(prompt).images[0]
4451
```
4552

46-
## Known Issues
53+
## Performance Recommendations
4754

48-
- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
49-
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching.
55+
M1/M2 performance is very sensitive to memory pressure. The system will automatically swap if it needs to, but performance will degrade significantly when it does.
5056

51-
## Performance
57+
We recommend you use _attention slicing_ to reduce memory pressure during inference and prevent swapping, particularly if your computer has lass than 64 GB of system RAM, or if you generate images at non-standard resolutions larger than 512 × 512 pixels. Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually has a performance impact of ~20% in computers without universal memory, but we have observed _better performance_ in most Apple Silicon computers, unless you have 64 GB or more.
5258

53-
These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5.
59+
```python
60+
pipeline.enable_attention_slicing()
61+
```
5462

55-
| Device | Steps | Time |
56-
|--------|-------|---------|
57-
| CPU | 50 | 213.46s |
58-
| MPS | 50 | 30.81s |
63+
## Known Issues
64+
65+
- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
66+
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching.

examples/community/clip_guided_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def __call__(
249249
latents_dtype = text_embeddings.dtype
250250
if latents is None:
251251
if self.device.type == "mps":
252-
# randn does not exist on mps
252+
# randn does not work reproducibly on mps
253253
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
254254
self.device
255255
)

examples/community/interpolate_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __call__(
324324
latents_dtype = text_embeddings.dtype
325325
if latents is None:
326326
if self.device.type == "mps":
327-
# randn does not exist on mps
327+
# randn does not work reproducibly on mps
328328
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
329329
self.device
330330
)

src/diffusers/models/attention.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def _set_attention_slice(self, slice_size):
207207
self.attn2._slice_size = slice_size
208208

209209
def forward(self, hidden_states, context=None):
210-
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
211210
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
212211
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
213212
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@@ -288,10 +287,19 @@ def forward(self, hidden_states, context=None, mask=None):
288287

289288
def _attention(self, query, key, value):
290289
# TODO: use baddbmm for better performance
291-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
290+
if query.device.type == "mps":
291+
# Better performance on mps (~20-25%)
292+
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
293+
else:
294+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
292295
attention_probs = attention_scores.softmax(dim=-1)
293296
# compute attention output
294-
hidden_states = torch.matmul(attention_probs, value)
297+
298+
if query.device.type == "mps":
299+
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
300+
else:
301+
hidden_states = torch.matmul(attention_probs, value)
302+
295303
# reshape hidden_states
296304
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
297305
return hidden_states
@@ -305,11 +313,21 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
305313
for i in range(hidden_states.shape[0] // slice_size):
306314
start_idx = i * slice_size
307315
end_idx = (i + 1) * slice_size
308-
attn_slice = (
309-
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
310-
) # TODO: use baddbmm for better performance
316+
if query.device.type == "mps":
317+
# Better performance on mps (~20-25%)
318+
attn_slice = (
319+
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
320+
* self.scale
321+
)
322+
else:
323+
attn_slice = (
324+
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
325+
) # TODO: use baddbmm for better performance
311326
attn_slice = attn_slice.softmax(dim=-1)
312-
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
327+
if query.device.type == "mps":
328+
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
329+
else:
330+
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
313331

314332
hidden_states[start_idx:end_idx] = attn_slice
315333

src/diffusers/models/resnet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,6 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
492492
kernel_h, kernel_w = kernel.shape
493493

494494
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
495-
496-
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
497-
if tensor.device.type == "mps":
498-
out = out.to("cpu")
499495
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
500496
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
501497

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def __call__(
287287
latents_dtype = text_embeddings.dtype
288288
if latents is None:
289289
if self.device.type == "mps":
290-
# randn does not exist on mps
290+
# randn does not work reproducibly on mps
291291
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
292292
self.device
293293
)

0 commit comments

Comments
 (0)