Skip to content

Qualcomm AI Engine Direct - GA model enablement (T5) #12234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

DannyYuyang-quic
Copy link
Collaborator

@DannyYuyang-quic DannyYuyang-quic commented Jul 4, 2025

Summary

  • e2e script / test case for GA T5 model
    • perf: 16a8w avg encoding time: 4.09ms/inf, avg decoding time: 6ms/inf (SM8750)
    • acc: F1 Score ~= 76% in SQuAD
  • add QA dataset for Seq2SeqLM benchmarking

Test plan

python -m examples.qualcomm.oss_scripts.t5.t5 -b build-android -m ${soc} -H ${host_id} -s ${device_id} -d ./SQuAD-v1.1.csv

cc: @haowhsu-quic,@cccclai

Copy link

pytorch-bot bot commented Jul 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/12234

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 4e79e47 with merge base 1decf7a (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 4, 2025
@DannyYuyang-quic
Copy link
Collaborator Author

@pytorchbot label "release notes: qualcomm"

@pytorch-bot pytorch-bot bot added the release notes: qualcomm Changes to the Qualcomm backend delegate label Jul 4, 2025
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this in D77877631.

logger = logging.get_logger(__name__)


# Copy from transformers/models/t5/modeling_t5.py (transformers=4.47.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How different it is compared with transformers/models/t5/modeling_t5.py?

Copy link
Collaborator Author

@DannyYuyang-quic DannyYuyang-quic Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two changes compared to transformers/models/t5/modeling_t5.py:
both added to move the computation of relative position out of runtime and into precomputed buffers. This is because T5Attention._relative_position_bucket is not QNN-friendly.

@cccclai
Copy link
Contributor

cccclai commented Jul 7, 2025

Also can you rebase?

@DannyYuyang-quic DannyYuyang-quic force-pushed the dev1/danny/GA_T5 branch 2 times, most recently from fa38382 to f15420a Compare July 8, 2025 03:01
Summary:
 - e2e script / test case for GA T5 model
    - perf: 16a8w avg encoding time: 4.09ms/inf, avg decoding time: 6ms/inf (SM8750)
    - acc: F1 Score ~= 76% in SQuAD
 - add QA dataset for Seq2SeqLM benchmarking
):
super().__init__(config, embed_tokens)

# ====================Qualcomm Changed=================================
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part of the added is that I precompute the relative_position_bucket using T5Attention._relative_position_bucket and register the result as a buffer.

Comment on lines +234 to +277
# ====================Qualcomm Changed=================================
# The bias is indexed by cache_position to select the correct positions for the current step.
if self.is_decoder:
# For decoder, use the decoder's relative position bias table.
position_bias = (
self.block[0]
.layer[0]
.SelfAttention.relative_attention_bias(
self.decoder_self_attn_position_bias[cache_position]
)
.permute([2, 0, 1])
.unsqueeze(0)
)
else:
# For encoder, use the encoder's relative position bias table.
position_bias = (
self.block[0]
.layer[0]
.SelfAttention.relative_attention_bias(
self.encoder_self_attn_position_bias[cache_position]
)
.permute([2, 0, 1])
.unsqueeze(0)
)
position_bias = position_bias[:, :, -seq_length:, :]
if self.is_decoder:
position_bias = (
position_bias + causal_mask[:, :, :, : self.max_cache_length]
)
else:
position_bias = position_bias + causal_mask[:, :, :, :seq_length]

# For cross-attention in decoder, precompute encoder-decoder position bias as zeros and add encoder attention mask.
encoder_decoder_position_bias = None
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = torch.zeros(
(1, self.config.num_heads, seq_length, self.max_hidden_seq_length),
dtype=encoder_extended_attention_mask.dtype,
)
encoder_decoder_position_bias = (
encoder_decoder_position_bias
+ encoder_extended_attention_mask[:, :, :, : self.max_hidden_seq_length]
)
# ========================================================================
Copy link
Collaborator Author

@DannyYuyang-quic DannyYuyang-quic Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second part is in the forward, where I retrieve the relative_position_bucket by indexing into the buffer using the correct cache position.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for optimizing the performance. Can the source model definition also lowered?

Copy link
Collaborator Author

@DannyYuyang-quic DannyYuyang-quic Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I still need to write some source transforms or add passes. This is because the function _relative_position_bucket from source T5 computation has two main issues:

  1. Unsupported ops for (int64/int32) datatypes:
    Source function performs operations like abs, min, and neg on int64, but these ops are not supported on int64/int32 in QNN.
  2. Unsupported casting:
    There is a cast from float32 to int64 in source function, but in the 16a8w quantization case, QNN's cast op actually performs a cast from uint16 to int32, which is also unsupported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a step back from being performant, if user tried and it fails here, and if it's not supported in QNN, should this op fall back to cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if it's not supported in QNN, this should fall back to cpu.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this in D77877631.

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The diff train is kind of stuck, will try to merge it soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: qualcomm Changes to the Qualcomm backend delegate
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants