Skip to content

Conversation

@pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Sep 18, 2025

EDIT: README FIRST
This is an implementation of a new type of attention gating in GGML.
Therefore, this implementation will be focused on CORRECTNESS ONLY.
Speed tuning and support for more architectures will come in future PRs.
Please do not spam this threads with reports about performance, especially on backend architectures (CUDA, Vulkan).

CURRENT STATE: pending merge of #17063

===
It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.

Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.

Resolves #15940

@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Sep 18, 2025
@gabe-l-hart
Copy link
Collaborator

I'll try to get into it in more detail soon, but here are a few general thoughts after quickly skimming the PR:

  1. The structure of what you've got smells correct, so it's likely close, but missing something small yet critical
  2. A full repro with the error it's raising would definitely help debug
  3. My debugging process for this would be:
    1. Make sure tokenization is solid (print statements as necessary to compare tokens before input)
    2. Use llama-eval-callback to dump tensors for a single prefill step
    3. Run an identical single prefill with the reference impl (transformers or otherwise), and inject prints as needed to dump tensors along the way
    4. Visually comb through them (particularly the sum at each point) to see where things start diverging significantly

@bugparty
Copy link
Contributor

It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.

Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.

Resolves #15940

interesting, maybe we can learn together

@pwilkin pwilkin marked this pull request as draft September 19, 2025 08:07
@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

  1. A full repro with the error it's raising would definitely help debug

Running llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are " yields this weird memory error:

#0  __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56      in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1  0x000070552b29eb63 in __internal_syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=0, a6=0, nr=61) at ./nptl/cancellation.c:49
warning: 49     ./nptl/cancellation.c: No such file or directory
#2  __syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75      in ./nptl/cancellation.c
#3  0x000070552b31afdf in __GI___wait4 (pid=<optimized out>, stat_loc=<optimized out>, options=<optimized out>, usage=<optimized out>) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30     ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4  0x000070552bb45c31 in ggml_print_backtrace () at /devel/tools/llama.cpp/ggml/src/ggml.c:196
warning: Source file is more recent than executable.
196             waitpid(child_pid, NULL, 0);
#5  0x000070552bb45de5 in ggml_abort (file=0x70552bbcdac8 "/devel/tools/llama.cpp/ggml/src/ggml-backend.cpp", line=189, fmt=0x70552bbcd8af "GGML_ASSERT(%s) failed") at /devel/tools/llama.cpp/ggml/src/ggml.c:230
230             ggml_print_backtrace();
#6  0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189         GGML_ASSERT(buffer);
#7  0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170         return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
#8  0x000070552c07a114 in llm_graph_input_rs::set_input (this=0x5f11bdf6aea0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:241
241             GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
#9  0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437         inp_rs->set_input(ubatch);
#10 0x000070552c07b549 in llm_graph_result::set_inputs (this=0x5f11be01ddf0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:480
480             input->set_input(ubatch);
#11 0x000070552c01ddb3 in llama_context::process_ubatch (this=0x5f11c05b5b50, ubatch=..., gtype=LLM_GRAPH_TYPE_DECODER, mctx=0x5f11be00ff00, ret=@0x7fff74d22ea4: 538976288) at /devel/tools/llama.cpp/src/llama-context.cpp:779
779             res->set_inputs(&ubatch);
#12 0x000070552c01f367 in llama_context::decode (this=0x5f11c05b5b50, batch_inp=...) at /devel/tools/llama.cpp/src/llama-context.cpp:1088
1088            const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
#13 0x000070552c025e49 in llama_decode (ctx=0x5f11c05b5b50, batch=...) at /devel/tools/llama.cpp/src/llama-context.cpp:2726
2726        const int ret = ctx->decode(batch);
#14 0x00005f11a2021559 in common_init_from_params (params=...) at /devel/tools/llama.cpp/common/common.cpp:1066
1066                llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
#15 0x00005f11a1e4a3c0 in main (argc=7, argv=0x7fff74d25968) at /devel/tools/llama.cpp/tools/main/main.cpp:140
140         common_init_result llama_init = common_init_from_params(params);

I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.

@CISC
Copy link
Collaborator

CISC commented Sep 19, 2025

  1. A full repro with the error it's raising would definitely help debug

Running llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are " yields this weird memory error:

...
#6  0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189         GGML_ASSERT(buffer);
#7  0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170         return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
...

The backend buffer is NULL.

@ngxson
Copy link
Collaborator

ngxson commented Sep 19, 2025

#9  0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437         inp_rs->set_input(ubatch);

The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.

I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.

Hmm I think I said the reverse: not to merge it but make the op simple

I feel like this op can be implemented using other ggml ops like mul, mul_mat, sum. Which part of the calculation do you think that can't be constructed using existing ops?

This is the more important question: should we try to implement it using existing ops, or add a new op and spend even more time to optimize it cross all backends?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

Now this is an error I haven't expected to encounter:

GGML_ABORT("not enough space in the context's memory pool");

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.

How do I allocate the memory for the linear layers then? I seem to have misunderstood how build_inp_mem_hybrid() works...

@yarikdevcom
Copy link

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Sep 19, 2025

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)

I send a coffee also.

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

GGML_ABORT("not enough space in the context's memory pool");

Probably there are too many nodes on cgraph, try increasing the limit via llama_context::graph_max_nodes()

Comment on lines 19054 to 19056
Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these ggml_cont can be removed if Q/gate are separated. ggml_cont is not recommended when dealing with big tensors

Copy link
Collaborator

@CISC CISC Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually none of these need ggml_cont, Q is 3D already, Q/K are RoPEd so can be views and V can also be a 3D view now.

Edit: sorry, not quite true about V, only if QKV is fused, the weird gate fuse threw me off. Nevertheless, K/V are already contiguous at this point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

Copy link
Collaborator

@CISC CISC Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

Are you sure? AFAIK those issues are fixed.

Edit: Also, if there still are issues they will never get fixed if we work around them. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

I think all of these cases are fixed now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an impl of 2D rope that relies on ggml_view: https://github.com/ngxson/ggml-easy/blob/f56e5e499b1f21a4aae73010e9d9582840428457/demo/2d-rope.cpp

It works on CPU and Metal, but doesn't work on CUDA/Vulkan. Couldn't tested on other backends, but feel free to make a PR to address this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that seems to work. sorry @pwilkin you will need to manually revert the change where I split Q/gate. the tensor shape for Q will be:

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);

layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape of LLM_TENSOR_ATTN_Q and LLM_TENSOR_SSM_OUT should not contain n_ff

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

^ proposed fix for the 3 comments above: 46110e0

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 20, 2025

@ngxson Thanks, scale_bias was one op I was missing in my endeavors :>

I got an LLM to rewrite the internal delta into tensor logic. After a day of manually fixing that crap, I think I understand it enough to rewrite it myself ;)

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

Honestly I would prefer taking time to understand the mamba/ssm implementation then writing the code manually. Code written by LLM are mostly attempts for 1-to-1 translation from pytorch --> GGML which looks quite confusing

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 20, 2025

Honestly I would prefer taking time to understand the mamba/ssm implementation then writing the code manually. Code written by LLM are mostly attempts for 1-to-1 translation from pytorch --> GGML which looks quite confusing

Yeah, for me getting a rough outline then going over it manually is the best way to learn :)

I tried the "one-to-one" approach and ended up with a graph that wouldn't fit in 16 GB of RAM for a 500M model...

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 20, 2025

Aight, I cleaned up the main graph calculation, now I have to figure out how to include conv_states_all in my delta_net function in order to not get the memory error.

pwilkin and others added 7 commits October 31, 2025 23:42
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
@Rob-P-Smith

This comment was marked as off-topic.

@lovedheart
Copy link

Not sure if it is related to this PR. When evaulating Qwen3-Next with Vulkan backend on iGPU, the model was loaded to both CPU and Vulkan, which doubles the RAM.

load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 48 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 49/49 layers to GPU
load_tensors: CPU_Mapped model buffer size = 25265.84 MiB
load_tensors: Vulkan0 model buffer size = 27246.99 MiB
..........................................................................................

See logs

PS E:\LLM\qwen3_next_llama> .\build\bin\Release\llama-cli.exe -m E:\LLM\Qwen3-Next-80B-A3B-Instruct-IQ1_S_M.gguf ggml_vulkan: Found 1 Vulkan devices: ggml_vulkan: 0 = AMD Radeon 780M Graphics (AMD proprietary driver) | uma: 1 | fp16: 1 | bf16: 1 | warp size: 64 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat build: 7337 (61667c3) with MSVC 19.44.35217.0 for x64 main: llama backend init main: load the model and apply lora adapter, if any llama_model_load_from_file_impl: using device Vulkan0 (AMD Radeon 780M Graphics) (unknown id) - 46478 MiB free llama_model_loader: loaded meta data with 45 key-value pairs and 807 tensors from E:\LLM\Qwen3-Next-80B-A3B-Instruct-IQ1_S_M.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = qwen3next llama_model_loader: - kv 1: general.type str = model llama_model_loader: - kv 2: general.name str = Qwen3 Next 80B A3B Instruct llama_model_loader: - kv 3: general.finetune str = Instruct llama_model_loader: - kv 4: general.basename str = Qwen3-Next llama_model_loader: - kv 5: general.size_label str = 80B-A3B llama_model_loader: - kv 6: general.license str = apache-2.0 llama_model_loader: - kv 7: general.license.link str = https://huggingface.co/Qwen/Qwen3-Nex... llama_model_loader: - kv 8: general.tags arr[str,1] = ["text-generation"] llama_model_loader: - kv 9: qwen3next.block_count u32 = 48 llama_model_loader: - kv 10: qwen3next.context_length u32 = 262144 llama_model_loader: - kv 11: qwen3next.embedding_length u32 = 2048 llama_model_loader: - kv 12: qwen3next.feed_forward_length u32 = 5120 llama_model_loader: - kv 13: qwen3next.attention.head_count u32 = 16 llama_model_loader: - kv 14: qwen3next.attention.head_count_kv u32 = 2 llama_model_loader: - kv 15: qwen3next.rope.freq_base f32 = 10000000.000000 llama_model_loader: - kv 16: qwen3next.attention.layer_norm_rms_epsilon f32 = 0.000001 llama_model_loader: - kv 17: qwen3next.expert_used_count u32 = 10 llama_model_loader: - kv 18: qwen3next.attention.key_length u32 = 256 llama_model_loader: - kv 19: qwen3next.attention.value_length u32 = 256 llama_model_loader: - kv 20: qwen3next.expert_count u32 = 512 llama_model_loader: - kv 21: qwen3next.expert_feed_forward_length u32 = 512 llama_model_loader: - kv 22: qwen3next.expert_shared_feed_forward_length u32 = 512 llama_model_loader: - kv 23: qwen3next.ssm.conv_kernel u32 = 4 llama_model_loader: - kv 24: qwen3next.ssm.state_size u32 = 128 llama_model_loader: - kv 25: qwen3next.ssm.group_count u32 = 16 llama_model_loader: - kv 26: qwen3next.ssm.time_step_rank u32 = 32 llama_model_loader: - kv 27: qwen3next.ssm.inner_size u32 = 4096 llama_model_loader: - kv 28: qwen3next.rope.dimension_count u32 = 64 llama_model_loader: - kv 29: tokenizer.ggml.model str = gpt2 llama_model_loader: - kv 30: tokenizer.ggml.pre str = qwen2 llama_model_loader: - kv 31: tokenizer.ggml.tokens arr[str,151936] = ["!", "\"", "#", "$", "%", "&", "'", ... llama_model_loader: - kv 32: tokenizer.ggml.token_type arr[i32,151936] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... llama_model_loader: - kv 33: tokenizer.ggml.merges arr[str,151387] = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",... llama_model_loader: - kv 34: tokenizer.ggml.eos_token_id u32 = 151645 llama_model_loader: - kv 35: tokenizer.ggml.padding_token_id u32 = 151643 llama_model_loader: - kv 36: tokenizer.ggml.bos_token_id u32 = 151643 llama_model_loader: - kv 37: tokenizer.ggml.add_bos_token bool = false llama_model_loader: - kv 38: tokenizer.chat_template str = {%- if tools %}\n {{- '<|im_start|>... llama_model_loader: - kv 39: general.quantization_version u32 = 2 llama_model_loader: - kv 40: general.file_type u32 = 24 llama_model_loader: - kv 41: quantize.imatrix.file str = E:\LLM\llama_qwen3_next\imatrix.gguf llama_model_loader: - kv 42: quantize.imatrix.dataset str = ..\calibration_dataset.txt llama_model_loader: - kv 43: quantize.imatrix.entries_count u32 = 540 llama_model_loader: - kv 44: quantize.imatrix.chunks_count u32 = 350 llama_model_loader: - type f32: 313 tensors llama_model_loader: - type q2_K: 20 tensors llama_model_loader: - type q4_K: 9 tensors llama_model_loader: - type q6_K: 6 tensors llama_model_loader: - type iq2_xxs: 10 tensors llama_model_loader: - type iq3_xxs: 4 tensors llama_model_loader: - type iq1_s: 14 tensors llama_model_loader: - type iq4_xs: 8 tensors llama_model_loader: - type iq1_m: 72 tensors llama_model_loader: - type bf16: 351 tensors print_info: file format = GGUF V3 (latest) print_info: file type = IQ1_S - 1.5625 bpw print_info: file size = 27.20 GiB (2.93 BPW) load: printing all EOG tokens: load: - 151643 ('<|endoftext|>') load: - 151645 ('<|im_end|>') load: - 151662 ('<|fim_pad|>') load: - 151663 ('<|repo_name|>') load: - 151664 ('<|file_sep|>') load: special tokens cache size = 26 load: token to piece cache size = 0.9311 MB print_info: arch = qwen3next print_info: vocab_only = 0 print_info: n_ctx_train = 262144 print_info: n_embd = 2048 print_info: n_layer = 48 print_info: n_head = 16 print_info: n_head_kv = 2 print_info: n_rot = 64 print_info: n_swa = 0 print_info: is_swa_any = 0 print_info: n_embd_head_k = 256 print_info: n_embd_head_v = 256 print_info: n_gqa = 8 print_info: n_embd_k_gqa = 512 print_info: n_embd_v_gqa = 512 print_info: f_norm_eps = 0.0e+00 print_info: f_norm_rms_eps = 1.0e-06 print_info: f_clamp_kqv = 0.0e+00 print_info: f_max_alibi_bias = 0.0e+00 print_info: f_logit_scale = 0.0e+00 print_info: f_attn_scale = 0.0e+00 print_info: n_ff = 5120 print_info: n_expert = 512 print_info: n_expert_used = 10 print_info: n_expert_groups = 0 print_info: n_group_used = 0 print_info: causal attn = 1 print_info: pooling type = 0 print_info: rope type = 2 print_info: rope scaling = linear print_info: freq_base_train = 10000000.0 print_info: freq_scale_train = 1 print_info: n_ctx_orig_yarn = 262144 print_info: rope_finetuned = unknown print_info: ssm_d_conv = 4 print_info: ssm_d_inner = 4096 print_info: ssm_d_state = 128 print_info: ssm_dt_rank = 32 print_info: ssm_n_group = 16 print_info: ssm_dt_b_c_rms = 0 print_info: model type = ?B print_info: model params = 79.67 B print_info: general.name = Qwen3 Next 80B A3B Instruct print_info: vocab type = BPE print_info: n_vocab = 151936 print_info: n_merges = 151387 print_info: BOS token = 151643 '<|endoftext|>' print_info: EOS token = 151645 '<|im_end|>' print_info: EOT token = 151645 '<|im_end|>' print_info: PAD token = 151643 '<|endoftext|>' print_info: LF token = 198 'Ċ' print_info: FIM PRE token = 151659 '<|fim_prefix|>' print_info: FIM SUF token = 151661 '<|fim_suffix|>' print_info: FIM MID token = 151660 '<|fim_middle|>' print_info: FIM PAD token = 151662 '<|fim_pad|>' print_info: FIM REP token = 151663 '<|repo_name|>' print_info: FIM SEP token = 151664 '<|file_sep|>' print_info: EOG token = 151643 '<|endoftext|>' print_info: EOG token = 151645 '<|im_end|>' print_info: EOG token = 151662 '<|fim_pad|>' print_info: EOG token = 151663 '<|repo_name|>' print_info: EOG token = 151664 '<|file_sep|>' print_info: max token length = 256 load_tensors: loading model tensors, this can take a while... (mmap = true) load_tensors: offloading 48 repeating layers to GPU load_tensors: offloading output layer to GPU load_tensors: offloaded 49/49 layers to GPU load_tensors: CPU_Mapped model buffer size = 25265.84 MiB load_tensors: Vulkan0 model buffer size = 27246.99 MiB .......................................................................................... llama_context: constructing llama_context llama_context: n_seq_max = 1 llama_context: n_ctx = 4096 llama_context: n_ctx_per_seq = 4096 llama_context: n_batch = 2048 llama_context: n_ubatch = 512 llama_context: causal_attn = 1 llama_context: flash_attn = auto llama_context: kv_unified = false llama_context: freq_base = 10000000.0 llama_context: freq_scale = 1 llama_context: n_ctx_per_seq (4096) < n_ctx_train (262144) -- the full capacity of the model will not be utilized llama_context: Vulkan_Host output buffer size = 0.58 MiB llama_kv_cache: Vulkan0 KV buffer size = 96.00 MiB llama_kv_cache: size = 96.00 MiB ( 4096 cells, 12 layers, 1/1 seqs), K (f16): 48.00 MiB, V (f16): 48.00 MiB llama_memory_recurrent: Vulkan0 RS buffer size = 75.38 MiB llama_memory_recurrent: size = 75.38 MiB ( 1 cells, 48 layers, 1 seqs), R (f32): 3.38 MiB, S (f32): 72.00 MiB llama_context: Flash Attention was auto, set to enabled llama_context: Vulkan0 compute buffer size = 307.75 MiB llama_context: Vulkan_Host compute buffer size = 105.01 MiB llama_context: graph nodes = 9168 llama_context: graph splits = 402 (with bs=512), 438 (with bs=1) common_init_from_params: added <|endoftext|> logit bias = -inf common_init_from_params: added <|im_end|> logit bias = -inf common_init_from_params: added <|fim_pad|> logit bias = -inf common_init_from_params: added <|repo_name|> logit bias = -inf common_init_from_params: added <|file_sep|> logit bias = -inf common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096 common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable) main: llama threadpool init, n_threads = 8 main: chat template is available, enabling conversation mode (disable it with -no-cnv) main: chat template example: <|im_start|>system You are a helpful assistant<|im_end|> <|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there<|im_end|> <|im_start|>user How are you?<|im_end|> <|im_start|>assistant

system_info: n_threads = 8 (n_threads_batch = 8) / 16 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |

main: interactive mode on.
sampler seed: 3858617728
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-n-sigma -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 0

== Running in interactive mode. ==

  • Press Ctrl+C to interject at any time.
  • Press Return to return control to the AI.
  • To return control without starting a new line, end your input with '/'.
  • If you want to submit another line, end your input with ''.
  • Not using system message. To change it, set a different value via -sys PROMPT

help me design a website for selling vps
Absolutely! Here's a comprehensive, professional website design plan for selling VPS (Virtual Private Server) hosting services. This plan includes structure, UI/UX recommendations, key content, technical considerations, and conversion optimization tips.


🌐 Website Name Suggestion

HostVPS.pro (or your brand name)
Tagline: “Power. Control. Scalability. Your VPS, Perfected.”


🧩 1. Website Structure & Pages

Homepage (Landing Page)

Goal: Convert visitors into leads/customers in <10 seconds.

Key Sections:

  1. Hero Section

    • Headline: “High-Performance VPS Hosting – Start in Seconds, Scale Forever”
    • Subheadline: Dedicated resources. Root access. Linux & Windows. No shared noise.
    • CTA Button: “Get Started for $4.99/mo” (bold, contrasting color)
    • Background: Animated cloud/server graphic or subtle video loop of data centers.
  2. Features Grid (3-Column)

    • 🚀 Instant Provisioning – Deploy in <60 seconds
    • 🔒 Root Access – Full control over your server
    • 📈 SSD Storage & 99.9% Uptime – Reliability you can trust
  3. Pricing Comparison Table (Hero Feature)

    Plan CPU RAM SSD Bandwidth Price Button
    Starter 1 1GB 20GB 1TB $4.99/mo [Choose]
    Professional 2 2GB 40GB 2TB $9.99/mo [Choose]
    Enterprise 4 4GB 80GB 5TB $19.99/mo [Choose]
    Custom ✔️ ✔️ ✔️ ✔️ Contact Us [Request Quote]

    Highlight the “Starter” plan as “Most Popular”
    Add “Unlimited Bandwidth” as upsell option

  4. Trust Indicators

    • Logos of trusted partners (e.g., Cloudflare, SSL providers)
    • “Trusted by 10,000+ developers & businesses”
    • 4.9/5 ★★★★★ (from verified reviews)
  5. Video or Demo (Optional but Powerful)

    • 30-sec screen recording: “How to deploy your VPS in 3 clicks”

Features Page

Goal: Educate visitors on why VPS beats shared hosting.

Sections:

  • What is a VPS? (Simple analogy: “Like renting your own apartment in a building, vs. sharing a dorm room”)
  • Why Choose Our VPS?
    • Dedicated Resources (no “noisy neighbor

llama_perf_sampler_print: sampling time = 87.45 ms / 708 runs ( 0.12 ms per token, 8096.43 tokens per second)
llama_perf_context_print: load time = 73043.81 ms
llama_perf_context_print: prompt eval time = 1877.24 ms / 17 tokens ( 110.43 ms per token, 9.06 tokens per second)
llama_perf_context_print: eval time = 131838.77 ms / 690 runs ( 191.07 ms per token, 5.23 tokens per second)
llama_perf_context_print: total time = 161926.39 ms / 707 tokens
llama_perf_context_print: graphs reused = 0
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
llama_memory_breakdown_print: | - Vulkan0 (780M Graphics) | 48924 = 18751 + (27726 = 27246 + 171 + 307) + 2446 |
llama_memory_breakdown_print: | - Host | 25370 = 25265 + 0 + 105 |
Interrupted by user

@Mushoz
Copy link

Mushoz commented Nov 1, 2025

@lovedheart please stop spamming this topic with irrelevant posts. This is a CPU only implementation focussed on correctness. GPU support will come in followup PRs

@ggerganov
Copy link
Member

please stop spamming this topic with irrelevant posts.

The OP needs to be updated with instructions - we cannot expect people to read all 430 comments in the thread.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 1, 2025

The OP needs to be updated with instructions - we cannot expect people to read all 430 comments in the thread.

Updated the first post with a README.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested parallel generations and they seem to work:

make -j && ./bin/llama-parallel -m ../models/qwen3-next/ggml-model-q8_0.gguf -np 5 -ns 8 --temp 0 

To move this forward, I suggest doing the following:

  • Extract all new ops into a separate PR
  • After it is approved, rebase this PR on top of the new ops
  • Try to reduce the nodes in the compute graph if possible (I'll take a look into this now)

ggml_tensor * get_s_l(int32_t il) const;

int32_t s_copy(int i) const;
bool has_previous_state() const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used anymore

Comment on lines +17 to +21
#include <mutex>

// Forward declarations for internal cache access
struct llama_memory_hybrid;
struct llama_memory_recurrent;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed

@ggerganov
Copy link
Member

@pwilkin Here is a first pass at cleaning up the linear attention graph:

diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index afa40189f..7986bea1e 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -20462,19 +20462,19 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         ggml_build_forward_expand(gf, cur);
     }
 
-    struct ggml_tensor * delta_net_unified(struct ggml_context * ctx,
-                                           struct ggml_tensor *  q,
-                                           struct ggml_tensor *  k,
-                                           struct ggml_tensor *  v,
-                                           struct ggml_tensor *  g,
-                                           struct ggml_tensor *  beta,
-                                           struct ggml_tensor *  state,
-                                           struct ggml_tensor *  causal_mask,
-                                           struct ggml_tensor *  identity,
-                                           bool                  use_qk_l2norm,
-                                           float                 eps_norm,
-                                           int                   il
-                                        ) {
+    ggml_tensor * delta_net_unified(
+            ggml_context * ctx,
+            ggml_tensor  * q,
+            ggml_tensor  * k,
+            ggml_tensor  * v,
+            ggml_tensor  * g,
+            ggml_tensor  * beta,
+            ggml_tensor  * state,
+            ggml_tensor  * causal_mask,
+            ggml_tensor  * identity,
+            bool           use_qk_l2norm,
+            float          eps_norm,
+            int            il) {
         GGML_ASSERT(ggml_is_contiguous(q));
         GGML_ASSERT(ggml_is_contiguous(k));
         GGML_ASSERT(ggml_is_contiguous(v));
@@ -20511,7 +20511,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
 
         beta = ggml_sigmoid(ctx, beta);
 
-        struct ggml_tensor * causal_diag_mask = ggml_add(ctx, causal_mask, identity);
+        ggml_tensor * causal_diag_mask = ggml_add(ctx, causal_mask, identity);
 
         cb(q, "q_in", il);
         cb(k, "k_in", il);
@@ -20519,11 +20519,12 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         cb(beta, "beta_in", il);
         cb(g, "g_in", il);
 
-        q    = ggml_cont_4d(ctx, ggml_permute(ctx, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-        k    = ggml_cont_4d(ctx, ggml_permute(ctx, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-        v    = ggml_cont_4d(ctx, ggml_permute(ctx, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+        q = ggml_cont_4d(ctx, ggml_permute(ctx, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+        k = ggml_cont_4d(ctx, ggml_permute(ctx, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+        v = ggml_cont_4d(ctx, ggml_permute(ctx, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+        g = ggml_cont_4d(ctx, ggml_permute(ctx, g, 2, 0, 3, 1), n_tokens, 1,   H_k, n_seqs);
+
         beta = ggml_cont(ctx, ggml_permute(ctx, beta, 2, 0, 1, 3));
-        g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
         state = ggml_reshape_4d(ctx, state, S_v, S_v, H_v, n_seqs);
 
         cb(q, "q_perm", il);
@@ -20536,39 +20537,32 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
         GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
         GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-        GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 &&
-                    beta->ne[3] == n_seqs);
-        GGML_ASSERT(g->ne[0] == n_tokens && g->ne[2] == H_k && g->ne[1] == 1 && g->ne[3] == n_seqs);
-
-        struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
-        v_beta                      = ggml_reshape_4d(ctx, v_beta, S_v, n_tokens, H_k, n_seqs);
-        struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
-        k_beta                      = ggml_reshape_4d(ctx, k_beta, S_v, n_tokens, H_k, n_seqs);
-        k                           = ggml_reshape_4d(ctx, k, S_v, n_tokens, H_k, n_seqs);
-        q                           = ggml_reshape_4d(ctx, q, S_v, n_tokens, H_k, n_seqs);
-        v                           = ggml_reshape_4d(ctx, v, S_v, n_tokens, H_v, n_seqs);
-        g                           = ggml_reshape_4d(ctx, g, n_tokens, 1, H_k, n_seqs);
-        struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
+        GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
+
+        ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
+        ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
+
+        ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
 
         cb(k_beta, "k_beta", il);
         cb(v_beta, "v_beta", il);
         cb(g_cumsum, "g_cumsum", il);
 
-        struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v,
+        ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v,
                                                   n_seqs);  // [chunk_size, 1, n_tokens, n_seqs]
-        struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v,
+        ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v,
                                                   n_seqs);  // [1, chunk_size, n_tokens, n_seqs]
 
         // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
-        // struct ggml_tensor * gcs_i_broadcast =
+        // ggml_tensor * gcs_i_broadcast =
         //     ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
         //                     n_seqs);  // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
         // Don't need this, this one will get auto-broadcast
-        struct ggml_tensor * gcs_j_broadcast =
+        ggml_tensor * gcs_j_broadcast =
             ggml_repeat_4d(ctx, gcs_j, n_tokens, n_tokens, H_v,
                            n_seqs);  // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
 
-        struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i);
+        ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i);
 
         // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
         decay_mask = ggml_mul(ctx, decay_mask, causal_diag_mask);
@@ -20580,12 +20574,12 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         cb(decay_mask, "decay_mask", il);
 
         // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
-        struct ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, ggml_cont(ctx, k), ggml_cont(ctx, k_beta));
+        ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, k, k_beta);
 
         cb(kmulkbeta, "kmulkbeta", il);
 
-        struct ggml_tensor * k_decay   = ggml_mul(ctx, kmulkbeta, decay_mask);
-        struct ggml_tensor * attn      = ggml_neg(ctx, ggml_mul(ctx, k_decay, causal_mask));
+        ggml_tensor * k_decay   = ggml_mul(ctx, kmulkbeta, decay_mask);
+        ggml_tensor * attn      = ggml_neg(ctx, ggml_mul(ctx, k_decay, causal_mask));
 
         cb(attn, "attn_pre_rec", il);
 
@@ -20597,29 +20591,28 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         //
         // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
         ggml_tensor * attn_lower = ggml_mul(ctx, attn, causal_mask);
-        struct ggml_tensor * lhs =
-            ggml_sub(ctx, ggml_repeat_4d(ctx, identity, identity->ne[0], identity->ne[1], attn_lower->ne[2], attn_lower->ne[3]), attn_lower);
+        ggml_tensor * lhs = ggml_sub(ctx, ggml_repeat(ctx, identity, attn_lower), attn_lower);
 
-        struct ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn);
+        ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn);
         attn = ggml_mul(ctx, lin_solve, causal_mask);
-        attn = ggml_cont(ctx, ggml_add(ctx, attn, identity));
+        attn = ggml_add(ctx, attn, identity);
 
         // value = attn @ v_beta
-        v = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, attn, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)))));
+        v = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)), attn);
 
         cb(v, "value_beta", il);
 
         // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
