Skip to content

[WIP] Enable causal block mask for sdpa #1348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mreso
Copy link

@mreso mreso commented Jun 26, 2025

This PR enables causal block mask for sdpa using nested tensors.

To test run:

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.local_batch_size=2 --model.flavor=debugmodel_sdpa_block_causal

TODO
[X] Test with debug model
[X Test with llama3 8B
[] Test with llama3 70B
[X] Compare loss vs sdpa_causal vs flex block_causal
loss of flex_attn and sdpa + causal block mask are identical over 3k steps:
image
[X] Compare memory with local_batch_size>1 (e.g. (bs=8, max_seq_len=8192) vs (bs=1, max_seq_len=8*8192)
Active memory is the same for both cases. Slight advantage over flex attention during initialization:
image
Note: full activation checkpointing was enabled because flex does not support selective cpkting.
[X] Compare throughput:
Slight advantage flex_attn. sdpa + causal block mask has advantage over sdpa + causal mask:
image

[X] Test torch.compile
Call of F.scaled_dot_product_attention fails with

   File "/home/matthias.meta/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 712, in compile_wrapper
      raise e.with_traceback(None) from e.__cause__  # User compiler error
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  torch._dynamo.exc.Unsupported: Data dependent operator
    Explanation: Operator `aten._local_scalar_dense.default` has a non-Tensor output whose value is dependent on the data of Tensor inputs.
    Hint: Enable tracing of data-dependent output operators with `torch._dynamo.config.capture_scalar_outputs = True`

    Developer debug context: aten._local_scalar_dense.default


  from user code:
     File "/home/matthias.meta/venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/home/matthias.meta/torchtitan/torchtitan/models/llama3/model/model.py", line 300, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "/home/matthias.meta/torchtitan/torchtitan/models/llama3/model/model.py", line 192, in forward
      output = self.sdpa(xq, xk, xv)
    File "/home/matthias.meta/torchtitan/torchtitan/models/attention.py", line 247, in forward
      act_nested = F.scaled_dot_product_attention(

  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Enabling torch._dynamo.config.capture_scalar_outputs still fails:

    File "/home/matthias.meta/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 712, in compile_wrapper
      raise e.with_traceback(None) from e.__cause__  # User compiler error
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  torch._dynamo.exc.UserError: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

  Caused by: act_nested = F.scaled_dot_product_attention(  # torchtitan/torchtitan/models/attention.py:249 in forward (nested/_internal/nested_tensor.py:36 in _get_sdpa_extreme_seqlen)
  For more information, run with TORCH_LOGS="dynamic"
  For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
  If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
  For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

  User Stack (most recent call last):
    (snipped, see stack below for prefix)
    File "/home/matthias.meta/venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/home/matthias.meta/torchtitan/torchtitan/models/llama3/model/model.py", line 300, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "/home/matthias.meta/torchtitan/torchtitan/models/llama3/model/model.py", line 192, in forward
      output = self.sdpa(xq, xk, xv)
    File "/home/matthias.meta/torchtitan/torchtitan/models/attention.py", line 249, in forward
      act_nested = F.scaled_dot_product_attention(

  For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
  For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

  from user code:
     File "/home/matthias.meta/venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/home/matthias.meta/torchtitan/torchtitan/models/llama3/model/model.py", line 300, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "/home/matthias.meta/torchtitan/torchtitan/models/llama3/model/model.py", line 192, in forward
      output = self.sdpa(xq, xk, xv)
    File "/home/matthias.meta/torchtitan/torchtitan/models/attention.py", line 249, in forward
      act_nested = F.scaled_dot_product_attention(

  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

@mreso mreso requested a review from lessw2020 June 26, 2025 19:02
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 26, 2025
@fegin
Copy link
Contributor

fegin commented Jun 27, 2025

I thought we want to pursue SDPA with variable length support instead of using NestedTensor as the uncertainty of DTensor + NestedTensor composability issue. We didn't investigate how to do CP with SDPA + NestedTensor because CP + SDPA relies on the DTensor dispatcher.

cc., @drisspg

@mreso
Copy link
Author

mreso commented Jun 27, 2025

I thought we want to pursue SDPA with variable length support instead of using NestedTensor as the uncertainty of DTensor + NestedTensor composability issue. We didn't investigate how to do CP with SDPA + NestedTensor because CP + SDPA relies on the DTensor dispatcher.

cc., @drisspg

Wasn't in this case the variable length support in SDPA and nested tensor just a different kind of interface for the same thing?

@fegin
Copy link
Contributor

fegin commented Jun 30, 2025

It's not the same. If SDPA directly supports variable length, that will be the native tensor cases, similar to directly export FlashAttention interface to SDPA without the needed of NestedTensor. NestedTensor + DTensor (and CP) is an uncertainty to us.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants