Skip to content

(research) experiment with phi-4-multimodal vision support #12274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2398,9 +2398,23 @@ def set_gguf_parameters(self):
self.gguf_writer.add_add_bos_token(False)


@Model.register("Phi3ForCausalLM")
@Model.register("Phi3ForCausalLM", "Phi4MMForCausalLM")
class Phi3MiniModel(Model):
model_arch = gguf.MODEL_ARCH.PHI3
has_vision: bool = False

# we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "vision_lora" in self.hparams:
logger.info("Detected vision encoder, but it will be ignored")
self.has_vision = True

def write(self):
super().write()
if self.has_vision:
logger.info("NOTE: this script only convert the language model to GGUF")
logger.info(" for the vision model, please use phi4mm_convert_encoder_to_gguf.py")

def set_vocab(self):
# Phi-4 model uses GPT2Tokenizer
Expand All @@ -2409,7 +2423,7 @@ def set_vocab(self):
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
tokenizer_class = tokenizer_config_json['tokenizer_class']
if tokenizer_class == 'GPT2Tokenizer':
if tokenizer_class == 'GPT2Tokenizer' or tokenizer_class == 'GPT2TokenizerFast':
return self._set_vocab_gpt2()

from sentencepiece import SentencePieceProcessor
Expand Down Expand Up @@ -2575,6 +2589,14 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if self.has_vision:
if name.startswith("model.embed_tokens_extend") or "lora_" in name:
return []
name = name.replace(".base_layer", "")
return [(self.map_tensor_name(name), data_torch)]


@Model.register("PhiMoEForCausalLM")
class PhiMoeModel(Phi3MiniModel):
Expand Down
7 changes: 7 additions & 0 deletions examples/llava/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

set(TARGET llama-phi4mm-cli)
add_executable(${TARGET} phi4mm-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-phi4mm-cli)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
26 changes: 24 additions & 2 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,28 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
}
}

// FIXME: phi-4, wrap this into an "if" condition
int n_tokens = embeddings->ne[1];
int n_tokens_sqrt = sqrtf(n_tokens);
int downscale_factor = 2;
printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]);
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
embeddings = ggml_reshape_3d(ctx0, embeddings, n_tokens_sqrt, n_tokens_sqrt, hidden_size);
// downscale n_tokens_sqrt*n_tokens_sqrt to n_tokens_sqrt/2*n_tokens_sqrt/2
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, downscale_factor, downscale_factor, downscale_factor, downscale_factor, 0, 0);
// flatten first two dimensions
embeddings = ggml_reshape_2d(ctx0, embeddings, n_tokens_sqrt/2*n_tokens_sqrt/2, hidden_size);
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]);
// mlp
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);

embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
printf("embeddings shape: %d %d %d %d\n", embeddings->ne[0], embeddings->ne[1], embeddings->ne[2], embeddings->ne[3]);

// llava projector
if (ctx->has_llava_projector) {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
Expand Down Expand Up @@ -2758,7 +2780,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);

if (!ctx->has_glm_projector) {
/*if (!ctx->has_glm_projector) {
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
// The patches vector is used to get rows to index into the embeds with;
// we should skip dim 0 only if we have CLS to avoid going out of bounds
Expand All @@ -2770,7 +2792,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data);
}
}*/
}
}

Expand Down
224 changes: 224 additions & 0 deletions examples/llava/phi4mm-cli.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#include "arg.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include "clip.h"
#include "stb_image.h"
#include "llama.h"
#include "ggml.h"

#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <iostream>
#include <fstream>

struct phi4mm_context {
struct clip_ctx * ctx_clip = NULL;
common_init_result llama_init;

llama_model * model;
llama_context * lctx;
llama_adapter_lora * vision_lora;

phi4mm_context(common_params & params) : llama_init(common_init_from_params(params)) {
model = llama_init.model.get();
lctx = llama_init.context.get();
vision_lora = llama_init.lora[0].get();
llama_clear_adapter_lora(lctx);
init_clip_model(params);
}

void init_clip_model(common_params & params) {
const char * clip_path = params.mmproj.c_str();
ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
}

~phi4mm_context() {
clip_free(ctx_clip);
}
};

struct decode_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};

struct inp_bitmap {
int nx;
int ny;
std::vector<unsigned char> data;
};

static void show_additional_info(int /*argc*/, char ** argv) {
GGML_UNUSED(argv);
LOG("TODO\n");
}

static void eval_text(phi4mm_context & ctx, int & n_past, std::string input, bool logits_last = false) {
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (llama_token & t : tokens) {
common_batch_add(batch, t, n_past++, {0}, false);
}
if (logits_last) {
batch.logits[batch.n_tokens - 1] = true;
}
LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
if (llama_decode(ctx.lctx, batch)) {
GGML_ABORT("Failed to decode\n");
}
}

int main(int argc, char ** argv) {
ggml_time_init();

common_params params;

// default values
params.prompt = "<|user|>$what did you see?<|end|><|assistant|>";
params.n_predict = 64;
params.sampling.temp = 0.0f;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
return 1;
}

common_init();

if (params.mmproj.empty() || (params.image.empty())) {
show_additional_info(argc, argv);
return 1;
}

if (params.lora_adapters.empty()) {
LOG_ERR("error: no vision lora adapters specified\n");
return 1;
}

phi4mm_context ctx(params);
printf("%s: %s\n", __func__, params.model.c_str());

int n_threads = params.cpuparams.n_threads;
int n_past = 0;

std::vector<std::string> prompt_parts = string_split<std::string>(params.prompt, '$');
GGML_ASSERT(prompt_parts.size() == 2);
eval_text(ctx, n_past, prompt_parts[0], false);

// process images
for (auto & image : params.image) {
//break;
std::vector<float> image_embd_v;
int n_embd = llama_model_n_embd(ctx.model);
int n_tokens = 256;
image_embd_v.resize(n_tokens * n_embd);

bool ok;
struct clip_image_u8 * img_u8 = clip_image_u8_init();
ok = clip_image_load_from_file(image.c_str(), img_u8);
if (!ok) {
LOG_ERR("Unable to load image %s\n", image.c_str());
return 1;
}

clip_image_f32_batch batch_f32;
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
if (!ok) {
LOG_ERR("Unable to preprocess image\n");
return 1;
}

LOG("Encoding image %s\n", image.c_str());
ok = clip_image_batch_encode(ctx.ctx_clip, n_threads, &batch_f32, image_embd_v.data());
if (!ok) {
LOG_ERR("Unable to encode image\n");
return 1;
}

// debug
// for (int i = 0; i < 10; i++) {
// LOG("embd[%d] = %f, %f, %f\n", i, image_embd_v[i*n_embd], image_embd_v[i*n_embd+1], image_embd_v[i*n_embd+2]);
// }

clip_image_f32_batch_free(&batch_f32);
clip_image_u8_free(img_u8);

// decode image embeddings
llama_set_adapter_lora(ctx.lctx, ctx.vision_lora, 1.0f);
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, n_past, 0);
if (llama_decode(ctx.lctx, batch_img.batch)) {
LOG_ERR("failed to decode image\n");
return 1;
}
llama_clear_adapter_lora(ctx.lctx);
n_past += n_tokens;
}

eval_text(ctx, n_past, prompt_parts[1], true);

// generate text
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
const llama_vocab * vocab = llama_model_get_vocab(ctx.model);
int n_prompt = n_past;
llama_batch batch = llama_batch_init(1, 0, 1);
while (true) {
int n_generated = n_past - n_prompt;
if (n_generated > params.n_predict) {
printf("\n");
break;
}

llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
common_sampler_accept(smpl, token_id, true);
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
fflush(stdout);

if (llama_vocab_is_eog(vocab, token_id)) {
printf("\n");
break;
}

// eval the token
common_batch_clear(batch);
common_batch_add(batch, token_id, n_past++, {0}, true);
if (llama_decode(ctx.lctx, batch)) {
LOG_ERR("failed to decode token\n");
break;
}
}

llama_batch_free(batch);

return 0;
}
14 changes: 14 additions & 0 deletions examples/llava/phi4mm-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/sh

# for convenience, we have this script to ease the development process

# make sure we are in the right directory
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
PROJECT_ROOT="$SCRIPT_DIR/../.."
cd $PROJECT_ROOT

./build/bin/llama-phi4mm-cli \
-m ../models/Phi-4-multimodal-instruct/model.gguf \
--mmproj ../models/Phi-4-multimodal-instruct/mmproj.gguf \
--lora ../models/Phi-4-multimodal-instruct/vision_lora.gguf \
--image ../models/bliss.png
Loading
Loading