-        struct ggml_tensor * g_cumsum_t = ggml_cont(ctx, ggml_transpose(ctx, g_cumsum));
-        struct ggml_tensor * gexp = ggml_exp(ctx, g_cumsum_t);
+        ggml_tensor * g_cumsum_t = ggml_cont(ctx, ggml_transpose(ctx, g_cumsum));
+        ggml_tensor * gexp = ggml_exp(ctx, g_cumsum_t);
 
         cb(gexp, "g_cum_exp", il);
 
-        struct ggml_tensor * kbeta_gexp = ggml_mul(ctx, ggml_cont(ctx, k_beta), gexp);
+        ggml_tensor * kbeta_gexp = ggml_mul(ctx, k_beta, gexp);
 
         cb(kbeta_gexp, "kbeta_gexp", il);
 
-        struct ggml_tensor * k_cumdecay =
+        ggml_tensor * k_cumdecay =
             ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, attn, ggml_cont(ctx, ggml_transpose(ctx, kbeta_gexp)))));
 
         cb(k_cumdecay, "k_cumdecay", il);
@@ -20631,28 +20624,32 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
 
         cb(attn, "attn_decay_key", il);
 
+        ggml_tensor * state_t = ggml_cont(ctx, ggml_transpose(ctx, state));
+
         // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-        struct ggml_tensor * v_prime = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, state)), k_cumdecay);
