Skip to content

llama: add initial support for Falcon-H1 model family #14534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 74 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
991de6c
v1
younesbelkada Jul 3, 2025
f897efd
push more fixes
younesbelkada Jul 3, 2025
71a6848
another fix
younesbelkada Jul 3, 2025
03568c9
fix
younesbelkada Jul 3, 2025
0c93ef6
more fixes
younesbelkada Jul 3, 2025
fdd5cff
minor fix
younesbelkada Jul 3, 2025
14c37ec
more cleaning on python code
younesbelkada Jul 3, 2025
8bea922
python fixes
ibrahimkhadraoui Jul 4, 2025
071f4b7
changed precision for multipliers float 32->64
ibrahimkhadraoui Jul 4, 2025
50eadc7
fixes
younesbelkada Jul 4, 2025
a39a842
merge
younesbelkada Jul 4, 2025
1415cd8
another fix
younesbelkada Jul 4, 2025
243e4d1
fix
younesbelkada Jul 4, 2025
cce3549
pre-norm -> norm
younesbelkada Jul 4, 2025
22de62c
fix
younesbelkada Jul 4, 2025
2fe057c
Revert "fix"
ibrahimkhadraoui Jul 4, 2025
d22b4ea
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 4, 2025
6c7d9e2
fix
younesbelkada Jul 4, 2025
15138df
small fix ffn_norm
ibrahimkhadraoui Jul 4, 2025
a6d0067
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 4, 2025
1fd0574
try
younesbelkada Jul 4, 2025
250b4f1
mix instead of max
younesbelkada Jul 4, 2025
3ee7983
fix vocab size
ibrahimkhadraoui Jul 4, 2025
2aa48dd
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 4, 2025
9760c8b
conflict solve
ibrahimkhadraoui Jul 4, 2025
7a25441
fixed multipliers
ibrahimkhadraoui Jul 4, 2025
280dd2d
falcon-h1 specefic vocab resolved
ibrahimkhadraoui Jul 7, 2025
c56ec07
read arch from gguf.MODEL_ARCH
ibrahimkhadraoui Jul 7, 2025
c4af0f3
mamba_d_ssm added to d_inner find_hparam
ibrahimkhadraoui Jul 7, 2025
53304c8
remove unused functions from gguf_writer.py
ibrahimkhadraoui Jul 7, 2025
441d8d6
override modify_tensors instead of get_tensors
ibrahimkhadraoui Jul 7, 2025
6c39e77
fix conversion and d_inner
younesbelkada Jul 7, 2025
8c50893
added some cb functions for debugging puposes
ibrahimkhadraoui Jul 7, 2025
49d7420
inp_out_ids moved outside of layers loop
ibrahimkhadraoui Jul 7, 2025
97011d7
mup_vec create as float64
ibrahimkhadraoui Jul 7, 2025
286e1fa
fix rope_theta
ibrahimkhadraoui Jul 7, 2025
b3bc1fb
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 7, 2025
a9f3a63
injected mup
younesbelkada Jul 7, 2025
e96cc73
clean ups
younesbelkada Jul 7, 2025
3afb2a8
Merge pull request #1 from tiiuae/injected-mup
ibrahimkhadraoui Jul 7, 2025
0ad3502
rm extra space
ibrahimkhadraoui Jul 7, 2025
53446f7
rm unused MAMBA_CHUNK_SIZE
ibrahimkhadraoui Jul 7, 2025
ae937f4
rm unused key
ibrahimkhadraoui Jul 7, 2025
b6df0a4
add bos False
ibrahimkhadraoui Jul 7, 2025
935d46f
changed ROPE_TYPE
ibrahimkhadraoui Jul 7, 2025
624699c
cleaning debugging stuff
ibrahimkhadraoui Jul 7, 2025
042e5ff
cleaning debug quant
ibrahimkhadraoui Jul 7, 2025
f74e266
fix comment
younesbelkada Jul 7, 2025
632861e
some cleanups
younesbelkada Jul 7, 2025
084873c
some cleanups
younesbelkada Jul 7, 2025
fd20330
Update src/llama-model-loader.cpp
younesbelkada Jul 7, 2025
68cb784
more cleanups
younesbelkada Jul 7, 2025
d2f46f1
moe cleanuips
younesbelkada Jul 7, 2025
7d7da0b
d_ssm -> d_inner;
younesbelkada Jul 8, 2025
67b2664
cleaning unused hparams
ibrahimkhadraoui Jul 8, 2025
da8a338
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 8, 2025
e63ee46
cleanup
ibrahimkhadraoui Jul 8, 2025
d473d42
more cleanups
younesbelkada Jul 8, 2025
8555ee8
more cleanups on python conversion;
younesbelkada Jul 8, 2025
7846c67
minor cleanups
ibrahimkhadraoui Jul 8, 2025
2dee7cf
Apply suggestions from code review
younesbelkada Jul 8, 2025
a846d02
remove todo
younesbelkada Jul 8, 2025
f028a43
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 8, 2025
d41f111
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 8, 2025
f266d14
added falcon-h1
ibrahimkhadraoui Jul 8, 2025
4bc9e0c
tensor not required
younesbelkada Jul 8, 2025
2834a4a
clean
ibrahimkhadraoui Jul 8, 2025
823696b
remove unneeded attributes
younesbelkada Jul 8, 2025
adff470
more cleanups and fixed conversion
younesbelkada Jul 8, 2025
097df0e
remove final_norm
younesbelkada Jul 8, 2025
9a048d8
flake8 fixes
ibrahimkhadraoui Jul 8, 2025
52d1ef3
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp…
ibrahimkhadraoui Jul 8, 2025
58e3866
Update src/llama-model.cpp
younesbelkada Jul 8, 2025
d28c31a
Merge branch 'master' into add-fh1-rebased
younesbelkada Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 136 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4899,7 +4899,7 @@ def set_vocab(self):
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
Expand All @@ -4908,8 +4908,10 @@ def set_gguf_parameters(self):

# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
# skip the assertion for FalconH1 Model
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
Expand Down Expand Up @@ -4946,7 +4948,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
data_torch = data_torch.reshape((n_group, d_inner // n_group))

Expand Down Expand Up @@ -6539,6 +6541,135 @@ def set_gguf_parameters(self):
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])


@ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1

def __init__(self, *args, **kwargs):
# Set the hparam prefixes for Falcon Mamba2
self.hparam_prefixes = ["mamba"]

# Initialize the base Mamba2Model
super().__init__(*args, **kwargs)

# Use Llama conversion for attention
self._transformer_model_class = LlamaModel

# n_group and d_inner are used during reshape_tensors for mamaba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model

# Initialize any Falcon Mamba2 specific attributes
self.has_attention = True # Falcon Mamba2 has attention components

# Load Falcon-H1 multipliers from hyperparameters
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
self.intermediate_size = self.find_hparam(["intermediate_size"])
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)

def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)

def set_vocab(self):
self._set_vocab_gpt2()

def _generate_mup_vector(self, block_id: int) -> torch.Tensor:
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"])

mup_vector = torch.ones(1, 1, vector_shape, dtype=torch.float64)
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0]
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1]
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2]
mup_vector[:, :, 2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size] *= zxbcdt_multipliers[3]
mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size:] *= zxbcdt_multipliers[4]

return mup_vector

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
tensors = list(super().modify_tensors(data_torch, name, bid))
tensor = tensors[0][1]

if "down_proj" in name:
tensor = tensor * self.mlp_multipliers[1]
elif "gate_proj" in name:
tensor = tensor * self.mlp_multipliers[0]
elif "k_proj" in name:
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
elif "q_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "v_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "o_proj" in name:
tensor = tensor * self.attention_out_multiplier
elif "out_proj" in name:
tensor = tensor * self.ssm_out_multiplier
elif "in_proj" in name:
tensor = tensor * self.ssm_in_multiplier
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
elif "lm_head" in name:
tensor = tensor * self.hparams["lm_head_multiplier"]
elif "embed_tokens" in name:
tensor = tensor * self.hparams["embedding_multiplier"]

