Skip to content

Replace with flash attention #1133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
28cf43a
verify auto_parallel intermediate api in paddlemix
jeff41404 Jan 20, 2025
1b23b6b
Merge branch 'develop' into verify_auto_parallel_intermediate_api_in_…
jeff41404 Jan 20, 2025
24c316c
temporary adjustment dataloader_num_workers=0
jeff41404 Jan 21, 2025
12d22ef
move ViT in dataloader to support auto_parallel sharding parallel
jeff41404 Jan 22, 2025
5f12dd6
Merge branch 'develop' into verify_auto_parallel_intermediate_api_in_…
jeff41404 Feb 8, 2025
646fb69
update preprocess of image to align loss of hybrid qwen2vl
jeff41404 Feb 10, 2025
76d750a
modify the input to 2 fields to meet the requirement of dynamic to st…
jeff41404 Feb 13, 2025
559e66a
modify model to get input_ids, attention_mask, inputs_embeds, labels …
jeff41404 Feb 18, 2025
32731fc
Merge branch 'develop' into verify_auto_parallel_intermediate_api_in_…
jeff41404 Mar 4, 2025
614b494
fix some issue of d2s
jeff41404 Mar 4, 2025
e56625b
fix vit embed
Xing-lil Mar 6, 2025
5c34581
changes code as per the feedback
Xing-lil Mar 7, 2025
8d63838
fix 2D dense_tensor_idx
Xing-lil Mar 7, 2025
acec889
Switch to no dist network
Xing-lil Mar 7, 2025
5f32352
Merge pull request #1 from Xing-lil/lzx_dev
jeff41404 Mar 7, 2025
4e58c65
modify create_attention_module to call modeling_qwen2_vl_network
jeff41404 Mar 11, 2025
afc3105
update dense_tensor_idx
jeff41404 Mar 11, 2025
b995ae0
fix dense_tensor_idx
jeff41404 Mar 11, 2025
be3f458
recover modeling_qwen2_vl
jeff41404 Mar 12, 2025
a91a061
Merge branch 'develop' into verify_auto_parallel_intermediate_api_in_…
pkhk-1 Mar 13, 2025
42cc2af
Clean up and correct annotations
jeff41404 Mar 13, 2025
2243313
Merge branch 'develop' into verify_auto_parallel_intermediate_api_in_…
jeff41404 Mar 17, 2025
3a8a681
replace with flash attention
liym27 Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
replace with flash attention
  • Loading branch information
liym27 committed Mar 17, 2025
commit 3a8a6818bd1002c72caeba0d1ac9382d6f27305d
20 changes: 6 additions & 14 deletions paddlemix/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import paddle.distributed.fleet.meta_parallel as mpu
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.functional.flash_attention import flash_attention
from paddle import Tensor
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
Expand Down Expand Up @@ -936,21 +937,12 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# Reashape to the expected shape for Flash Attention
# [1, 3599, 12, 128]
query_states = query_states.transpose(perm=[0, 2, 1, 3])
key_states = key_states.transpose(perm=[0, 2, 1, 3])
value_states = value_states.transpose(perm=[0, 2, 1, 3])

attn_output = self._flash_attention_forward(
query_states,
key_states,
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
attention_mask,
q_len
# dropout=0.0 if not self.training else self.attention_dropout,
# causal=self.is_causal,
)
causal=True,
return_softmax=output_attentions)

attn_output = attn_output.reshape([bsz, q_len, -1])
attn_output = self.o_proj(attn_output)
Expand Down
21 changes: 6 additions & 15 deletions paddlemix/models/qwen2_vl/modeling_qwen2_vl_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from paddlenlp.transformers.linear_utils import Linear
from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast, ModelOutput
from paddlenlp.transformers.model_utils import PretrainedModel

from paddle.nn.functional.flash_attention import flash_attention
from paddlemix.models.flash_attn_utils import (
create_attention_module,
has_flash_attn_func,
Expand Down Expand Up @@ -862,21 +862,12 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# Reashape to the expected shape for Flash Attention
# [1, 3599, 12, 128]
query_states = query_states.transpose(perm=[0, 2, 1, 3])
key_states = key_states.transpose(perm=[0, 2, 1, 3])
value_states = value_states.transpose(perm=[0, 2, 1, 3])

attn_output = self._flash_attention_forward(
query_states,
key_states,
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
attention_mask,
q_len
# dropout=0.0 if not self.training else self.attention_dropout,
# causal=self.is_causal,
)
causal=True,
return_softmax=output_attentions)

attn_output = attn_output.reshape([bsz, q_len, -1])
attn_output = self.o_proj(attn_output)
Expand Down