Skip to content

Conversation

cyr0930
Copy link
Contributor

@cyr0930 cyr0930 commented May 21, 2025

Summary

Try to fix #439.

As above issue addressed, chunk hidden state across batch-dimension has restrictive benefits.
Therefore I try to chunk hidden state across (batch*seq_len)-dimension.

As it requires non-trivial online loss computation, we cannot use fusing forward-backward technique in this case.
However, memory footprint issue still can be addressed by slicing hidden state into small chunks.
This is because we can do backward-step chunk by chunk instead of doing it all at once which results in high spike, although materializing all logits is inevitable.

I'm not sure the implementation of this PR is perfect for now, but just want to check this idea is valid and aligned with the spirit of liger-kernel. Any feedback would be great. Thanks.

Testing Done

I haven't run the tests yet, because I just want to check this concept is okay to be accepted.

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug in chunking?

3 participants