tensors = [(tensors[0][0], tensor)]
return tensors

def set_gguf_parameters(self):
super().set_gguf_parameters()

## General Params ##
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])

## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.find_hparam(["mamba_d_ssm"]))
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))

## Attention params ##
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in self.hparams else self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_key_length(self.hparams["head_dim"])
self.gguf_writer.add_value_length(self.hparams["head_dim"])

## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads"], optional=True) or self.find_hparam(["num_attention_heads"]))

# Add any other Falcon Mamba2 specific configuration
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))


@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
Expand Down Expand Up @@ -6687,6 +6818,7 @@ def prepare_tensors(self):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
{"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", },
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", },
{"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
Expand Down
37 changes: 37 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class SSM:
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
HEAD_DIM = "{arch}.ssm.head_dim"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head dimension in Mamba-2 is also the time step rank.

I guess it could be clearer to use a more appropriate name like this, though.

I'm not against, this is only to at least let you know.


class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"
Expand Down Expand Up @@ -288,6 +289,7 @@ class MODEL_ARCH(IntEnum):
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
FALCON_H1 = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
Expand Down Expand Up @@ -661,6 +663,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.FALCON_H1: "falcon_h1",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
}

Expand Down Expand Up @@ -2213,6 +2216,40 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.FALCON_H1: [
# Token embedding
MODEL_TENSOR.TOKEN_EMBD,

# Input layernorm
MODEL_TENSOR.ATTN_NORM,

# Attention components
MODEL_TENSOR.ATTN_Q, # Query projection
MODEL_TENSOR.ATTN_K, # Key projection
MODEL_TENSOR.ATTN_V, # Value projection
MODEL_TENSOR.ATTN_OUT, # Output projection

# SSM components (Mamba2 specific)
MODEL_TENSOR.SSM_IN, # Input projection for SSM
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
MODEL_TENSOR.SSM_DT, # Delta time projection
MODEL_TENSOR.SSM_A, # A parameter (log form)
MODEL_TENSOR.SSM_D, # D parameter
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
MODEL_TENSOR.SSM_OUT, # Output projection

# Pre-feedforward layernorm
MODEL_TENSOR.FFN_PRE_NORM,

# Feed-forward network components
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
MODEL_TENSOR.FFN_DOWN, # Down projection
MODEL_TENSOR.FFN_UP, # Up projection

# Post-feedforward layernorm
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
],
MODEL_ARCH.HUNYUAN_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,9 @@ def add_ssm_group_count(self, value: int) -> None:
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

def add_ssm_head_dim(self, value: int) -> None:
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

Expand Down
10 changes: 10 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,14 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"model.layers.{bid}.pre_ff_layernorm.weight",
),

# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.feed_forward.up_proj",
),

MODEL_TENSOR.FFN_GATE_INP: (
Expand Down Expand Up @@ -363,6 +365,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.feed_forward.down_proj",
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),

Expand Down Expand Up @@ -553,11 +556,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj",
),

MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d",
),

MODEL_TENSOR.SSM_X: (
Expand All @@ -568,25 +573,30 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj",
),

MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log",
),

MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D",
),

MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm", # falcon-h1
"backbone.layers.{bid}.mixer.norm", # mamba2
),

MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.mamba.out_proj", # falcon-h1
),

MODEL_TENSOR.TIME_MIX_W0: (
Expand Down
32 changes: 29 additions & 3 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_FALCON_H1, "falcon_h1" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
Expand Down Expand Up @@ -128,7 +129,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
Expand Down Expand Up @@ -1023,6 +1024,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_FALCON_H1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_XVERSE,
{
Expand Down Expand Up @@ -1949,9 +1974,10 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
}

bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
// List all mamba-attention hybrid models here
switch (arch) {
case LLM_ARCH_FALCON_H1:
return true;
default:
return false;
}
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ enum llm_arch {
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_FALCON_H1,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,
Expand Down
Loading
Loading