Skip to content

fix flash attention in ppdiffuser #1211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ppdiffusers/ppdiffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def prepare_attention_mask(
num_heads = self.heads
if attention_mask is None:
return attention_mask

ori_type = attention_mask.dtype
attention_mask = attention_mask.to(paddle.float32)

Expand Down Expand Up @@ -1296,7 +1296,7 @@ def __call__(
# adapt the scaled_dot_product_attention_ when attention_mask is a bool tensor
if attention_mask is not None and attention_mask.dtype == paddle.bool:
L, S = query.shape[1], key.shape[1]
attention_mask_tmp = paddle.zeros([1,1, L, S], dtype=query.dtype)
attention_mask_tmp = paddle.zeros([1, 1, L, S], dtype=query.dtype)
attention_mask_tmp = attention_mask_tmp.masked_fill(attention_mask.logical_not(), float("-inf"))
attention_mask = attention_mask_tmp

Expand Down
22 changes: 12 additions & 10 deletions ppdiffusers/ppdiffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def str2bool(v):
raise ValueError("Not supported value: {}".format(v))


def is_npu_available():
return paddle.device.get_device().startswith("npu")


# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
import importlib_metadata
Expand Down Expand Up @@ -76,19 +80,19 @@ def str2bool(v):

if _paddle_available:
try:
from paddle.incubate.nn.memory_efficient_attention import ( # noqa
memory_efficient_attention,
_ = paddle.nn.functional.scaled_dot_product_attention(
paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
attn_mask=paddle.ones((1, 2, 1, 1), dtype=paddle.float16),
)

# _ = memory_efficient_attention(
# paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
# paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
# paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
# )
_ppxformers_available = True
except Exception:
_ppxformers_available = False

if is_npu_available():
_ppxformers_available = False

else:
logger.info("Disabling Paddle because USE_PADDLE is set")
_paddle_available = False
Expand Down Expand Up @@ -375,8 +379,6 @@ def is_scipy_available():
def is_librosa_available():
return _librosa_available

def is_npu_available():
return paddle.device.get_device().startswith("npu")

def is_ppxformers_available():
USE_PPXFORMERS = str2bool(os.getenv("USE_PPXFORMERS", True))
Expand Down