diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf6bc68380b19..8f9d21143d6b1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5669,11 +5669,20 @@ def set_gguf_parameters(self): logger.info("gguf: (granite) logits_scale = %s", logits_scale) -@ModelBase.register("GraniteMoeForCausalLM") +@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM") class GraniteMoeModel(GraniteModel): """Conversion for IBM's GraniteMoeForCausalLM""" model_arch = gguf.MODEL_ARCH.GRANITE_MOE + def set_gguf_parameters(self): + """GraniteMoeShared uses GraniteMoe parameters plus the following: + - shared_intermediate_size + """ + super().set_gguf_parameters() + if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): + self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) + logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: """In modeling_granitemoe, the JetMoe implementation of parallel experts is used. This essentially merges w1 and w3 into a single tensor with 2x @@ -5684,12 +5693,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("block_sparse_moe.input_linear.weight"): ffn_dim = self.hparams["intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" - gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + gate, up = data_torch.split(ffn_dim, dim=-2) return [ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), ] + if name.endswith("shared_mlp.input_linear.weight"): + ffn_dim = self.hparams["shared_intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" + gate, up = data_torch.split(ffn_dim, dim=-2) + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + ] + return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7dd7bb6d1b5d9..7f0968bb4806e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1893,6 +1893,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, ], MODEL_ARCH.CHAMELEON: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 003b0172c77b0..83c5b27ae1199 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -428,6 +428,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe ), MODEL_TENSOR.ATTN_Q_NORM: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f2bc8ca768502..abf436adac416 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1481,6 +1481,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 21b12339a221b..65f80da9372fe 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1385,6 +1385,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Add additional layer/vocab/etc checks here for other model sizes default: type = LLM_TYPE_UNKNOWN; } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_CHAMELEON: { @@ -1768,6 +1771,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } } } } break; @@ -4381,10 +4391,13 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } - if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { @@ -4594,11 +4607,6 @@ struct llm_build_llama : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -4670,11 +4678,6 @@ struct llm_build_llama : public llm_graph_context { cb(cur, "ffn_moe_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4697,11 +4700,6 @@ struct llm_build_llama : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -4812,11 +4810,6 @@ struct llm_build_deci : public llm_graph_context { continue; } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - // modified to support attention-free layer of Llama-3_1-Nemotron-51B ggml_tensor * ffn_inp = cur; if (n_head > 0) { @@ -4840,11 +4833,6 @@ struct llm_build_deci : public llm_graph_context { cb(cur, "ffn_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4867,11 +4855,6 @@ struct llm_build_deci : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -12210,6 +12193,195 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; + +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + + if (!inp_pos) { + inp_pos = build_inp_pos(); + } + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -12917,8 +13089,6 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA4: case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: { llm = std::make_unique(*this, params, gf); } break; @@ -13149,6 +13319,11 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_CHAMELEON: { llm = std::make_unique(*this, params, gf);