+        ggml_tensor * v_prime = ggml_mul_mat(ctx, state_t, k_cumdecay);
 
         cb(v_prime, "v_prime", il);
 
         // v_new = v_i - v_prime
-        struct ggml_tensor * v_new = ggml_sub(ctx, ggml_repeat_4d(ctx, v, v_prime->ne[0], v_prime->ne[1], v_prime->ne[2], v_prime->ne[3]), v_prime);
+        ggml_tensor * v_new = ggml_sub(ctx, ggml_repeat(ctx, v, v_prime), v_prime);
+
+        ggml_tensor * v_new_t = ggml_cont(ctx, ggml_transpose(ctx, v_new));
 
         cb(v_new, "v_new", il);
 
         // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        struct ggml_tensor * q_g_exp = ggml_mul(ctx, q, gexp);
-        struct ggml_tensor * attn_inter = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, state)), q_g_exp);
+        ggml_tensor * q_g_exp = ggml_mul(ctx, q, gexp);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx, state_t, q_g_exp);
 
         cb(attn_inter, "attn_inter", il);
 
         // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        struct ggml_tensor * v_attn = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, v_new)), attn);
+        ggml_tensor * v_attn = ggml_mul_mat(ctx, v_new_t, attn);
 
         cb(v_attn, "v_attn", il);
 
-        struct ggml_tensor * core_attn_out = ggml_add(ctx, attn_inter, v_attn);
+        ggml_tensor * core_attn_out = ggml_add(ctx, attn_inter, v_attn);
 
         cb(core_attn_out, "core_attn_out", il);
 
@@ -20662,22 +20659,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
         // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
 
-        gexp = ggml_cont(ctx, gexp);
-
         ggml_tensor * g_cum_last = ggml_cont(ctx, ggml_view_4d(ctx, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], g_cumsum_t->nb[1],
                                                 g_cumsum_t->nb[2], g_cumsum_t->nb[3], g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1)));
 
         cb(g_cum_last, "g_cum_last", il);
 
-        ggml_tensor * gexp_last = ggml_cont_4d(ctx, ggml_exp(ctx, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
+        ggml_tensor * gexp_last = ggml_reshape_4d(ctx, ggml_exp(ctx, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
 
         cb(g_cum_last, "gexp_last", il);
 
-        ggml_tensor * g_cum_last_3d = ggml_cont_3d(ctx, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
+        ggml_tensor * g_cum_last_3d = ggml_reshape_3d(ctx, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
 
         cb(g_cum_last, "g_cum_last_3d", il);
 
-        ggml_tensor * g_cumsum_3d = ggml_cont_3d(ctx, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
+        ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
 
         cb(g_cum_last, "g_cumsum_3d", il);
 
@@ -20689,24 +20684,22 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
 
         cb(g_cum_last, "g_diff_exp", il);
 
-        ggml_tensor * key_gdiff = ggml_mul(ctx, k, ggml_cont_4d(ctx, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], g_diff_exp->ne[2] * g_diff_exp->ne[3]));
+        ggml_tensor * key_gdiff = ggml_mul(ctx, k, ggml_reshape_4d(ctx, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], g_diff_exp->ne[2] * g_diff_exp->ne[3]));
 
         cb(g_cum_last, "key_gdiff", il);
 
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_cont(ctx, ggml_transpose(ctx, v_new))),
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx, v_new_t,
                                         ggml_cont(ctx, ggml_transpose(ctx, key_gdiff)));
 
         cb(kgdmulvnew, "kgdmulvnew", il);
 
-        struct ggml_tensor * new_state =
-            ggml_add(ctx, ggml_mul(ctx, state, ggml_cont_4d(ctx, gexp_last, 1, 1, H_v, ggml_nelements(gexp_last) / H_v)),
-                 kgdmulvnew);
+        ggml_tensor * new_state = ggml_add(ctx, ggml_mul(ctx, state, gexp_last), kgdmulvnew);
 
         cb(new_state, "new_state", il);
 
         // flatten output
-        struct ggml_tensor * flat_output = ggml_cont_1d(ctx, ggml_permute(ctx, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
-        struct ggml_tensor * flat_state = ggml_cont_1d(ctx, new_state, S_v * S_v * H_v * n_seqs);
+        ggml_tensor * flat_output = ggml_cont_1d(ctx, ggml_permute(ctx, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
+        ggml_tensor * flat_state = ggml_cont_1d(ctx, new_state, S_v * S_v * H_v * n_seqs);
 
         return ggml_concat(ctx, flat_output, flat_state, 0);
     }
@@ -20799,15 +20792,14 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         return cur;
     }
 
-
-
-    ggml_tensor * build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
-                                                                        ggml_tensor *        cur,
-                                                                        const llama_model &  model,
-                                                                        const llama_ubatch & ubatch,
-                                                                        ggml_tensor *        causal_mask,
-                                                                        ggml_tensor *        identity,
-                                                                        int                  il) {
+    ggml_tensor * build_qwen3next_linear_attn_layer(
+            llm_graph_input_rs * inp,
+            ggml_tensor *        cur,
+            const llama_model &  model,
+            const llama_ubatch & ubatch,
+            ggml_tensor *        causal_mask,
+            ggml_tensor *        identity,
+            int                  il) {
         const auto * mctx_cur = inp->mctx;
 
         const int64_t d_inner  = hparams.ssm_d_inner;
@@ -21050,7 +21042,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
 
         // Reshape both attn_out_final and z to 2D tensors for normalization
         // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-        ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+        ggml_tensor * attn_out_2d_final = ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
 
         // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
         ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
@@ -21058,11 +21050,8 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         // Apply gated normalization: self.norm(core_attn_out, z)
         ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
 
-        // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
-        ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
-
         // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
-        ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+        ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
         cb(final_output, "final_output", il);
 
         // Output projection
@@ -21070,7 +21059,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         cb(cur, "linear_attn_out", il);
 
         // Reshape back to original dimensions
-        cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs));
+        cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
         return cur;
     }
 

This shaves of ~700 nodes from the graph. There are still ~8400 nodes remaining which is a bit excessive, but I think we can work with this for now and we'll probably be able to reduce them further with some proper op rearrangements. But not a showstopper.

I think we can look forward to merging this soon (see my previous comment for the plan). One thing we should pay extra attention is if we can benefit from some weight re-orientations (e.g. reshapes, transpositions) during model conversions that would facilitate later graph optimizations (e.g. avoid reshapes, perumtes) since after we settle on the weights, we won't be able to easily make changes to them later. Haven't spotted anything specific so far though.

Comment on lines +20991 to +21011

// if head keys and value keys are different, repeat to force tensors into matching shapes
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
int64_t repeat_factor = num_v_heads / num_k_heads;

// repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);

// Repeat along the third dimension (the new dimension with size 1)
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);

// Reshape back to merge the head and repeat dimensions
// From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
// Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is new - wondering if we can avoid this explicit repeat by utilizing broadcasts. Though, not super important for now.

Co-authored-by: Georgi Gerganov <[email protected]>
@ross-rosario
Copy link

This is off-topic/spam but I CAN'T WAIT FOR THIS PR TO GET MERGED!

@mattepiu
Copy link

mattepiu commented Nov 6, 2025

Meanwhile, gated deltanet are now explained in detail: https://sebastianraschka.com/llms-from-scratch/ch04/08_deltanet/

This also explains the high reduction of KV_cache ( reduced by a factor of 2*n_tokens on the deltanet, for a whopping 1/4 on Qwen-Next and Kimi-Linear !!! )

@rombodawg
Copy link

rombodawg commented Nov 8, 2025

This is off-topic/spam but I CAN'T WAIT FOR THIS PR TO GET MERGED!

I really dont get why people are disliking comments like this. Like people are exited for this to happen. Why be upset about someone elses happiness. No one is rushing the project, we just want to share our excitement for its completion

Edit: Ooops sorry 😅

@k3d3
Copy link

k3d3 commented Nov 8, 2025

There are hundreds of people subscribed to this thread. Every message like yours and theirs means everyone gets notified of an update that is not relevant to the feature at hand.

This is why the comment is getting thumbs down.

I've already made a comment here #16095 (comment) about not spamming, but it seems like the comment got lost due to how active this thread is. Regardless, I was really hoping I wouldn't have to make a second comment.

If you want to show support, add a heart react to the initial message. If you want updates or to know when it's ready, subscribe like the rest of us. But please, stop spamming this thread. Especially if you're fully aware that what you're doing is spamming, as Ross Rosario mentions in their comment, just don't.

Also, like with my last message, please don't respond to my comment - just thumbs it up or down, to reduce the spam. I don't like adding to the noise either, but when people start defending the spam, it needs to be reiterated.

@mattepiu

This comment has been minimized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Qwen3-Next support