Skip to content

Commit 780b3a4

Browse files
authored
Fix typo in AttnProcessor2_0 symbol (huggingface#2404)
Fix typo in AttnProcessor2_0 symbol.
1 parent 07547df commit 780b3a4

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

docs/source/en/optimization/torch2.0.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
5050
```Python
5151
import torch
5252
from diffusers import StableDiffusionPipeline
53-
from diffusers.models.cross_attention import AttnProccesor2_0
53+
from diffusers.models.cross_attention import AttnProcessor2_0
5454

5555
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
56-
pipe.unet.set_attn_processor(AttnProccesor2_0())
56+
pipe.unet.set_attn_processor(AttnProcessor2_0())
5757

5858
prompt = "a photo of an astronaut riding a horse on mars"
5959
image = pipe(prompt).images[0]

src/diffusers/models/cross_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def __init__(
9999
self.to_out.append(nn.Dropout(dropout))
100100

101101
# set attention processor
102-
# We use the AttnProccesor2_0 by default when torch2.x is used which uses
102+
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
103103
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
104104
if processor is None:
105-
processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
105+
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
106106
self.set_processor(processor)
107107

108108
def set_use_memory_efficient_attention_xformers(
@@ -466,10 +466,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
466466
return hidden_states
467467

468468

469-
class AttnProccesor2_0:
469+
class AttnProcessor2_0:
470470
def __init__(self):
471471
if not hasattr(F, "scaled_dot_product_attention"):
472-
raise ImportError("AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
472+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
473473

474474
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
475475
batch_size, sequence_length, inner_dim = hidden_states.shape

0 commit comments

Comments
 (0)