diff --git a/common/arg.cpp b/common/arg.cpp index b6bfe6f89bead..e0f4a15cc7784 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1570,7 +1570,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.image.emplace_back(value); } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION})); if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", diff --git a/common/common.h b/common/common.h index 1c0f199774976..88becc7f3181b 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,7 @@ enum llama_example { LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_TTS, + LLAMA_EXAMPLE_VISION, LLAMA_EXAMPLE_COUNT, }; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b5d95bd5639f3..a26df6d3eafe5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from itertools import chain +from transformers import AutoConfig import math import numpy as np import torch @@ -66,6 +67,13 @@ class Model: metadata_override: Path | None dir_model_card: Path + # for vision model + vision_arch: gguf.MODEL_ARCH | None = None + preprocessor_config: dict[str, Any] | None = None + vparams: dict[str, Any] | None = None + v_tensor_map: gguf.TensorNameMap | None = None + v_tensor_names: set[str] | None + # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -126,6 +134,16 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: return None raise KeyError(f"could not find any of: {keys}") + def find_vparams(self, keys: Iterable[str], optional: bool = False) -> Any: + if self.vparams is None: + raise ValueError("vision model parameters not set") + key = next((k for k in keys if k in self.vparams), None) + if key is not None: + return self.vparams[key] + if optional: + return None + raise KeyError(f"(vision) could not find any of: {keys}") + def set_vocab(self): self._set_vocab_gpt2() @@ -186,9 +204,10 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: f"Missing tensors: {missing}\n" f"Extra tensors: {extra}") - def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: - if key not in gguf.MODEL_TENSORS[self.model_arch]: - raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") + def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight", is_vision = False) -> str: + arch = self.vision_arch if is_vision and self.vision_arch is not None else self.model_arch + if key not in gguf.MODEL_TENSORS[arch]: + raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {arch!r}") name: str = gguf.TENSOR_NAMES[key] if "{bid}" in name: assert bid is not None @@ -210,9 +229,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) - if new_name is None: + new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None + if new_name is not None: + return new_name + elif new_name_vision is not None: + return new_name_vision + else: raise ValueError(f"Can not map tensor {name!r}") - return new_name def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) @@ -257,6 +280,23 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length(head_dim) self.gguf_writer.add_value_length(head_dim) + # Vision model parameters + if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None: + self.gguf_writer.add_vision_type("vit") + self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) + self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) + self.gguf_writer.add_vision_vit_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch]) + self.gguf_writer.add_vision_vit_block_count(self.vparams["num_hidden_layers"]) + self.gguf_writer.add_vision_vit_embedding_length(self.vparams["hidden_size"]) + self.gguf_writer.add_vision_vit_feed_forward_length(self.vparams["intermediate_size"]) + self.gguf_writer.add_vision_vit_head_count(self.vparams["num_attention_heads"]) + self.gguf_writer.add_vision_vit_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_vit_image_std(self.preprocessor_config["image_std"]) + try: + self.gguf_writer.add_vision_vit_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"])) + except KeyError: + self.gguf_writer.add_vision_vit_select_layer(0) + self.gguf_writer.add_file_type(self.ftype) logger.info(f"gguf: file type = {self.ftype}") @@ -466,7 +506,20 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + hparams = json.load(f) + if "text_config" in hparams: + hparams = {**hparams["text_config"], **hparams} + return hparams + + @staticmethod + def load_preprocessor_config(dir_model: Path): + # TODO: this varies vastly among models, need to handle more cases in the future + file_path = dir_model / "preprocessor_config.json" + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + else: + raise Exception(f"Preprocessor config not found at {file_path}") @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -519,7 +572,9 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: toktypes: list[int] = [] from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + # DEBIAN_FRONTEND=noninteractive means that the script is running in a non-interactive environment (i.e. CI), so we cannot answer Y/N when it asks for user input + is_cli_non_interactive = os.environ.get("DEBIAN_FRONTEND", "") == "noninteractive" + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=is_cli_non_interactive) vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size @@ -954,6 +1009,29 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0]) +# TODO: maybe merge this with Model in the future +class VisionModelHelper: + model: Model + tok_embd_tensor: Tensor | None = None + + def __init__(self, model: Model): + self.model = model + # TODO: how to do this without reading the whole safetensor file? + for tname, tensor in model.get_tensors(): + if tname.endswith("embed_tokens.weight"): + self.tok_embd_tensor = tensor + + def get_embd_for_tokens(self, map_token_to_tensor_name: Iterable[tuple[str, gguf.MODEL_TENSOR]], tensor_name_postfix = '.weight') -> Iterable[tuple[str, Tensor]]: + if self.tok_embd_tensor is None: + raise ValueError("Token embedding tensor not found") + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.model.dir_model, trust_remote_code=True) + for token, tensor_name in map_token_to_tensor_name: + tok_id = tokenizer.get_vocab()[token] + row = self.tok_embd_tensor[tok_id] + yield gguf.TENSOR_NAMES[tensor_name] + tensor_name_postfix, row + + @Model.register("GPTNeoXForCausalLM") class GPTNeoXModel(Model): model_arch = gguf.MODEL_ARCH.GPTNEOX @@ -1566,10 +1644,39 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "MobileLlamaForCausalLM", "Idefics3ForConditionalGeneration") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + model_type = self.hparams.get("model_type") + self.vision_arch = None + + # only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf + if "vision_config" in self.hparams and model_type == "llava": + self.vparams = self.hparams["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA + + # only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B + if "mm_vision_tower" in self.hparams and model_type == "mobilevlm": + from transformers import AutoImageProcessor + vision_model_id = self.hparams["mm_vision_tower"] + self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"] + self.preprocessor_config = AutoImageProcessor.from_pretrained(vision_model_id).to_dict() + self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM + + # only tested with https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct + if "vision_config" in self.hparams and model_type == "idefics3": + self.vparams = self.hparams["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_IDEFICS3 + + if self.vparams is not None and self.vision_arch is not None: + self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) + def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -1619,6 +1726,24 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + # For vision model + if self.vparams is not None: + max_pos_embd = -1 + self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) + # TODO: should not hardcode these, but they are currently missing from config.json + if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA: + self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM: + self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.LDPV2) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + if self.vision_arch == gguf.MODEL_ARCH.VISION_IDEFICS3: + self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP) + self.gguf_writer.add_vision_vit_scale_factor(self.hparams["scale_factor"]) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-05) + self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd) + @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: @@ -1632,11 +1757,23 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - - if name.endswith(("q_proj.weight", "q_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith(("k_proj.weight", "k_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + is_vision_tensor = "vision_tower" in name or "vision_model" in name + + if is_vision_tensor: + name = name.replace("model.vision_tower.", "") + if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3: + return [] # skip post_layernorm + + if not is_vision_tensor: + if name.startswith("model.text_model"): + name = name.replace("text_model.", "") # for SmolVLM + elif name.startswith("language_model"): + # language model tensors, remove the prefix + name = name.replace("language_model.", "") + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) # process the experts separately if name.find("block_sparse_moe.experts") != -1: @@ -1713,6 +1850,22 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("LlavaForConditionalGeneration") +class LlavaModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA + + def __init__(self, *args, **kwargs): + # quick fix for llava model + # see: https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/34 + hparams = Model.load_hparams(kwargs["dir_model"]) + if "vision_config" in hparams and hparams.get("model_type") == "llava": + text_config = hparams["text_config"] + text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict() + kwargs["hparams"] = {**text_config, **hparams} + + super().__init__(*args, **kwargs) + + @Model.register("DeciLMForCausalLM") class DeciModel(Model): model_arch = gguf.MODEL_ARCH.DECI @@ -2240,6 +2393,173 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: yield name, data +@Model.register("MiniCPMV") +class MiniCPMVModel(Qwen2Model): + # MiniCPM-V 2.5 is Qwen2 and 2.6 is Qwen-2.5 + model_arch = gguf.MODEL_ARCH.QWEN2 + proj_type: gguf.constants.CLIPProjectorType | None + resampler_n_embd = 0 + vhelper: VisionModelHelper | None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + model_type = self.hparams.get("model_type", None) + + # only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6 + if "vision_config" in self.hparams and model_type == "minicpmv": + self.vparams = self.hparams["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV + version = str(self.hparams.get("version", "unknown")) + if version == "2.5": + self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5 + elif version == "2.6": + self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6 + else: + raise ValueError(f"Unsupported MiniCPM-V version: {version}") + self.vhelper = VisionModelHelper(self) + # TODO: how to do this without reading the whole safetensor file? + for tname, tensor in self.get_tensors(): + if tname == "resampler.ln_post.bias": + self.resampler_n_embd = tensor.shape[0] + if self.resampler_n_embd < 2: + raise ValueError("Failed to detect resampler embedding size") + else: + raise ValueError("Expected vision_config, but not found") + + assert self.vparams is not None + assert self.vision_arch is not None + assert self.preprocessor_config is not None + self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5] + self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5] + self.hparams["vision_feature_layer"] = 0 + self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.vparams is not None and self.proj_type is not None + self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) + self.gguf_writer.add_vision_vit_projector_type(self.proj_type) + self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd) + + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # because the model operates excusively on 70x70 patches for now, we should precompute the positional embeddings to gain performance + # in the future, we can do it in cpp if we figure out how to do it efficiently + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True), + torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70))) + ) + assert self.vhelper is not None + added_tokens = [ + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_IMAGE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_IMAGE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_SLICE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_SLICE), + ] + for tensor_name, tensor in self.vhelper.get_embd_for_tokens(added_tokens): + yield tensor_name, tensor + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # for language part + if name.startswith("llm."): + return [(self.map_tensor_name(name.replace("llm.", "")), data_torch)] + + # split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v + if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"): + assert data_torch.shape[0] == 3 * self.resampler_n_embd + split_tensor = data_torch.chunk(3, dim=0) + name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias) + name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias) + name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias) + return [ + # TODO: permute these + (self.map_tensor_name(name_q), split_tensor[0]), + (self.map_tensor_name(name_k), split_tensor[1]), + (self.map_tensor_name(name_v), split_tensor[2]), + ] + + # append .weight to these tensors + if name == "resampler.proj" or name == "resampler.query": + name += ".weight" + + if name.startswith("resampler.proj"): + data_torch = data_torch.transpose(-1, -2).contiguous() + + if "post_layernorm" in name: + return [] # skip post_layernorm + + return [(self.map_tensor_name(name), data_torch)] + + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + del name, bid # unused + if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name: + return gguf.GGMLQuantizationType.F32 + if "v.resmpl." in new_name: + return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16 + return False + + # utils to work with MiniCPM-V resampler + + # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 + def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + @Model.register("WavTokenizerDec") class WavTokenizerDecModel(Model): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC @@ -5034,7 +5354,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Convert a huggingface model to a GGML compatible file") + description="Convert a huggingface model to a GGML compatible file\n\nNote: When converting vision models, this script may use internet connection to download configuration files via Hugging Face.") parser.add_argument( "--vocab-only", action="store_true", help="extract only the vocab", diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 66cfab2c3b796..41d968ed64531 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -53,6 +53,7 @@ else() add_subdirectory(tokenize) add_subdirectory(tts) add_subdirectory(gen-docs) + add_subdirectory(vision) if (NOT GGML_BACKEND_DL) # these examples use the backends directly and cannot be built with dynamic loading add_subdirectory(convert-llama2c-to-ggml) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 71e053b202cd2..d5cbbf2ed474c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3156,6 +3156,7 @@ struct server_context { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, + nullptr, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/vision/CMakeLists.txt b/examples/vision/CMakeLists.txt new file mode 100644 index 0000000000000..ab009157a957f --- /dev/null +++ b/examples/vision/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-vision) +add_executable(${TARGET} vision.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/vision/README.md b/examples/vision/README.md new file mode 100644 index 0000000000000..c2468444caa89 --- /dev/null +++ b/examples/vision/README.md @@ -0,0 +1,3 @@ +# llama.cpp/example/simple-vision + +Minimal demo for vision API diff --git a/examples/vision/vision.cpp b/examples/vision/vision.cpp new file mode 100644 index 0000000000000..359a023ae86e3 --- /dev/null +++ b/examples/vision/vision.cpp @@ -0,0 +1,224 @@ +#include "llama.h" +#include "common.h" +#include "arg.h" +#include "log.h" +#include "sampling.h" +#include +#include +#include +#include +#include + +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [--image img_path] [-p prompt]\n", argv[0]); + printf("\n"); +} + +static llama_vision_bitmap * load_image_from_file(const char * fname) { + std::ifstream file(fname, std::ios::binary); + if (!file) { + throw std::runtime_error("Unable to open file"); + } + std::vector image_bytes = std::vector( + std::istreambuf_iterator(file), + std::istreambuf_iterator()); + // decode image to byte array + int nx, ny, nc; + auto * bytes = (unsigned char *) image_bytes.data(); + auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3); + if (!img) { + throw std::runtime_error("failed to decode image bytes"); + } + // printf("nx=%d ny=%d nc=%d\n", nx, ny, nc); + // GGML_ASSERT(nc == 3); + // for (int y = 0; y < ny; y++) { + // for (int x = 0; x < nx; x++) { + // unsigned char * pix = img + x*nc + y*nc*nx; + // printf("%02x%02x%02x ", pix[0], pix[1], pix[2]); + // } + // printf("\n"); + // } + // printf("\n"); + llama_vision_bitmap * result = llama_vision_bitmap_init(nx, ny); + memcpy(result->data, img, nx*ny*3); + stbi_image_free(img); + return result; +} + +// split string by a `std::string delim` instead of `char delim` +static std::vector string_split_str(std::string s, const std::string & delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + return tokens; +} + +struct tokenized_part { + llama_tokens tokens; + bool is_image; +}; + +// TODO: this function is hacky, need to be improved +// static const llama_token TOKEN_IMG_PLACEMENT = -1000; +static const std::string IMG_PLACEMENT = ""; +static std::vector tokenize_with_img_placement( + const llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + std::vector parts = string_split_str(text, IMG_PLACEMENT); + std::vector output; + for (const auto & part : parts) { + //printf("tokenizing part: %s\n", part.c_str()); + bool add_bos = &parts.front() == ∂ + auto tokens = common_tokenize(vocab, part, add_special && add_bos, parse_special); + if (tokens.empty()) { + continue; + } + output.push_back({std::move(tokens), false}); + if (&parts.back() != &part) { + // add image token to middle of 2 parts + output.push_back({{}, true}); + } + } + return output; +} + +int main(int argc, char ** argv) { + common_params params; + + // default prompt for llava 1.5 + //params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:\nwhat did you see?\nASSISTANT:"; + // default prompt for minicpmv 2.6 + params.prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nwhat do you see?<|im_end|>\n<|im_start|>assistant\n"; + params.n_predict = 64; + params.n_batch = 2048; + params.n_ubatch = 1024; + params.n_gpu_layers = 99; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_VISION, print_usage)) { + return 1; + } + + common_init(); + common_init_result llama_init = common_init_from_params(params); + llama_context * ctx = llama_init.context.get(); + const llama_model * model = llama_init.model.get(); + const llama_vocab * vocab = llama_model_get_vocab(model); + if (!model) { + LOG_ERR("failed to load model\n"); + return 1; + } + + llama_vision_context_params vparams = llama_vision_context_default_params(); + vparams.n_threads = llama_n_threads(ctx); + llama_vision_context * vctx = llama_vision_init_from_model(model, vparams); + if (!vctx) { + LOG_ERR("model does not have vision encoder\n"); + return 1; + } + + struct common_sampler * smpl = common_sampler_init(model, params.sampling); + + llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + int n_past = 0; + int n_prompt = 0; + + // process image + llama_vision_tokens * img_tokens = nullptr; + { + const char * img_path = params.image[0].c_str(); + if (params.image[0].empty()) { + LOG_ERR("no image path provided\n"); + return 1; + } + llama_vision_bitmap * img = load_image_from_file(img_path); + LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny); + img_tokens = llama_vision_tokenize(vctx, img); + if (!img_tokens) { + LOG_ERR("failed to create image tokens\n"); + return 1; + } + if (llama_vision_encode(vctx, img_tokens)) { + LOG_ERR("failed to encode image\n"); + return 1; + } + LOG_INF("encoded image\n"); + } + + // process prompt + { + std::vector parts = tokenize_with_img_placement(vocab, params.prompt, true, true); + for (const tokenized_part & part : parts) { + if (!part.is_image) { + for (const llama_token & token : part.tokens) { + //LOG_INF("%d -> %s\n", token, common_token_to_piece(ctx, token).c_str()); + common_batch_add(batch, token, n_past++, {0}, &part == &parts.back()); + } + LOG_INF("eval text batch (%d tokens)\n", batch.n_tokens); + if (llama_decode(ctx, batch)) { + LOG_ERR("failed to decode text prompt\n"); + return 1; + } + } else { + auto * img_embd = llama_vision_get_output_tensor(vctx); + // std::vector output_debug(ggml_nelements(img_embd)); + // ggml_backend_tensor_get(img_embd, output_debug.data(), 0, ggml_nbytes(img_embd)); + // for (int row = 0; row < 10; row++) { + // int off = row * img_embd->ne[0]; + // printf("... %f %f %f\n", output_debug[off], output_debug[off+1], output_debug[off+2]); + // } + // exit(1); + llama_batch batch_img = llama_batch_get_one_from_tensor(img_embd, n_past, 0); + n_past += batch_img.n_tokens; + LOG_INF("eval image batch (%d embeddings)\n", batch_img.n_tokens); + if (llama_decode(ctx, batch_img)) { + LOG_ERR("failed to decode image prompt\n"); + return 1; + } + llama_batch_free(batch_img); + } + } + n_prompt = n_past; + LOG_INF("prompt processed, %d tokens\n", n_prompt); + } + + // generate response + while (true){ + int n_generated = n_past - n_prompt; + if (n_generated > params.n_predict) { + printf("\n"); + break; + } + + llama_token token_id = common_sampler_sample(smpl, ctx, -1); + common_sampler_accept(smpl, token_id, true); + printf("%s", common_token_to_piece(ctx, token_id).c_str()); + fflush(stdout); + + if (llama_vocab_is_eog(vocab, token_id)) { + printf("\n"); + break; + } + + // eval the token + common_batch_clear(batch); + common_batch_add(batch, token_id, n_past++, {0}, true); + if (llama_decode(ctx, batch)) { + LOG_ERR("failed to decode token\n"); + break; + } + } + + return 0; +} diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 19624eae04ece..3f0ccf13d1af9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -202,6 +202,9 @@ class Tokenizer: FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" + # Vision models + IMAGE_START_ID = "tokenizer.ggml.image_start_token_id" + IMAGE_END_ID = "tokenizer.ggml.image_end_token_id" # deprecated: PREFIX_ID = "tokenizer.ggml.prefix_token_id" SUFFIX_ID = "tokenizer.ggml.suffix_token_id" @@ -211,6 +214,32 @@ class Adapter: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" + class Vision: + # only support vision.type = "vit" for now + TYPE = "vision.type" + IMAGE_SIZE = "vision.image_size" + PATCH_SIZE = "vision.patch_size" + IMAGE_MEAN = "vision.image_mean" + IMAGE_STD = "vision.image_std" + + class Vit: + ARCHITECTURE = "vision.vit.architecture" + CONTEXT_LENGTH = "vision.vit.context_length" + EMBEDDING_LENGTH = "vision.vit.embedding_length" + BLOCK_COUNT = "vision.vit.block_count" + FEED_FORWARD_LENGTH = "vision.vit.feed_forward_length" + PROJECTION_TYPE = "vision.vit.projection_type" + PROJECTION_DIM = "vision.vit.projection_dim" + USE_GELU = "vision.vit.use_gelu" + MAX_POS_EMBEDDING = "vision.vit.max_position_embeddings" + MAX_SLICES = "vision.vit.max_slices" + PROJECTOR_TYPE = "vision.vit.projector_type" + SELECT_LAYER = "vision.vit.select_layer" + PATCH_MERGE_TYPE = "vision.vit.patch_merge_type" + HEAD_COUNT = "vision.vit.attention.head_count" + LAYERNORM_EPS = "vision.vit.attention.layer_norm_epsilon" + SCALE_FACTOR = "vision.vit.scale_factor" # only used by idefics3 for now + # # recommended mapping of model tensor names for storage in gguf # @@ -280,6 +309,11 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + # vision models + VISION_LLAVA = auto() + VISION_MOBILEVLM = auto() + VISION_MINICPMV = auto() + VISION_IDEFICS3 = auto() class MODEL_TENSOR(IntEnum): @@ -391,6 +425,7 @@ class MODEL_TENSOR(IntEnum): ENC_OUTPUT_NORM = auto() CLS = auto() # classifier CLS_OUT = auto() # classifier output projection + # wavtokenizer CONV1D = auto() CONVNEXT_DW = auto() CONVNEXT_NORM = auto() @@ -407,6 +442,39 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() + # vision + V_MMPROJ = auto() + V_MMPROJ_FC = auto() + V_MMPROJ_MLP = auto() + V_MMPROJ_PEG = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_V = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_OUTPUT = auto() + V_ENC_OUTPUT_NORM = auto() + V_ENC_FFN_UP = auto() + V_ENC_FFN_DOWN = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() + V_RESMPL_POS_EMBD_K = auto() # minicpmv + V_RESMPL_ATTN_Q = auto() # minicpmv + V_RESMPL_ATTN_K = auto() # minicpmv + V_RESMPL_ATTN_V = auto() # minicpmv + V_RESMPL_ATTN_OUT = auto() # minicpmv + V_RESMPL_KV = auto() # minicpmv + V_RESMPL_KV_NORM = auto() # minicpmv + V_RESMPL_POST_NORM = auto() # minicpmv + V_RESMPL_Q_NORM = auto() # minicpmv + V_RESMPL_PROJ = auto() # minicpmv + V_RESMPL_QUERY = auto() # minicpmv + V_TOK_EMBD_IMAGE = auto() # embedding for token + V_TOK_EMBD_END_IMAGE = auto() # embedding for token + V_TOK_EMBD_SLICE = auto() # embedding for token + V_TOK_EMBD_END_SLICE = auto() # embedding for token MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -468,6 +536,11 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + # vision + MODEL_ARCH.VISION_LLAVA: "llava", + MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm", + MODEL_ARCH.VISION_MINICPMV: "minicpmv", + MODEL_ARCH.VISION_IDEFICS3: "idefics3", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -595,6 +668,39 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + # vision + MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}", + MODEL_TENSOR.V_MMPROJ_FC: "v.mmproj.fc", + MODEL_TENSOR.V_MMPROJ_MLP: "v.mmproj.mlp.{bid}", + MODEL_TENSOR.V_MMPROJ_PEG: "v.mmproj.peg.{bid}", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.enc.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_K: "v.enc.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_V: "v.enc.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.enc.blk.{bid}.input_norm", + MODEL_TENSOR.V_ENC_OUTPUT: "v.enc.blk.{bid}.output", + MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.enc.blk.{bid}.output_norm", + MODEL_TENSOR.V_ENC_FFN_UP: "v.enc.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down", + MODEL_TENSOR.V_PRE_NORM: "v.pre_norm", + MODEL_TENSOR.V_POST_NORM: "v.post_norm", + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "v.resmpl.pos_embd_k", + MODEL_TENSOR.V_RESMPL_ATTN_Q: "v.resmpl.attn_q", + MODEL_TENSOR.V_RESMPL_ATTN_K: "v.resmpl.attn_k", + MODEL_TENSOR.V_RESMPL_ATTN_V: "v.resmpl.attn_v", + MODEL_TENSOR.V_RESMPL_ATTN_OUT: "v.resmpl.attn_out", + MODEL_TENSOR.V_RESMPL_KV: "v.resmpl.kv", + MODEL_TENSOR.V_RESMPL_KV_NORM: "v.resmpl.kv_norm", + MODEL_TENSOR.V_RESMPL_POST_NORM: "v.resmpl.post_norm", + MODEL_TENSOR.V_RESMPL_Q_NORM: "v.resmpl.q_norm", + MODEL_TENSOR.V_RESMPL_PROJ: "v.resmpl.proj", + MODEL_TENSOR.V_RESMPL_QUERY: "v.resmpl.query", + MODEL_TENSOR.V_TOK_EMBD_IMAGE: "v.tok_embd.image", + MODEL_TENSOR.V_TOK_EMBD_END_IMAGE: "v.tok_embd.end_image", + MODEL_TENSOR.V_TOK_EMBD_SLICE: "v.tok_embd.slice", + MODEL_TENSOR.V_TOK_EMBD_END_SLICE: "v.tok_embd.end_slice", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1556,6 +1662,80 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_OUT, ], + MODEL_ARCH.VISION_LLAVA: [ + MODEL_TENSOR.V_MMPROJ, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], + MODEL_ARCH.VISION_MOBILEVLM: [ + MODEL_TENSOR.V_MMPROJ_MLP, + MODEL_TENSOR.V_MMPROJ_PEG, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], + MODEL_ARCH.VISION_MINICPMV: [ + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_RESMPL_POS_EMBD_K, + MODEL_TENSOR.V_RESMPL_ATTN_Q, + MODEL_TENSOR.V_RESMPL_ATTN_K, + MODEL_TENSOR.V_RESMPL_ATTN_V, + MODEL_TENSOR.V_RESMPL_ATTN_OUT, + MODEL_TENSOR.V_RESMPL_KV, + MODEL_TENSOR.V_RESMPL_KV_NORM, + MODEL_TENSOR.V_RESMPL_POST_NORM, + MODEL_TENSOR.V_RESMPL_Q_NORM, + MODEL_TENSOR.V_RESMPL_PROJ, + MODEL_TENSOR.V_RESMPL_QUERY, + MODEL_TENSOR.V_TOK_EMBD_IMAGE, + MODEL_TENSOR.V_TOK_EMBD_END_IMAGE, + MODEL_TENSOR.V_TOK_EMBD_SLICE, + MODEL_TENSOR.V_TOK_EMBD_END_SLICE, + ], + MODEL_ARCH.VISION_IDEFICS3: [ + MODEL_TENSOR.V_MMPROJ_FC, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_POST_NORM, + ], # TODO } @@ -1637,6 +1817,18 @@ class PoolingType(IntEnum): CLS = 2 +class CLIPProjectorType(Enum): + MLP = 'mlp' + LDPV2 = 'ldpv2' + MINICPMV_2_5 = 'minicpmv-2.5' # resampler + MINICPMV_2_6 = 'minicpmv-2.6' # resampler + + +class CLIPPatchMergeType(Enum): + FLAT = 'flat' + SPATIAL_UNPAD = 'spatial_unpad' + + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 080d2b9dce5cb..a31ab736bc20a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -27,6 +27,8 @@ PoolingType, TokenType, ExpertGatingFuncType, + CLIPPatchMergeType, + CLIPProjectorType, ) from .quants import quant_shape_from_byte_shape @@ -875,6 +877,60 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_vision_type(self, value: str) -> None: + self.add_string(Keys.Vision.TYPE, value) + + def add_vision_image_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.IMAGE_SIZE, value) + + def add_vision_patch_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.PATCH_SIZE, value) + + def add_vision_vit_architecture(self, value: str) -> None: + self.add_string(Keys.Vision.Vit.ARCHITECTURE, value) + + def add_vision_vit_context_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.CONTEXT_LENGTH, value) + + def add_vision_vit_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.EMBEDDING_LENGTH, value) + + def add_vision_vit_block_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.BLOCK_COUNT, value) + + def add_vision_vit_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.FEED_FORWARD_LENGTH, value) + + def add_vision_vit_head_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.HEAD_COUNT, value) + + def add_vision_vit_max_position_embeddings(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.MAX_POS_EMBEDDING, value) + + def add_vision_vit_projector_type(self, value: CLIPProjectorType) -> None: + self.add_string(Keys.Vision.Vit.PROJECTOR_TYPE, value.value) + + def add_vision_vit_max_slices(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.MAX_SLICES, value) + + def add_vision_vit_select_layer(self, value: int) -> None: + self.add_int32(Keys.Vision.Vit.SELECT_LAYER, value) + + def add_vision_vit_patch_merge_type(self, value: CLIPPatchMergeType) -> None: + self.add_string(Keys.Vision.Vit.PATCH_MERGE_TYPE, value.value) + + def add_vision_vit_layer_norm_epsilon(self, value: float) -> None: + self.add_float32(Keys.Vision.Vit.LAYERNORM_EPS, value) + + def add_vision_vit_image_mean(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_MEAN, value) + + def add_vision_vit_image_std(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_STD, value) + + def add_vision_vit_scale_factor(self, value: int) -> None: + self.add_int32(Keys.Vision.Vit.SCALE_FACTOR, value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 617791e240b60..3f247d787ba11 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -787,6 +787,157 @@ class TensorNameMap: MODEL_TENSOR.POSNET_ATTN_OUT: ( "backbone.posnet.{bid}.proj_out", # wavtokenizer ), + + ############################################################################# + + MODEL_TENSOR.V_MMPROJ: ( + "multi_modal_projector.linear_{bid}", + ), + + MODEL_TENSOR.V_MMPROJ_FC: ( + "model.connector.modality_projection.proj", # SmolVLM + ), + + MODEL_TENSOR.V_MMPROJ_MLP: ( + "model.mm_projector.mlp.mlp.{bid}", + ), + + MODEL_TENSOR.V_MMPROJ_PEG: ( + "model.mm_projector.peg.peg.{bid}", + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS: ( + "vision_tower.vision_model.embeddings.class_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH: ( + "vision_tower.vision_model.embeddings.patch_embedding", + "vpm.embeddings.patch_embedding", + "model.vision_model.embeddings.patch_embedding", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_EMBD_POS: ( + "vision_tower.vision_model.embeddings.position_embedding", + "vpm.embeddings.position_embedding", + "model.vision_model.embeddings.position_embedding", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_ATTN_Q: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + "vpm.encoder.layers.{bid}.self_attn.q_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_ATTN_K: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "vpm.encoder.layers.{bid}.self_attn.k_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_ATTN_V: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "vpm.encoder.layers.{bid}.self_attn.v_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_INPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + "vpm.encoder.layers.{bid}.layer_norm1", + "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_OUTPUT: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + "vpm.encoder.layers.{bid}.self_attn.out_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + "vpm.encoder.layers.{bid}.layer_norm2", + "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_FFN_UP: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + "vpm.encoder.layers.{bid}.mlp.fc1", + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_FFN_DOWN: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + "vpm.encoder.layers.{bid}.mlp.fc2", + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM + ), + + MODEL_TENSOR.V_PRE_NORM: ( + "vision_tower.vision_model.pre_layrnorm", + ), + + MODEL_TENSOR.V_POST_NORM: ( + "vision_tower.vision_model.post_layernorm", + "model.vision_model.post_layernorm", # SmolVLM + ), + + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ( + "resampler.pos_embed_k", + ), + + MODEL_TENSOR.V_RESMPL_ATTN_Q: ( + "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_K: ( + "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_V: ( + "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( + "resampler.attn.out_proj", + ), + + MODEL_TENSOR.V_RESMPL_KV: ( + "resampler.kv_proj", + ), + + MODEL_TENSOR.V_RESMPL_POST_NORM: ( + "resampler.ln_post", + ), + + MODEL_TENSOR.V_RESMPL_KV_NORM: ( + "resampler.ln_kv", + ), + + MODEL_TENSOR.V_RESMPL_Q_NORM: ( + "resampler.ln_q", + ), + + MODEL_TENSOR.V_RESMPL_PROJ: ( + "resampler.proj", + ), + + MODEL_TENSOR.V_RESMPL_QUERY: ( + "resampler.query", + ), + + MODEL_TENSOR.V_TOK_EMBD_IMAGE:( + "v.tok_embd.image", # tensor generated from token embeddings + ), + + MODEL_TENSOR.V_TOK_EMBD_END_IMAGE:( + "v.tok_embd.end_image", # tensor generated from token embeddings + ), + + MODEL_TENSOR.V_TOK_EMBD_SLICE:( + "v.tok_embd.slice", # tensor generated from token embeddings + ), + + MODEL_TENSOR.V_TOK_EMBD_END_SLICE:( + "v.tok_embd.end_slice", # tensor generated from token embeddings + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index 6a44be404d914..85302c67dec8b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -231,6 +231,20 @@ extern "C" { bool sorted; } llama_token_data_array; + struct llama_vision_context; + + // Structure represents the basic input unit of vision model + // This can be a processed image or slices of images under the hood + struct llama_vision_tokens; + + // represent an RGB image + // size of data must be equal to 3*nx*ny + typedef struct llama_vision_bitmap { + uint32_t nx; + uint32_t ny; + unsigned char * data; + } llama_vision_bitmap; + typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_decode @@ -255,6 +269,8 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + + struct ggml_tensor * embd_tensor; } llama_batch; enum llama_model_kv_override_type { @@ -353,6 +369,10 @@ extern "C" { void * abort_callback_data; }; + struct llama_vision_context_params { + int32_t n_threads; + }; + // model quantization parameters typedef struct llama_model_quantize_params { int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() @@ -390,6 +410,7 @@ extern "C" { // TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_vision_context_params llama_vision_context_default_params(void); LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @@ -907,6 +928,10 @@ extern "C" { int32_t embd, int32_t n_seq_max); + // Allocates a batch based on a tensor, only used by vision API for now + // Unlike llama_batch_get_one, this will need to be freed after use + LLAMA_API struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id); + // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); @@ -1357,6 +1382,35 @@ extern "C" { // TODO: extend in the future //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); + // + // Vision API + // + + // Vision context + LLAMA_API struct llama_vision_context * llama_vision_init_from_model( + const struct llama_model * model, + struct llama_vision_context_params params); + LLAMA_API void llama_vision_free(struct llama_vision_context * ctx); + + // Container for RGB bitmap + LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny); + LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp); + + // Create image tokens from the RGB bitmap + LLAMA_API struct llama_vision_tokens * llama_vision_tokenize( + struct llama_vision_context * ctx, + struct llama_vision_bitmap * bmp); + LLAMA_API void llama_vision_tokens_free(struct llama_vision_tokens * img_tokens); + + // User must reserve N number of tokens in tokenized text prompt for each image + // LLAMA_API int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens); + + // Encode patches into embeddings + LLAMA_API int32_t llama_vision_encode( + struct llama_vision_context * ctx, + struct llama_vision_tokens * img_tokens); + LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx); + // // Model split // diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b340dae5b28cd..aded67d4efdcc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -27,9 +27,10 @@ add_library(llama llama-quant.cpp llama-sampling.cpp llama-vocab.cpp - unicode-data.cpp - unicode.cpp + llama-vision.cpp unicode.h + unicode.cpp + unicode-data.cpp ) target_include_directories(llama PUBLIC . ../include ../common) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 28f2bbc8f72bf..f07ef9afe844c 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -3,6 +3,7 @@ #include "llama-impl.h" #include +#include static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, @@ -63,6 +64,10 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_VISION_LLAVA, "llava" }, + { LLM_ARCH_VISION_MOBILEVLM, "mobilevlm" }, + { LLM_ARCH_VISION_MINICPMV, "minicpmv" }, + { LLM_ARCH_VISION_IDEFICS3, "idefics3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -191,6 +196,28 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_VISION_TYPE, "vision.type" }, + { LLM_KV_VISION_IMAGE_SIZE, "vision.image_size" }, + { LLM_KV_VISION_PATCH_SIZE, "vision.patch_size" }, + { LLM_KV_VISION_IMAGE_MEAN, "vision.image_mean" }, + { LLM_KV_VISION_IMAGE_STD, "vision.image_std" }, + { LLM_KV_VISION_VIT_ARCHITECTURE, "vision.vit.architecture" }, + { LLM_KV_VISION_VIT_CONTEXT_LENGTH, "vision.vit.context_length" }, + { LLM_KV_VISION_VIT_EMBEDDING_LENGTH, "vision.vit.embedding_length" }, + { LLM_KV_VISION_VIT_BLOCK_COUNT, "vision.vit.block_count" }, + { LLM_KV_VISION_VIT_FEED_FORWARD_LENGTH, "vision.vit.feed_forward_length" }, + { LLM_KV_VISION_VIT_PROJECTION_TYPE, "vision.vit.projection_type" }, + { LLM_KV_VISION_VIT_PROJECTION_DIM, "vision.vit.projection_dim" }, + { LLM_KV_VISION_VIT_USE_GELU, "vision.vit.use_gelu" }, + { LLM_KV_VISION_VIT_MAX_POS_EMBD, "vision.vit.max_position_embeddings" }, + { LLM_KV_VISION_VIT_MAX_SLICES, "vision.vit.max_slices" }, + { LLM_KV_VISION_VIT_PROJECTOR_TYPE, "vision.vit.projector_type" }, + { LLM_KV_VISION_VIT_SELECT_LAYER, "vision.vit.select_layer" }, + { LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, "vision.vit.patch_merge_type" }, + { LLM_KV_VISION_VIT_HEAD_COUNT, "vision.vit.attention.head_count" }, + { LLM_KV_VISION_VIT_LAYERNORM_EPS, "vision.vit.attention.layer_norm_epsilon" }, + { LLM_KV_VISION_VIT_SCALE_FACTOR, "vision.vit.scale_factor" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -1317,6 +1344,95 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + // vision + { + LLM_ARCH_VISION_LLAVA, + { + { LLM_TENSOR_V_MMPROJ, "v.mmproj_%d" }, + { LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_PRE_NORM, "v.pre_norm" }, + { LLM_TENSOR_V_POST_NORM, "v.post_norm" }, + } + }, + { + LLM_ARCH_VISION_MOBILEVLM, + { + { LLM_TENSOR_V_MMPROJ_MLP, "v.mmproj.mlp.%d" }, + { LLM_TENSOR_V_MMPROJ_PEG, "v.mmproj.peg.%d" }, + { LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_PRE_NORM, "v.pre_norm" }, + { LLM_TENSOR_V_POST_NORM, "v.post_norm" }, + } + }, + { + LLM_ARCH_VISION_MINICPMV, + { + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_RESMPL_POS_EMBD_K, "v.resmpl.pos_embd_k" }, + { LLM_TENSOR_V_RESMPL_ATTN_Q, "v.resmpl.attn_q" }, + { LLM_TENSOR_V_RESMPL_ATTN_K, "v.resmpl.attn_k" }, + { LLM_TENSOR_V_RESMPL_ATTN_V, "v.resmpl.attn_v" }, + { LLM_TENSOR_V_RESMPL_ATTN_OUT, "v.resmpl.attn_out" }, + { LLM_TENSOR_V_RESMPL_KV, "v.resmpl.kv" }, + { LLM_TENSOR_V_RESMPL_KV_NORM, "v.resmpl.kv_norm" }, + { LLM_TENSOR_V_RESMPL_POST_NORM, "v.resmpl.post_norm" }, + { LLM_TENSOR_V_RESMPL_Q_NORM, "v.resmpl.q_norm" }, + { LLM_TENSOR_V_RESMPL_PROJ, "v.resmpl.proj" }, + { LLM_TENSOR_V_RESMPL_QUERY, "v.resmpl.query" }, + { LLM_TENSOR_V_TOK_EMBD_IMAGE, "v.tok_embd.image" }, + { LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "v.tok_embd.end_image" }, + { LLM_TENSOR_V_TOK_EMBD_SLICE, "v.tok_embd.slice" }, + { LLM_TENSOR_V_TOK_EMBD_END_SLICE, "v.tok_embd.end_slice" }, + } + }, + { + LLM_ARCH_VISION_IDEFICS3, + { + { LLM_TENSOR_V_MMPROJ_FC, "v.mmproj.fc" }, + { LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_PRE_NORM, "v.pre_norm" }, + { LLM_TENSOR_V_POST_NORM, "v.post_norm" }, + } + }, { LLM_ARCH_UNKNOWN, { @@ -1466,6 +1582,39 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // vision + {LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_MMPROJ_PEG, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_EMBD_CLS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, + {LLM_TENSOR_V_ENC_EMBD_PATCH, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, + {LLM_TENSOR_V_ENC_EMBD_POS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, + {LLM_TENSOR_V_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_INPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_V_ENC_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_V_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, + {LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_POS_EMBD_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_ADD}}, + {LLM_TENSOR_V_RESMPL_ATTN_Q, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_V, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_OUT, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_KV, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_KV_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_POST_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_Q_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_PROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_QUERY, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + // special token embeddings for image + {LLM_TENSOR_V_TOK_EMBD_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_END_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_END_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 2ec2e2362eba1..4aa682cf60028 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -67,6 +67,11 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + // vision + LLM_ARCH_VISION_LLAVA, + LLM_ARCH_VISION_MOBILEVLM, + LLM_ARCH_VISION_MINICPMV, + LLM_ARCH_VISION_IDEFICS3, LLM_ARCH_UNKNOWN, }; @@ -195,6 +200,28 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_VISION_TYPE, + LLM_KV_VISION_IMAGE_SIZE, + LLM_KV_VISION_PATCH_SIZE, + LLM_KV_VISION_IMAGE_MEAN, + LLM_KV_VISION_IMAGE_STD, + LLM_KV_VISION_VIT_ARCHITECTURE, + LLM_KV_VISION_VIT_CONTEXT_LENGTH, + LLM_KV_VISION_VIT_EMBEDDING_LENGTH, + LLM_KV_VISION_VIT_BLOCK_COUNT, + LLM_KV_VISION_VIT_FEED_FORWARD_LENGTH, + LLM_KV_VISION_VIT_PROJECTION_TYPE, + LLM_KV_VISION_VIT_PROJECTION_DIM, + LLM_KV_VISION_VIT_USE_GELU, + LLM_KV_VISION_VIT_MAX_POS_EMBD, + LLM_KV_VISION_VIT_MAX_SLICES, + LLM_KV_VISION_VIT_PROJECTOR_TYPE, + LLM_KV_VISION_VIT_SELECT_LAYER, + LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, + LLM_KV_VISION_VIT_HEAD_COUNT, + LLM_KV_VISION_VIT_LAYERNORM_EPS, + LLM_KV_VISION_VIT_SCALE_FACTOR, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -328,11 +355,46 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + // vision + LLM_TENSOR_V_MMPROJ, + LLM_TENSOR_V_MMPROJ_FC, + LLM_TENSOR_V_MMPROJ_MLP, + LLM_TENSOR_V_MMPROJ_PEG, + LLM_TENSOR_V_ENC_EMBD_CLS, + LLM_TENSOR_V_ENC_EMBD_PATCH, + LLM_TENSOR_V_ENC_EMBD_POS, + LLM_TENSOR_V_ENC_ATTN_Q, + LLM_TENSOR_V_ENC_ATTN_K, + LLM_TENSOR_V_ENC_ATTN_V, + LLM_TENSOR_V_ENC_INPUT_NORM, + LLM_TENSOR_V_ENC_OUTPUT, + LLM_TENSOR_V_ENC_OUTPUT_NORM, + LLM_TENSOR_V_ENC_FFN_UP, + LLM_TENSOR_V_ENC_FFN_DOWN, + LLM_TENSOR_V_PRE_NORM, + LLM_TENSOR_V_POST_NORM, + // vision - minicpmv + LLM_TENSOR_V_RESMPL_POS_EMBD_K, + LLM_TENSOR_V_RESMPL_ATTN_Q, + LLM_TENSOR_V_RESMPL_ATTN_K, + LLM_TENSOR_V_RESMPL_ATTN_V, + LLM_TENSOR_V_RESMPL_ATTN_OUT, + LLM_TENSOR_V_RESMPL_KV, + LLM_TENSOR_V_RESMPL_KV_NORM, + LLM_TENSOR_V_RESMPL_POST_NORM, + LLM_TENSOR_V_RESMPL_Q_NORM, + LLM_TENSOR_V_RESMPL_PROJ, + LLM_TENSOR_V_RESMPL_QUERY, + LLM_TENSOR_V_TOK_EMBD_IMAGE, + LLM_TENSOR_V_TOK_EMBD_END_IMAGE, + LLM_TENSOR_V_TOK_EMBD_SLICE, + LLM_TENSOR_V_TOK_EMBD_END_SLICE, }; enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, + LLM_TENSOR_LAYER_PROJECTION, LLM_TENSOR_LAYER_OUTPUT, }; diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..c656e16093520 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -31,6 +31,7 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { /*n_seq_id =*/ ubatch_n_seq_id.data(), /*seq_id =*/ ubatch_seq_id.data(), /*output =*/ ubatch_output.data(), + /*embd_tensor =*/ nullptr, }; return ubatch; } @@ -55,7 +56,10 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s } else { ubatch.token = nullptr; } - if (batch->embd) { + if (batch->embd_tensor) { + // TODO @ngxson : we also need to split the tensor by doing a ggml_view + ubatch.embd_tensor = batch->embd_tensor; + } else if (batch->embd) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { memcpy( @@ -139,7 +143,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); ubatch.equal_seqs = false; if (!seq.empty()) { llama_sbatch_seq & s = seq[0]; @@ -152,7 +156,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); if (!seq.empty()) { size_t length = 0; size_t n_tokens_in_ubatch = 0; @@ -179,7 +183,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); if (!seq.empty()) { llama_sbatch_seq & s = seq[seq.size() - 1]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; @@ -320,6 +324,7 @@ struct llama_batch llama_batch_get_one( /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, + /*embd_tensor =*/ nullptr, }; } @@ -332,6 +337,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, + /*embd_tensor =*/ nullptr, }; if (embd) { @@ -353,6 +359,35 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ return batch; } +struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id) { + GGML_ASSERT(tensor->ne[2] == 1 && tensor->ne[3] == 1); + int32_t n_tokens = tensor->ne[1]; + llama_batch batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*embd_tensor =*/ tensor, + }; + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.pos [i] = p0 + i; + batch.seq_id [i] = (llama_seq_id *) malloc(sizeof(llama_seq_id)); + batch.seq_id [i][0] = seq_id; + batch.n_seq_id[i] = 1; + } + batch.seq_id[n_tokens] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); diff --git a/src/llama-batch.h b/src/llama-batch.h index f1df40d27086e..9e3562c8867ed 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -21,6 +21,8 @@ struct llama_ubatch { int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] + + struct ggml_tensor * embd_tensor; }; struct llama_sbatch_seq { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c2fcce42a7d58..35d65b2ca2ae9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1194,7 +1194,10 @@ int llama_context::decode(llama_batch & inp_batch) { batch_guard bg(*kv_self); - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO @ngxson : we can do better than this + GGML_ASSERT((batch.token && !batch.embd && !batch.embd_tensor) + || (!batch.token && batch.embd && !batch.embd_tensor) + || (!batch.token && !batch.embd && batch.embd_tensor)); // NOLINT if (batch.token) { for (int64_t i = 0; i < n_tokens_all; ++i) { diff --git a/src/llama-context.h b/src/llama-context.h index 04facb544cb1a..5a44075fcc236 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -5,6 +5,7 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-vision.h" #include "ggml-cpp.h" diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4e90873397ca4..d5b603ce0a175 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -43,7 +43,7 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); } - if (ubatch->embd) { + if (ubatch->embd && !ubatch->embd_tensor) { const int64_t n_embd = embd->ne[0]; const int64_t n_tokens = ubatch->n_tokens; @@ -983,6 +983,10 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } + } else if (ubatch.embd_tensor) { + inp->embd = ubatch.embd_tensor; + ggml_set_input(ubatch.embd_tensor); + } else { inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); ggml_set_input(inp->embd); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index dbb7abd317b6f..6371594827eab 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -99,7 +99,7 @@ struct llama_hparams { float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; - // Additional scale factors (Granite/Granite MoE) + // Additional scale factors (Granite/Granite MoE/MiniCPM) float f_residual_scale = 0.0f; float f_embedding_scale = 0.0f; float f_attention_scale = 0.0f; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 05d58ad90eba9..4ddffcfd84f22 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -375,6 +375,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, int32_t & result, bool required); template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); @@ -439,6 +440,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 750a702ff77a4..033717dac2dae 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2,6 +2,7 @@ #include "llama-impl.h" #include "llama-mmap.h" +#include "llama-vision.h" #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" @@ -222,6 +223,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); } break; + case GGML_OP_CONCAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_concat(ctx, w, b, 0); + } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); } @@ -1296,6 +1302,56 @@ void llama_model::load_hparams(llama_model_loader & ml) { } hparams.rope_type = llama_model_rope_type(this); + + // vision model + auto & vparams = vit.hparams; + std::string vision_type; + ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); + if (vision_type == "vit") { + LLAMA_LOG_INFO("%s: loading ViT vision model\n", __func__); + has_vision = true; + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, vparams.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, vparams.patch_size, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, vparams.image_mean, 3, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, vparams.image_std, 3, true); + ml.get_key(LLM_KV_VISION_VIT_EMBEDDING_LENGTH, vparams.hidden_size, true); + ml.get_key(LLM_KV_VISION_VIT_BLOCK_COUNT, vparams.n_layer, true); + ml.get_key(LLM_KV_VISION_VIT_FEED_FORWARD_LENGTH, vparams.n_intermediate, true); + ml.get_key(LLM_KV_VISION_VIT_HEAD_COUNT, vparams.n_head, true); + ml.get_key(LLM_KV_VISION_VIT_LAYERNORM_EPS, vparams.eps, true); + ml.get_key(LLM_KV_VISION_VIT_SELECT_LAYER, vparams.select_layer, true); + ml.get_key(LLM_KV_VISION_VIT_MAX_POS_EMBD, vparams.max_pos_embd, true); + ml.get_key(LLM_KV_VISION_VIT_SCALE_FACTOR, vparams.scale_factor, false); + { + std::string name; + ml.get_key(LLM_KV_VISION_VIT_PROJECTOR_TYPE, name, true); + vparams.proj_type = vision_projector_type_from_name(name); + if (vparams.proj_type == VISION_PROJECTOR_TYPE_UNKNOWN) { + throw std::runtime_error(format("unsupported clip projector type: %s", name.c_str())); + } + } + { + std::string name; + ml.get_key(LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, name, false); + vparams.mm_patch_merge_type = mm_patch_merge_from_name(name); + } + { + std::string arch; + ml.get_key(LLM_KV_VISION_VIT_ARCHITECTURE, arch, true); + vparams.arch = llm_arch_from_string(arch); + if (vparams.arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error(format("unsupported vision arch: %s", arch.c_str())); + } + } + } else if (!vision_type.empty()) { + throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); + } + + // arch-specific CLIP hparams + // switch (vparams.arch) { + // case VISION_ARCH_LLAVA: + // default: (void)0; + // } } void llama_model::load_vocab(llama_model_loader & ml) { @@ -1485,7 +1541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // sanity checks - if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (info.layer == LLM_TENSOR_LAYER_PROJECTION) { + // nothing to check + } else if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { if (tn.bid != -1) { GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); } @@ -1507,6 +1565,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_TENSOR_LAYER_REPEATING: buft_list = pimpl->dev_layer.at(tn.bid).buft_list; break; + case LLM_TENSOR_LAYER_PROJECTION: + buft_list = pimpl->dev_layer.back().buft_list; + break; default: GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } @@ -3508,6 +3569,179 @@ bool llama_model::load_tensors(llama_model_loader & ml) { __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1, ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); } + + // load tensors for vision model + auto & vparams = vit.hparams; + if (has_vision) { + // language params + const int64_t n_embd = hparams.n_embd; + // vision params + const int64_t n_vlayer = vparams.n_layer; + const int64_t n_vembd = vparams.hidden_size; + const int64_t n_vff = vparams.n_intermediate; + const int64_t max_pos_embd = vparams.max_pos_embd; + const int64_t n_channel = 3; // always RGB + const int64_t patch_size = vparams.patch_size; + const auto tn = LLM_TN(vparams.arch); + + // TODO: vit is cpu only for now + vit.buft = ggml_backend_cpu_buffer_type(); + vit.layers.resize(n_vlayer); + + switch (vparams.arch) { + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + { + if (vparams.arch == LLM_ARCH_VISION_LLAVA) { + vit.mm_1_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff}, 0); + vit.mm_1_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff}, 0); + vit.mm_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff}, 0); + vit.mm_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff}, 0); + } else if (vparams.arch == LLM_ARCH_VISION_MOBILEVLM) { + vit.mm_model_mlp_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd}, 0); + vit.mm_model_mlp_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd}, 0); + vit.mm_model_mlp_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd}, 0); + vit.mm_model_mlp_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd}, 0); + vit.mm_model_peg_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd}, 0); + vit.mm_model_peg_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd}, 0); + } + + vit.class_embedding = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd}, 0); + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); + + vit.pre_norm_w = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd}, 0); + vit.pre_norm_b = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd}, 0); + vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); + vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_vlayer; ++i) { + auto & layer = vit.layers[i]; + + layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0); + layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0); + layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0); + layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0); + layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0); + layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0); + + layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0); + layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0); + + layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0); + layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0); + + layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); + layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); + } + } break; + case LLM_ARCH_VISION_MINICPMV: + { + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); + + // tok embd + vit.mm_tok_embd_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_IMAGE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_end_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_end_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_vlayer; ++i) { + auto & layer = vit.layers[i]; + + layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0); + layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0); + layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0); + layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0); + layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0); + layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0); + + layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0); + layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0); + + layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0); + layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0); + + layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); + layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); + } + + // resampler, we consider it as one layer on top of the encoder + int il = n_vlayer - 1; + int rs_n_embd = llama_vision_n_mmproj_embd(vit); + vit.mm_model_pos_embed_k = create_tensor(tn(LLM_TENSOR_V_RESMPL_POS_EMBD_K, "weight", il), {rs_n_embd, max_pos_embd}, 0); + vit.mm_model_query = create_tensor(tn(LLM_TENSOR_V_RESMPL_QUERY, "weight", il), {rs_n_embd, 64}, 0); // why 64? + vit.mm_model_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_PROJ, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_kv_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV, "weight", il), {n_vembd, rs_n_embd}, 0); + vit.mm_model_attn_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_k_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_k_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_v_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_v_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_o_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_o_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_kv_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_kv_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_post_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_post_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "bias" , il), {rs_n_embd}, 0); + + } break; + case LLM_ARCH_VISION_IDEFICS3: + { + int scale_factor = vit.hparams.scale_factor; + vit.projection = create_tensor(tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd}, 0); + + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); + + vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, 0); + vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, 0); + + for (int i = 0; i < n_vlayer; ++i) { + auto & layer = vit.layers[i]; + + layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0); + layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0); + layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0); + layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0); + layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0); + layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0); + + layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0); + layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0); + + layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0); + layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0); + + layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); + layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); + } + } break; + default: + throw std::runtime_error("unknown vision architecture"); + } + + if (llama_vision_n_mmproj_embd(vit) != hparams.n_embd) { + std::runtime_error("model has vision, but n_mmproj_embd != n_embd"); + } + } } ml.done_getting_tensors(); @@ -11298,6 +11532,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN2VL: return LLAMA_ROPE_TYPE_MROPE; + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + case LLM_ARCH_VISION_MINICPMV: + case LLM_ARCH_VISION_IDEFICS3: + GGML_ABORT("vision arch does not use RoPE"); + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); diff --git a/src/llama-model.h b/src/llama-model.h index 55c26a92b02d2..34b56a6f07699 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -6,6 +6,7 @@ #include "llama-hparams.h" #include "llama-memory.h" #include "llama-vocab.h" +#include "llama-vision.h" #include #include @@ -366,6 +367,10 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + // vision + bool has_vision = false; + llama_vision_model vit; + // TODO: move this to new llm_arch_model_i interface llama_memory_i * create_memory() const; // TODO: params diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp new file mode 100644 index 0000000000000..f961acc81d079 --- /dev/null +++ b/src/llama-vision.cpp @@ -0,0 +1,1343 @@ +#include "llama.h" +#include "llama-vision.h" +#include "llama-impl.h" +#include "llama-context.h" + +#include // memcpy +#include +#include + +#ifndef NDEBUG +// for debugging +#include +#include +#include + +// export llama_image_u8 to bmp file for debugging +// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c +static int bmp_export(const struct llama_image_u8 &img, const std::string &location); +#endif + +struct img_size { + int width; + int height; + img_size(int w, int h) : width(w), height(h) {} +}; + +// RGB uint8 image +// Memory layout: RGBRGBRGB... +struct llama_image_u8 { + int nx; + int ny; + std::vector buf; + llama_image_u8() {} + llama_image_u8(const llama_vision_bitmap & bmp) { + nx = bmp.nx; + ny = bmp.ny; + buf.resize(nx*ny*3); + memcpy(buf.data(), bmp.data, buf.size()); + } +}; + +uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel) { + auto & proj_type = vmodel.hparams.proj_type; + if (proj_type == VISION_PROJECTOR_TYPE_MLP) { + return vmodel.mm_2_b + ? vmodel.mm_2_b->ne[0] + : vmodel.projection->ne[1]; // idefics3 + } else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) { + return vmodel.mm_model_peg_0_b->ne[0]; + } else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) { + return 4096; // resampler + } else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) { + return 3584; // resampler + } else { + GGML_ASSERT(false && "invalid proj type"); + } +} + + +// +// internal utils +// + +static int get_n_patches_x(const llama_vision_context & ctx) { + auto & hparams = ctx.model->hparams; + return hparams.image_size / hparams.patch_size; +} + +static int get_n_patches_y(const llama_vision_context & ctx) { + return get_n_patches_x(ctx); +} + +static int get_n_patches(const llama_vision_context & ctx) { + return get_n_patches_x(ctx) * get_n_patches_y(ctx); +} + +// +// bitmap utils +// + +/** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image in the format (width, height). + * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + * @return The best fit resolution in the format (width, height). + */ +static img_size select_best_resolution(const img_size & original_size, const std::vector& possible_resolutions) { + int original_width = original_size.width; + int original_height = original_size.height; + + img_size best_fit(0, 0); + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto& resolution : possible_resolutions) { + int width = resolution.width; + int height = resolution.height; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; +} + +static bool bicubic_resize(const llama_image_u8 & img, llama_image_u8 & dst, int target_width, int target_height) { + auto clip = [](int x, int lower, int upper) -> int { + return std::max(lower, std::min(x, upper)); + }; + + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +static std::vector divide_to_patches_u8(const llama_image_u8 & image, int patch_size) { + std::vector patches; + int width = image.nx; + int height = image.ny; + for (int i = 0; i < height; i += patch_size) { + for (int j = 0; j < width; j += patch_size) { + llama_image_u8 patch; + patch.nx = std::min(patch_size, width - j); + patch.ny = std::min(patch_size, height - i); + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = 0; y < patch.ny; ++y) { + for (int x = 0; x < patch.nx; ++x) { + for (int c = 0; c < 3; ++c) { + patch.buf[3 * (y * patch.nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; + } + } + } + patches.push_back(patch); + } + } + return patches; +} + +// llava-1.6 type of resize_and_pad (black) +static llama_image_u8 resize_and_pad_image(const llama_image_u8 & image, const img_size & target_resolution) { + int target_width = target_resolution.width; + int target_height = target_resolution.height; + + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; + + int new_width, new_height; + + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + llama_image_u8 resized_image; + // bilinear_resize(image, resized_image, new_width, new_height); + bicubic_resize(image, resized_image, new_width, new_height); + + llama_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + return padded_image; +} + +static void normalize_image_u8_to_f32(const llama_image_u8 & src, std::vector & dst, const std::array & mean, const std::array & std) { + dst.resize(src.buf.size()); + + for (size_t i = 0; i < src.buf.size(); ++i) { + int c = i % 3; // rgb + dst[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; + } +} + + +// +// processor +// + +struct llama_vision_processor { + const llama_vision_context & ctx; + llama_vision_processor(const llama_vision_context & ctx) : ctx(ctx) {} + virtual llama_vision_tokens tokenize(const llama_image_u8 & img) = 0; + virtual ~llama_vision_processor() = default; +}; + +// inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py +struct llama_vision_processor_llava : llama_vision_processor { + llama_vision_processor_llava(const llama_vision_context & ctx) : llama_vision_processor(ctx) {} + + virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override { + bool pad_to_square = true; + auto & params = ctx.model->hparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } + + llama_vision_tokens output_slices; + output_slices.n_px = get_n_patches_x(ctx); + output_slices.n_py = get_n_patches_y(ctx); + output_slices.px = params.patch_size; + output_slices.py = params.patch_size; + + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + llama_image_u8 temp; + if (pad_to_square && img.nx != img.ny) { + // if the image is not square, pad it to a square + int longer_side = std::max(img.nx, img.ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.buf.resize(3 * longer_side * longer_side); + const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) + + // fill with background color + for (size_t i = 0; i < temp.buf.size(); i++) { + temp.buf[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img.ny; y++) { + for (int x = 0; x < img.nx; x++) { + const int i = 3 * (y * img.nx + x); + const int j = 3 * (y * temp.nx + x); + temp.buf[j] = img.buf[i]; + temp.buf[j+1] = img.buf[i+1]; + temp.buf[j+2] = img.buf[i+2]; + } + } + } else if (params.image_grid_pinpoints[0] != 0) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + std::vector possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { + img_size s(0, 0); + s.width = params.image_grid_pinpoints[i]; + s.height = params.image_grid_pinpoints[i+1]; + possible_resolutions.push_back(s); + } + img_size best_resolution = select_best_resolution(img_size(img.nx, img.ny), possible_resolutions); + // debug_image_save_to_bmp(*img, "input.bmp"); + temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 + // debug_image_save_to_bmp(*temp, "resized.bmp"); + + std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + + llama_image_u8 image_original_resize; + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + patches.insert(patches.begin(), image_original_resize); + output_slices.buf.resize(patches.size()); + int num = 0; + for (auto & patch : patches) { + normalize_image_u8_to_f32(patch, output_slices.buf[num], params.image_mean, params.image_std); + num++; + } + return output_slices; + } else { + temp.nx = img.nx; + temp.ny = img.ny; + temp.buf.resize(img.buf.size()); + memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); + } + + const int nx = temp.nx; + const int ny = temp.ny; + // bmp_export(temp, "resized_vanilla.bmp"); + + const int nx2 = params.image_size; + const int ny2 = params.image_size; + std::vector res; + res.resize(3 * nx2 * ny2); + + const float scale = std::max(nx, ny) / (float)params.image_size; + + const int nx3 = int(nx / scale + 0.5f); + const int ny3 = int(ny / scale + 0.5f); + + const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; + const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f) * scale - 0.5f; + const float sy = (y + 0.5f) * scale - 0.5f; + + const int x0 = std::max(0, (int)std::floor(sx)); + const int y0 = std::max(0, (int)std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3 * (y0 * nx + x0) + c; + const int j01 = 3 * (y0 * nx + x1) + c; + const int j10 = 3 * (y1 * nx + x0) + c; + const int j11 = 3 * (y1 * nx + x1) + c; + + const float v00 = temp.buf[j00]; + const float v01 = temp.buf[j01]; + const float v10 = temp.buf[j10]; + const float v11 = temp.buf[j11]; + + const float v0 = v00 * (1.0f - dx) + v01 * dx; + const float v1 = v10 * (1.0f - dx) + v11 * dx; + + const float v = v0 * (1.0f - dy) + v1 * dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3 * (y * nx3 + x) + c; + + res[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; + } + } + } + + output_slices.buf.resize(1); + output_slices.buf[0] = std::move(res); + + return output_slices; + } +}; + +struct llama_vision_processor_uhd : llama_vision_processor { + llama_vision_processor_uhd(const llama_vision_context & ctx) : llama_vision_processor(ctx) {} + + int ensure_divide(int length, int patch_size) { + return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); + } + + img_size find_best_resize(const img_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; + if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { + float r = static_cast(width) / height; + height = static_cast(scale_resolution / std::sqrt(r)); + width = static_cast(height * r); + } + int best_width = ensure_divide(width, patch_size); + int best_height = ensure_divide(height, patch_size); + return img_size(best_width, best_height); + } + + img_size get_refine_size(const img_size & original_size, const img_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; + int grid_x = grid.width; + int grid_y = grid.height; + + int refine_width = ensure_divide(width, grid_x); + int refine_height = ensure_divide(height, grid_y); + + int grid_width = refine_width / grid_x; + int grid_height = refine_height / grid_y; + + // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) + auto best_grid = find_best_resize({grid_width, grid_height}, scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair + + // img_size refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) + img_size refine_size = img_size(best_grid.width * grid_x, best_grid.height * grid_y); // (new line) + return refine_size; + } + + img_size find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { + std::vector candidate_split_grids_nums; + for (int i : {multiple - 1, multiple, multiple + 1}) { + if (i == 1 || i > max_slice_nums) { + continue; + } + candidate_split_grids_nums.push_back(i); + } + + std::vector candidate_grids; + for (int split_grids_nums : candidate_split_grids_nums) { + int m = 1; + while (m <= split_grids_nums) { + if (split_grids_nums % m == 0) { + candidate_grids.emplace_back(m, split_grids_nums / m); + } + ++m; + } + } + + img_size best_grid = img_size(1, 1); + float min_error = std::numeric_limits::infinity(); + for (const auto& grid : candidate_grids) { + float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height)); + if (error < min_error) { + best_grid = grid; + min_error = error; + } + } + return best_grid; + } + + std::vector> slice_image( + const llama_image_u8 & img, + const int max_slice_nums = 9, + const int scale_resolution = 448, + const int patch_size = 14) { + const img_size original_size = img_size(img.nx, img.ny); + const int original_width = img.nx; + const int original_height = img.ny; + const float log_ratio = log(1.0*original_width/original_height); + const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); + const int multiple = fmin(ceil(ratio), max_slice_nums); + + std::vector> images; + LLAMA_LOG_DEBUG("%s: multiple %d\n", __func__, multiple); + images.push_back(std::vector()); + + if (multiple <= 1) { + auto best_size = find_best_resize(original_size, scale_resolution, patch_size, true); + llama_image_u8 source_image; + bicubic_resize(img, source_image, best_size.width, best_size.height); + // source_image = image.resize(best_size, Image.Resampling.BICUBIC) + images.back().push_back(source_image); + } else if (multiple > 1) { + auto best_size = find_best_resize(original_size, scale_resolution, patch_size); + llama_image_u8 source_image; + bicubic_resize(img, source_image, best_size.width, best_size.height); + // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.width, best_size.height); + images.back().push_back(source_image); + + img_size best_grid = find_best_grid(max_slice_nums, multiple, log_ratio); + LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.width, best_grid.height); + + auto refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); + llama_image_u8 refine_image; + // TODO: so far, we spend most of the time in bicubic_resize, we should optimize it + bicubic_resize(img, refine_image, refine_size.width, refine_size.height); + + LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.width, refine_size.height); + + // split_to_patches + int width = refine_image.nx; + int height = refine_image.ny; + int grid_x = int(width / best_grid.width); + int grid_y = int(height / best_grid.height); + for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.height; patches_i += grid_y, ic += 1){ + std::vector patches_out; + images.push_back(std::vector()); + for (int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.width; patches_j += grid_x, jc += 1) { + llama_image_u8 patch; + patch.nx = grid_x; + patch.ny = grid_y; + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = patches_i; y < patches_i + grid_y; ++y) { + for (int x = patches_j; x < patches_j + grid_x; ++x) { + const int i = 3 * (y * refine_image.nx + x); + const int j = 3 * ((y-patches_i) * patch.nx + (x-patches_j)); + patch.buf[j] = refine_image.buf[i]; + patch.buf[j+1] = refine_image.buf[i+1]; + patch.buf[j+2] = refine_image.buf[i+2]; + } + } + patches_out.push_back(std::move(patch)); + } + images.push_back(std::move(patches_out)); + } + } + return images; + } + + virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override { + auto & params = ctx.model->hparams; + + std::vector> imgs = slice_image(img); + + llama_vision_tokens output; + output.n_px = get_n_patches_x(ctx); + output.n_py = get_n_patches_y(ctx); + output.px = params.patch_size; + output.py = params.patch_size; + + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + std::vector res; + normalize_image_u8_to_f32(imgs[i][j], res, params.image_mean, params.image_std); + output.buf.push_back(res); + } + } + + return output; + } +}; + +// +// cgraph builder +// + +// TODO: move this to llm_build_context in llama.cpp +struct llama_vision_graph_builder { + llama_vision_context & ctx; + const llama_vision_model & model; + struct ggml_context * ctx0; + int batch_size; + int hidden_size; + int n_head; + int d_head; + int patch_size; + float eps; + int num_patches; + int num_positions; + int img_w; + int img_h; + bool use_gelu; + int n_layers; + int rs_n_embd; + vision_projector_type proj_type; + + llama_vision_graph_builder(llama_vision_context & ctx, const llama_vision_tokens & inp) : ctx(ctx), model(*ctx.model) { + struct ggml_init_params params = { + /*.mem_size =*/ ctx.buf_compute_meta.size(), + /*.mem_buffer =*/ ctx.buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + ctx0 = ggml_init(params); + + auto & hparams = ctx.model->hparams; + + batch_size = inp.buf.size(); + hidden_size = hparams.hidden_size; + n_head = hparams.n_head; + d_head = hidden_size / n_head; + patch_size = hparams.patch_size; + eps = hparams.eps; + num_patches = inp.n_px * inp.n_py; + num_positions = num_patches + (model.class_embedding ? 1 : 0); + img_w = inp.px * inp.n_px; + img_h = inp.py * inp.n_py; + use_gelu = hparams.use_gelu; + n_layers = (int)hparams.n_layer + hparams.select_layer; + proj_type = hparams.proj_type; + } + + ~llama_vision_graph_builder() { + ggml_free(ctx0); + } + + struct ggml_tensor * build_inp() { + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, img_w, img_h, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + // auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]); + + struct ggml_tensor * embd = inp; + if (model.class_embedding) { + embd = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_name(embd, "inp_embd"); + ggml_set_input(embd); + + embd = ggml_acc(ctx0, embd, model.class_embedding, + embd->nb[1], embd->nb[2], embd->nb[3], 0); + embd = ggml_acc(ctx0, embd, inp, + embd->nb[1], embd->nb[2], embd->nb[3], model.class_embedding->nb[1]); + } + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "inp_pos"); + ggml_set_input(positions); + + embd = ggml_add(ctx0, + embd, + ggml_get_rows(ctx0, model.position_embeddings, positions)); + + return embd; + } + + struct ggml_tensor * build_pre_norm(struct ggml_tensor * cur) { + if (model.pre_norm_w) { + cur = ggml_norm(ctx0, cur, eps); + ggml_set_name(cur, "pre_ln"); + + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.pre_norm_w), model.pre_norm_b); + } + return cur; + } + + struct ggml_tensor * build_post_norm(struct ggml_tensor * cur) { + if (model.post_norm_w) { + cur = ggml_norm(ctx0, cur, eps); + ggml_set_name(cur, "post_ln"); + + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.post_norm_w), model.post_norm_b); + } + return cur; + } + + struct ggml_tensor * build_layer(struct ggml_tensor * inpL, int il) { + struct ggml_tensor * cur = inpL; + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_in_w), + model.layers[il].norm_in_b); + } + + // self-attention + { + struct ggml_tensor * Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].q_w, cur), + model.layers[il].q_b); + + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].k_w, cur), + model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].v_w, cur), + model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_out_w), + model.layers[il].norm_out_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b); + + if (use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + return cur; + } + + struct ggml_tensor * build_vit() { + struct ggml_tensor * cur = build_inp(); + cur = build_pre_norm(cur); + for (int il = 0; il < n_layers; il++) { + cur = build_layer(cur, il); + } + cur = build_post_norm(cur); + return cur; + } + + // graph for each vision arch + + struct ggml_cgraph * build_llava() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false); + struct ggml_tensor * cur = build_vit(); + + // llava projector + { + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1]); + + struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_set_name(patches, "inp_patches"); + ggml_set_input(patches); + + // shape [1, 576, 1024] + // ne is whcn, ne = [1024, 576, 1, 1] + cur = ggml_get_rows(ctx0, cur, patches); + + if (proj_type == VISION_PROJECTOR_TYPE_MLP) { + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = ggml_add(ctx0, cur, model.mm_2_b); + + } else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) { + int n_patch = 24; + struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, cur); + mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); + mlp_0 = ggml_gelu(ctx0, mlp_0); + struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); + mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); + // mlp_2 ne = [2048, 576, 1, 1] + // // AVG Pool Layer 2*2, strides = 2 + mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3)); + // mlp_2 ne = [576, 2048, 1, 1] + mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); + // mlp_2 ne [24, 24, 2048, 1] + mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); + // weight ne = [3, 3, 2048, 1] + struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); + peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); + mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, mlp_2); + peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); + cur = ggml_cont(ctx0, peg_0); + + } else { + GGML_ASSERT(false && "unsupported proj type"); + } + } + + ggml_set_name(cur, "output"); + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_minicpmv() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false); + struct ggml_tensor * cur = build_vit(); + + // minicpmv resampler projector + { + int hidden_size = llama_vision_n_mmproj_embd(*ctx.model); + struct ggml_tensor * q = model.mm_model_query; + // layernorm + { + q = ggml_norm(ctx0, q, eps); + q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + } + + struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, cur); + // layernorm + { + v = ggml_norm(ctx0, v, eps); + v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); + } + + // position + struct ggml_tensor * k = ggml_add(ctx0, v, model.mm_model_pos_embed_k); + + // attention + { + const int d_head = 128; + int n_head = hidden_size/d_head; + int num_query = -1; + if (model.hparams.proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) { + num_query = 96; + } else if (model.hparams.proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) { + num_query = 64; + } + + struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); + struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); + // permute + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // TODO: do this when converting the model + Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // TODO: do this when converting the model + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // TODO: do this when converting the model + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); // TODO: do this when converting the model + KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); + } + // layernorm + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.mm_model_ln_post_w), model.mm_model_ln_post_b); + } + cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); + } + + // add and token embeddings + cur = ggml_concat(ctx0, model.mm_tok_embd_image, cur, 1); + cur = ggml_concat(ctx0, cur, model.mm_tok_embd_end_image, 1); + + ggml_set_name(cur, "output"); + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_idefics3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false); + struct ggml_tensor * cur = build_vit(); + + // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 + { + const int scale_factor = model.hparams.scale_factor; + const int n_embd = cur->ne[0]; + const int seq = cur->ne[1]; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = std::sqrt(seq); + const int width = std::sqrt(seq); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + seq / (scale_factor * scale_factor), + bsz); + + cur = ggml_mul_mat(ctx0, model.projection, cur); + } + + ggml_set_name(cur, "output"); + ggml_build_forward_expand(gf, cur); + + return gf; + } +}; + +static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) { + int batch_size = inp.buf.size(); + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; + + if (hparams.arch == LLM_ARCH_VISION_LLAVA) { + GGML_ASSERT(batch_size == 1); // TODO: support multiple images + } + + img_size image_size = img_size((int)hparams.image_size, (int)hparams.image_size); + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + + LLAMA_LOG_DEBUG("%s: image_size = %d\n", __func__, hparams.image_size); + LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions); + + // build the inference graph + llama_vision_graph_builder builder(ctx, inp); + ggml_cgraph * gf; + switch(hparams.arch) { + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + gf = builder.build_llava(); + break; + case LLM_ARCH_VISION_MINICPMV: + gf = builder.build_minicpmv(); + break; + case LLM_ARCH_VISION_IDEFICS3: + gf = builder.build_idefics3(); + break; + default: + GGML_ASSERT(false && "unsupported vision arch"); + } + + // alloc memory for graph + bool ok = ggml_backend_sched_alloc_graph(ctx.sched.get(), gf); + if (!ok) { + LLAMA_LOG_ERROR("failed to alloc memory for graph\n"); + return -1; + } + + // set raw input + { + struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); + std::vector inp_buf(ggml_nelements(inp_raw)); + + for (int i = 0; i < batch_size; i++) { + const int nx = inp.px * inp.n_px; + const int ny = inp.py * inp.n_py; + const int n = nx * ny; + + for (int b = 0; b < batch_size; b++) { + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + inp_buf[(b * 3 * n) + k * n + y * nx + x] = inp.buf[b][3 * (y * nx + x) + k]; + } + } + } + } + } + ggml_backend_tensor_set(inp_raw, inp_buf.data(), 0, ggml_nbytes(inp_raw)); + } + + if (model.class_embedding) { + struct ggml_tensor * inp_embd = ggml_graph_get_tensor(gf, "inp_embd"); + ggml_set_zero(inp_embd); + } + + if (hparams.arch == LLM_ARCH_VISION_MINICPMV) { + // inspired from siglip: + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos"); + std::vector buf(ggml_nelements(positions)); + GGML_ASSERT(num_positions == (int)buf.size()); + + int bucket_coords_h[70]; + int bucket_coords_w[70]; + size_t h = inp.py; + size_t w = inp.py; + for (size_t i = 0; i < h; i++) { + bucket_coords_h[i] = std::floor(70.0*i/h); + } + for (size_t i = 0; i < w; i++) { + bucket_coords_w[i] = std::floor(70.0*i/w); + } + for (size_t i = 0, id = 0; i < h; i++){ + for (size_t j = 0; j < w; j++){ + buf[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + } + } + ggml_backend_tensor_set(positions, buf.data(), 0, ggml_nbytes(positions)); + + } else { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos"); + std::vector pos_buf(ggml_nelements(positions)); + GGML_ASSERT(num_positions == (int)pos_buf.size()); + for (int i = 0; i < num_positions; i++) { + pos_buf[i] = i; + } + ggml_backend_tensor_set(positions, pos_buf.data(), 0, ggml_nbytes(positions)); + } + + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "inp_patches"); + if (patches) { + std::vector patches_buf(ggml_nelements(patches)); + GGML_ASSERT(num_patches == (int)patches_buf.size()); + for (int i = 0; i < num_patches; i++) { + patches_buf[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_buf.data(), 0, ggml_nbytes(patches)); + } + + // compute + LLAMA_LOG_DEBUG("%s: compute start\n", __func__); + int64_t t_start = ggml_time_ms(); + ggml_backend_sched_graph_compute(ctx.sched.get(), gf); + + // the last node is the embedding tensor + struct ggml_tensor * output_node = ggml_graph_node(gf, -1); + //LLAMA_LOG_INFO("%s: output tensor shape = %lld %lld %lld %lld\n", __func__, output->ne[0], output->ne[1], output->ne[2], output->ne[3]); + LLAMA_LOG_DEBUG("%s: compute time = %lld ms\n", __func__, ggml_time_ms() - t_start); + + // copy output node to context + if (ctx.ctx_ggml) { + ggml_free(ctx.ctx_ggml); + } + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ctx.ctx_ggml = ggml_init(params); + ctx.output = ggml_dup_tensor(ctx.ctx_ggml, output_node); + ggml_backend_alloc_ctx_tensors_from_buft(ctx.ctx_ggml, ctx.model->buft); + ggml_backend_tensor_copy(output_node, ctx.output); + + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////// +// public API + +struct llama_vision_context_params llama_vision_context_default_params() { + return { + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + }; +} + +struct llama_vision_context * llama_vision_init_from_model(const struct llama_model * model, struct llama_vision_context_params params) { + if (!model->has_vision) { + return nullptr; + } + + llama_vision_context * ctx = new llama_vision_context; + ctx->model = &model->vit; + + // TODO: this looks ugly, mostly copied from llama.cpp, refactor it in the future + + // init backends + { + // add CPU backend + ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (ctx->backend_cpu == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); + llama_vision_free(ctx); + return nullptr; + } + ctx->backends.emplace_back(ctx->backend_cpu); + + // create a list of the set_n_threads functions in the backends + for (auto & backend : ctx->backends) { + ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + ggml_backend_set_n_threads_fn(backend.get(), params.n_threads); + } + } + } + + // scheduler and compute buffers + { + // buffer types used for the compute buffer of each backend + std::vector backend_buft; + std::vector backend_ptrs; + for (auto & backend : ctx->backends) { + auto * buft = ggml_backend_get_default_buffer_type(backend.get()); + auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) { + // use the host buffer of the first device CPU for faster transfer of the intermediate state + auto * dev = model->devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (host_buft) { + buft = host_buft; + } + } + backend_buft.push_back(buft); + backend_ptrs.push_back(backend.get()); + } + + const size_t max_nodes = model->max_nodes(); + + // buffer used to store the computation graph and the tensor meta data + ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + // TODO: support pipeline_parallel + const bool pipeline_parallel = false; + + ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); + + if (pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get())); + } + } + + const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic + ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + return ctx; +} + +void llama_vision_free(struct llama_vision_context * ctx) { + if (ctx->ctx_ggml) { + ggml_free(ctx->ctx_ggml); + } + delete ctx; +} + +struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny) { + llama_vision_bitmap * bmp = new llama_vision_bitmap; + bmp->nx = nx; + bmp->ny = ny; + bmp->data = (unsigned char *)malloc(3 * nx * ny); + return bmp; +} + +void llama_vision_bitmap_free(llama_vision_bitmap * bmp) { + free(bmp->data); + delete bmp; +} + +struct llama_vision_tokens * llama_vision_tokenize( + struct llama_vision_context * ctx, + struct llama_vision_bitmap * bmp) { + switch (ctx->model->hparams.arch) { + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + case LLM_ARCH_VISION_IDEFICS3: + return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp)); + case LLM_ARCH_VISION_MINICPMV: + return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp)); + default: + GGML_ASSERT(false && "unsupported arch"); + } +} + +void llama_vision_tokens_free(llama_vision_tokens * p) { + delete p; +} + +int32_t llama_vision_encode(struct llama_vision_context * ctx, struct llama_vision_tokens * p) { + if (p->buf.empty()) { + LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__); + return -1; + } + + auto & hparams = ctx->model->hparams; + switch (hparams.mm_patch_merge_type) { + case MM_PATCH_MERGE_FLAT: + { + // flat / default llava-1.5 type embedding + int32_t encoded = llama_vision_encode_impl(*ctx, *p); + if (encoded != 0) { + LLAMA_LOG_ERROR("Unable to encode image\n"); + return encoded; + } + } break; + case MM_PATCH_MERGE_SPATIAL_UNPAD: + { + // TODO: support llava-1.6 + (void)0; + } break; + default: + GGML_ASSERT(false && "unsupported mm_patch_merge_type"); + } + + return 0; +} + +struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx) { + return ctx->output; +} + +//////////////////////////////////////////////////////////////////////////////////////// +// for debugging +#ifndef NDEBUG + +static int bmp_export(const struct llama_image_u8 &img, const std::string &location) { + const uint32_t width = img.nx; + const uint32_t height = img.ny; + // swap red and blue channel + std::vector buffer(width*height*3); + for (uint32_t y = 0; y < height; y++) { + for (uint32_t x = 0; x < width; x++) { + size_t base = x*3 + y*3*width; + buffer[base+2] = img.buf[base]; + buffer[base+1] = img.buf[base+1]; + buffer[base] = img.buf[base+2]; + } + } + const bool hasAlphaChannel = false; + + std::ofstream fout(location, std::ios::out | std::ios::binary); + + if (fout.fail()) { + return 0; + } + + //Padding + const uint8_t padding = hasAlphaChannel ? 0 : (4 - (width * 3) % 4) % 4; + + //Bitmap file header. + const char signature[2] = { 'B', 'M' }; + const uint32_t fileSize = buffer.size() * sizeof(uint8_t) + padding * (height - 1) + 14 + 124; + const uint32_t offset = 14 + 124; + + //Bitmap information header file + const uint32_t DIBSize = 124; + const int32_t bitmapWidth = width; + const int32_t bitmapHeight = height; + const uint16_t numPlanes = 1; + const uint16_t bitsPerPixel = (hasAlphaChannel) ? 32 : 24; + const uint32_t compressionMethod = (hasAlphaChannel) ? 3 : 0; //BI_RGB = 0, BI_BITFIELDS = 3 + const uint32_t bitmapSize = buffer.size() * sizeof(uint8_t); + const int32_t horizontalResolution = 2834; + const int32_t verticalResolution = 2834; + const uint32_t numColors = 0; + const uint32_t impColorCount = 0; + const uint32_t redBitmask = (hasAlphaChannel) ? 0x0000FF00 : 0; //ARGB32 pixel format + const uint32_t greenBitmask = (hasAlphaChannel) ? 0x00FF0000 : 0; + const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0; + const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0; + + //Writing the file header and information header to the file + std::vector header(offset, 0); + header[0] = signature[0]; + header[1] = signature[1]; + +#define BMP_HEADERS(i, variableName) header[i] = variableName; header[i+1] = variableName >> 8; header[i+2] = variableName >> 16; header[i+3] = variableName >> 24; + + BMP_HEADERS(2, fileSize); + BMP_HEADERS(6, 0); + BMP_HEADERS(10, offset); + BMP_HEADERS(14, DIBSize); + BMP_HEADERS(18, bitmapWidth); + BMP_HEADERS(22, bitmapHeight); + + header[26] = (uint8_t)numPlanes; + header[27] = (uint8_t)(numPlanes >> 8); + header[28] = (uint8_t)bitsPerPixel; + header[29] = (uint8_t)(bitsPerPixel >> 8); + + BMP_HEADERS(30, compressionMethod); + BMP_HEADERS(34, (unsigned char)bitmapSize); + BMP_HEADERS(38, (unsigned char)horizontalResolution); + BMP_HEADERS(42, (unsigned char)verticalResolution); + BMP_HEADERS(46, (unsigned char)numColors); + BMP_HEADERS(50, (unsigned char)impColorCount); + BMP_HEADERS(54, (unsigned char)redBitmask); + BMP_HEADERS(58, (unsigned char)greenBitmask); + BMP_HEADERS(62, (unsigned char)blueBitmask); + BMP_HEADERS(66, alphaBitmask); + +#undef BMP_HEADERS + + fout.write((char *)header.data(), sizeof(uint8_t) * header.size()); + + //Writing the pixel array + const uint32_t bWidth = bitsPerPixel / 8 * width; + + for (int i = height - 1; i >= 0; i--) { + std::vector row(buffer.begin() + i * bWidth, buffer.begin() + i * bWidth + bWidth); + fout.write((char *)row.data(), row.size() * sizeof(uint8_t)); + fout.seekp(padding * sizeof(uint8_t), std::ios::cur); + } + + fout.close(); + return 1; +} + +#endif + diff --git a/src/llama-vision.h b/src/llama-vision.h new file mode 100644 index 0000000000000..d1ba10c30cd30 --- /dev/null +++ b/src/llama-vision.h @@ -0,0 +1,195 @@ +#pragma once + +#include "ggml.h" +#include "ggml-cpp.h" +#include "llama.h" +#include "llama-arch.h" + +#include +#include + +#define VISION_GRAPH_MAX_NODE 2048 + +enum vision_projector_type { + VISION_PROJECTOR_TYPE_UNKNOWN, + VISION_PROJECTOR_TYPE_MLP, + VISION_PROJECTOR_TYPE_LDPV2, + VISION_PROJECTOR_TYPE_MINICPMV_2_5, + VISION_PROJECTOR_TYPE_MINICPMV_2_6, +}; + +enum mm_patch_merge { + MM_PATCH_MERGE_UNKNOWN, + MM_PATCH_MERGE_FLAT, + MM_PATCH_MERGE_SPATIAL_UNPAD, +}; + +struct llama_vision_model { + struct vision_hparams { + llm_arch arch = LLM_ARCH_UNKNOWN; + + uint32_t image_size; + uint32_t patch_size; + uint32_t hidden_size; + uint32_t n_intermediate; + uint32_t projection_dim; + uint32_t n_head; + uint32_t n_layer; + uint32_t max_pos_embd; + int32_t select_layer = 0; + bool use_gelu = false; + + float eps; + + vision_projector_type proj_type = VISION_PROJECTOR_TYPE_UNKNOWN; + mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_UNKNOWN; + + std::array image_mean; + std::array image_std; + + std::array image_grid_pinpoints; // TODO: should this be array of (x, y) pairs? + int32_t image_crop_resolution; + + // idefics3 + int scale_factor = 0; + }; + struct vision_hparams hparams; + ggml_backend_buffer_type_t buft; + + // embeddings + struct ggml_tensor * class_embedding = nullptr; + struct ggml_tensor * patch_embeddings = nullptr; + struct ggml_tensor * patch_bias = nullptr; + struct ggml_tensor * position_embeddings = nullptr; + + struct ggml_tensor * pre_norm_w = nullptr; + struct ggml_tensor * pre_norm_b = nullptr; + + struct vision_layer { + // attention + struct ggml_tensor * k_w = nullptr; + struct ggml_tensor * k_b = nullptr; + struct ggml_tensor * q_w = nullptr; + struct ggml_tensor * q_b = nullptr; + struct ggml_tensor * v_w = nullptr; + struct ggml_tensor * v_b = nullptr; + + struct ggml_tensor * output_w = nullptr; + struct ggml_tensor * output_b = nullptr; + + // layernorm 1 + struct ggml_tensor * norm_in_w = nullptr; + struct ggml_tensor * norm_in_b = nullptr; + + // ff + struct ggml_tensor * ffn_up_w = nullptr; + struct ggml_tensor * ffn_up_b = nullptr; + + struct ggml_tensor * ffn_down_w = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; + + // layernorm 2 + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; + }; + std::vector layers; + + struct ggml_tensor * post_norm_w = nullptr; + struct ggml_tensor * post_norm_b = nullptr; + + struct ggml_tensor * projection = nullptr; + + // LLaVA projection + struct ggml_tensor * mm_1_w = nullptr; + struct ggml_tensor * mm_1_b = nullptr; + struct ggml_tensor * mm_2_w = nullptr; + struct ggml_tensor * mm_2_b = nullptr; + + // MobileVLM_V2 projection + struct ggml_tensor * mm_model_mlp_0_w = nullptr; + struct ggml_tensor * mm_model_mlp_0_b = nullptr; + struct ggml_tensor * mm_model_mlp_2_w = nullptr; + struct ggml_tensor * mm_model_mlp_2_b = nullptr; + struct ggml_tensor * mm_model_peg_0_w = nullptr; + struct ggml_tensor * mm_model_peg_0_b = nullptr; + + // MINICPMV projection + struct ggml_tensor * mm_model_pos_embed_k = nullptr; + struct ggml_tensor * mm_model_query = nullptr; + struct ggml_tensor * mm_model_proj = nullptr; + struct ggml_tensor * mm_model_kv_proj = nullptr; + struct ggml_tensor * mm_model_attn_q_w = nullptr; + struct ggml_tensor * mm_model_attn_q_b = nullptr; + struct ggml_tensor * mm_model_attn_k_w = nullptr; + struct ggml_tensor * mm_model_attn_k_b = nullptr; + struct ggml_tensor * mm_model_attn_v_w = nullptr; + struct ggml_tensor * mm_model_attn_v_b = nullptr; + struct ggml_tensor * mm_model_attn_o_w = nullptr; + struct ggml_tensor * mm_model_attn_o_b = nullptr; + struct ggml_tensor * mm_model_ln_q_w = nullptr; + struct ggml_tensor * mm_model_ln_q_b = nullptr; + struct ggml_tensor * mm_model_ln_kv_w = nullptr; + struct ggml_tensor * mm_model_ln_kv_b = nullptr; + struct ggml_tensor * mm_model_ln_post_w = nullptr; + struct ggml_tensor * mm_model_ln_post_b = nullptr; + + // special tokens + struct ggml_tensor * mm_tok_embd_image = nullptr; + struct ggml_tensor * mm_tok_embd_end_image = nullptr; + struct ggml_tensor * mm_tok_embd_slice = nullptr; + struct ggml_tensor * mm_tok_embd_end_slice = nullptr; +}; + +struct llama_vision_context { + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + ggml_backend_sched_ptr sched; + std::vector backends; + ggml_backend_t backend_cpu; + + const llama_vision_model * model; + + // temporary output data, to be picked up by llama_decode() + struct ggml_context * ctx_ggml = nullptr; + struct ggml_tensor * output; +}; + +// for now, this only contains: +// - the instruction for ggml_conv_2d to break the image into patches +// - the pre-processed image data in f32 +struct llama_vision_tokens { + uint32_t px; // size of patch + uint32_t py; // size of patch + size_t n_px; // number of patches in x direction + size_t n_py; // number of patches in y direction + // RGB float32 image (NHWC) + // Memory layout: RGBRGBRGB... + std::vector> buf; // preprocessed image data +}; + +inline mm_patch_merge mm_patch_merge_from_name(std::string & name) { + if (name == "flat") { + return MM_PATCH_MERGE_FLAT; + } else if (name == "spatial_unpad") { + return MM_PATCH_MERGE_SPATIAL_UNPAD; + } + return MM_PATCH_MERGE_UNKNOWN; +} + +inline vision_projector_type vision_projector_type_from_name(std::string & name) { + if (name == "mlp") { + return VISION_PROJECTOR_TYPE_MLP; + } else if (name == "ldpv2") { + return VISION_PROJECTOR_TYPE_LDPV2; + } else if (name == "minicpmv-2.5") { + return VISION_PROJECTOR_TYPE_MINICPMV_2_5; + } else if (name == "minicpmv-2.6") { + return VISION_PROJECTOR_TYPE_MINICPMV_2_6; + } + return VISION_PROJECTOR_TYPE_UNKNOWN; +} + +// only for sanity check: must be equal to n_embd of language model +uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel); + +struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx);