[WIP] Enable causal block mask for sdpa #1348
Draft
+169
−21
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR enables causal block mask for sdpa using nested tensors.
To test run:
TODO



[X] Test with debug model
[X Test with llama3 8B
[] Test with llama3 70B
[X] Compare loss vs sdpa_causal vs flex block_causal
loss of flex_attn and sdpa + causal block mask are identical over 3k steps:
[X] Compare memory with local_batch_size>1 (e.g. (bs=8, max_seq_len=8192) vs (bs=1, max_seq_len=8*8192)
Active memory is the same for both cases. Slight advantage over flex attention during initialization:
Note: full activation checkpointing was enabled because flex does not support selective cpkting.
[X] Compare throughput:
Slight advantage flex_attn. sdpa + causal block mask has advantage over sdpa + causal mask:
[X] Test torch.compile
Call of F.scaled_dot_product_attention fails with
Enabling torch._dynamo.config.capture_scalar_outputs still fails: