Skip to content

llama : Support llama 4 text-only #12791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add chunk attn mask
  • Loading branch information
ngxson committed Apr 7, 2025
commit e6a2809c2d42042cb5e64052117be1e36af53b83
12 changes: 10 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,17 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
}

// may need to cut off old tokens for sliding window
// TODO @ngxson : the check for n_attn_chunk is temporary, need to optimize it
if (data_swa) {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
Comment on lines +480 to +483
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson Here in this check, I think that the second condition is always false. So we can simplify to:

...
if (kv_self->cells[i].pos < pos_chunk_start) {
...

Copy link
Collaborator Author

@ngxson ngxson May 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah thanks for noticing, yes it is redundant.

Btw I'm thinking about how to refactor the mask generation is a way the the code is more easier to understand (i.e. make it sounds closer to english). My idea looks like this:

for (int j = 0; j < n_seq_tokens; ++j) {
    const llama_pos batch_pos = ubatch->pos[s*n_seq_tokens + j];
    for (int i = 0; i < n_kv; ++i) {
        const llama_pos cache_pos = kv_self->cells[i].pos > pos;
        bool masked = false; // masked tokens will not be attended to

        if (causal) {
            // mask future tokens outside of the batch
            masked = cache_pos > batch_pos;
        }

        if (hparams.n_attn_chunk) {
            // mask past tokens outside of the chunk
            llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
            masked = cache_pos < pos_chunk_start;
        }

        if (!kv_self->cells[i].has_seq_id(seq_id)) {
            // mask tokens that are not in the same sequence
            masked = true;
        }
    }
}

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can improve - I'm already doing some improvements in this regard in #13194

} else {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
Expand Down
1 change: 1 addition & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ struct llama_hparams {

uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
Expand Down
5 changes: 5 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
// hack: we use SWA to store the chunked attn mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, SWA -> AUX makes sense.

Btw, we should soon implement actual SWA / chunked attention that uses less memory. It shouldn't be a big change and will improve memory usage significantly for such models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rename makes quite more changes than I expected, so I think I'll do it in another PR to test it more thoroughly. Here I'll only edit my comment to make it more clear that I'm using the "swa" variable to store the chunked mask

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this after having the logits matching automated test. The problem is that changing name n_swa to n_pattern_aux make the code checking if (n_swa) then using SWA becomes invalid. It should now become if (n_pattern_aux && is_swa) then use SWA

I think it's better to add an enum called llama_mask_aux having 3 values: none, swa, chunked ; so that the code will become more clear

// luckily, the n_swa_pattern is the same as chunked layer pattern: 3 chunked - 1 full
hparams.n_swa_pattern = 4;
hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa = 1; // unused, added to trigger the SWA

switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break;
Expand Down
Loading