Skip to content

Commit afcc335

Browse files
authored
clip : Experimental support for Gemma 3 vision (#12344)
* clip : Experimental support for Gemma 3 vision * fix build * PRId64
1 parent d0a27d6 commit afcc335

File tree

5 files changed

+885
-10
lines changed

5 files changed

+885
-10
lines changed

examples/llava/CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ install(TARGETS ${TARGET} RUNTIME)
5151
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
5252
target_compile_features(${TARGET} PRIVATE cxx_std_17)
5353

54+
set(TARGET llama-gemma3-cli)
55+
add_executable(${TARGET} gemma3-cli.cpp)
56+
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
57+
install(TARGETS ${TARGET} RUNTIME)
58+
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
59+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
60+
5461
set(TARGET llama-llava-clip-quantize-cli)
5562
add_executable(${TARGET} clip-quantize-cli.cpp)
5663
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)

examples/llava/README-gemma3.md

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Gemma 3 vision
2+
3+
> [!IMPORTANT]
4+
>
5+
> This is very experimental, only used for demo purpose.
6+
7+
## How to get mmproj.gguf?
8+
9+
```bash
10+
cd gemma-3-4b-it
11+
python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
12+
13+
# output file is mmproj.gguf
14+
```
15+
16+
## How to run it?
17+
18+
What you need:
19+
- The text model GGUF, can be converted using `convert_hf_to_gguf.py`
20+
- The mmproj file from step above
21+
- An image file
22+
23+
```bash
24+
# build
25+
cmake -B build
26+
cmake --build build --target llama-gemma3-cli
27+
28+
# run it
29+
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
30+
```

examples/llava/clip.cpp

+200-10
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
136136
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
137137
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
138138
#define TN_IMAGE_NEWLINE "model.image_newline"
139+
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
140+
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
139141

140142
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
141143
#define TN_MINICPMV_QUERY "resampler.query"
@@ -162,6 +164,7 @@ enum projector_type {
162164
PROJECTOR_TYPE_RESAMPLER,
163165
PROJECTOR_TYPE_GLM_EDGE,
164166
PROJECTOR_TYPE_MERGER,
167+
PROJECTOR_TYPE_GEMMA3,
165168
PROJECTOR_TYPE_UNKNOWN,
166169
};
167170

@@ -172,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
172175
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
173176
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
174177
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
178+
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
175179
};
176180

177181

@@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name)
298302
return kv.first;
299303
}
300304
}
301-
return PROJECTOR_TYPE_UNKNOWN;
305+
throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
302306
}
303307

304308
#ifdef CLIP_DEBUG_FUNCTIONS
@@ -555,6 +559,10 @@ struct clip_vision_model {
555559
struct ggml_tensor * mm_model_ln_kv_b;
556560
struct ggml_tensor * mm_model_ln_post_w;
557561
struct ggml_tensor * mm_model_ln_post_b;
562+
563+
// gemma3
564+
struct ggml_tensor * mm_input_proj_w;
565+
struct ggml_tensor * mm_soft_emb_norm_w;
558566
};
559567

560568
struct clip_ctx {
@@ -569,7 +577,7 @@ struct clip_ctx {
569577
struct clip_vision_model vision_model;
570578
projector_type proj_type = PROJECTOR_TYPE_MLP;
571579

572-
int32_t max_feature_layer;
580+
int32_t max_feature_layer; // unused in newer models like gemma3
573581
float image_mean[3];
574582
float image_std[3];
575583
bool use_gelu = false;
@@ -595,7 +603,7 @@ struct clip_ctx {
595603

596604
ggml_backend_sched_ptr sched;
597605

598-
struct clip_image_size * load_image_size;
606+
struct clip_image_size * load_image_size = nullptr;
599607

600608
clip_ctx(clip_context_params & ctx_params) {
601609
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
@@ -631,7 +639,159 @@ struct clip_ctx {
631639
}
632640
};
633641

634-
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
642+
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
643+
const auto & model = ctx->vision_model;
644+
const auto & hparams = model.hparams;
645+
646+
const int image_size = hparams.image_size;
647+
int image_size_width = image_size;
648+
int image_size_height = image_size;
649+
650+
const int patch_size = hparams.patch_size;
651+
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
652+
const int hidden_size = hparams.hidden_size;
653+
const int n_head = hparams.n_head;
654+
const int d_head = hidden_size / n_head;
655+
const int n_layer = hparams.n_layer;
656+
const float eps = hparams.eps;
657+
658+
GGML_ASSERT(imgs->size == 1); // batch_size == 1
659+
660+
struct ggml_init_params params = {
661+
/*.mem_size =*/ ctx->buf_compute_meta.size(),
662+
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
663+
/*.no_alloc =*/ true,
664+
};
665+
666+
struct ggml_context * ctx0 = ggml_init(params);
667+
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
668+
669+
// input raw
670+
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
671+
ggml_set_name(inp_raw, "inp_raw");
672+
ggml_set_input(inp_raw);
673+
674+
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
675+
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
676+
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
677+
inp = ggml_add(ctx0, inp, model.patch_bias);
678+
679+
// position embeddings
680+
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
681+
682+
// loop over layers
683+
for (int il = 0; il < n_layer; il++) {
684+
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
685+
686+
// layernorm1
687+
{
688+
cur = ggml_norm(ctx0, cur, eps);
689+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
690+
}
691+
692+
// self-attention
693+
{
694+
695+
struct ggml_tensor * Q =
696+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
697+
698+
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
699+
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
700+
701+
struct ggml_tensor * K =
702+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
703+
704+
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
705+
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
706+
707+
struct ggml_tensor * V =
708+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
709+
710+
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
711+
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
712+
713+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
714+
KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
715+
KQ = ggml_soft_max_inplace(ctx0, KQ);
716+
717+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
718+
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
719+
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
720+
721+
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
722+
}
723+
724+
// attention output
725+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
726+
727+
// re-add the layer input, e.g., residual
728+
cur = ggml_add(ctx0, cur, embeddings);
729+
730+
embeddings = cur; // embeddings = residual, cur = hidden_states
731+
732+
// layernorm2
733+
{
734+
cur = ggml_norm(ctx0, cur, eps);
735+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
736+
}
737+
738+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
739+
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
740+
741+
// siglip uses gelu
742+
cur = ggml_gelu(ctx0, cur);
743+
744+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
745+
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
746+
747+
// residual 2
748+
cur = ggml_add(ctx0, embeddings, cur);
749+
750+
embeddings = cur;
751+
}
752+
753+
// post-layernorm
754+
if (ctx->has_post_norm) {
755+
embeddings = ggml_norm(ctx0, embeddings, eps);
756+
ggml_set_name(embeddings, "post_ln");
757+
758+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
759+
}
760+
761+
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
762+
const int batch_size = 1;
763+
const int mm_tokens_per_image = 256; // default value for gemma3
764+
const int tokens_per_side = sqrt(mm_tokens_per_image);
765+
const int patches_per_image = sqrt(num_patches);
766+
const int kernel_size = patches_per_image / tokens_per_side;
767+
768+
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
769+
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
770+
771+
// doing a pool2d to reduce the number of output tokens to 256
772+
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
773+
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
774+
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
775+
776+
// apply norm before projection
777+
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
778+
embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
779+
780+
// apply projection
781+
embeddings = ggml_mul_mat(ctx0,
782+
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
783+
embeddings);
784+
}
785+
786+
// build the graph
787+
ggml_build_forward_expand(gf, embeddings);
788+
789+
ggml_free(ctx0);
790+
791+
return gf;
792+
}
793+
794+
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
635795
if (!ctx->has_vision_encoder) {
636796
LOG_ERR("This gguf file seems to have no vision encoder\n");
637797
return nullptr;
@@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
11771337
} else {
11781338
GGML_ABORT("fatel error");
11791339
}
1180-
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
1340+
}
1341+
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
11811342
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
11821343

11831344
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
11991360
return gf;
12001361
}
12011362

1363+
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
1364+
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
1365+
return clip_image_build_graph_siglip(ctx, imgs);
1366+
} else {
1367+
// TODO: we should have one build_* function per model
1368+
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
1369+
}
1370+
}
1371+
12021372
// read and create ggml_context containing the tensors and their data
12031373
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
12041374
return clip_init(fname, clip_context_params{
@@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
13581528
GGML_ASSERT(new_clip->has_vision_encoder);
13591529
GGML_ASSERT(!new_clip->has_text_encoder);
13601530

1361-
idx = get_key_idx(ctx, KEY_USE_GELU);
1362-
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
1531+
try {
1532+
idx = get_key_idx(ctx, KEY_USE_GELU);
1533+
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
1534+
} catch (std::runtime_error & /*e*/) {
1535+
new_clip->use_gelu = false;
1536+
}
13631537

13641538
try {
13651539
idx = get_key_idx(ctx, KEY_USE_SILU);
@@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
15671741
}
15681742

15691743
try {
1570-
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
1744+
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
1745+
} catch(const std::exception& /*e*/) {
1746+
vision_model.patch_embeddings_0 = nullptr;
1747+
}
1748+
1749+
try {
15711750
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
15721751
} catch(const std::exception& /*e*/) {
1573-
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
1752+
vision_model.position_embeddings = nullptr;
15741753
}
1754+
15751755
try {
15761756
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
15771757
} catch(const std::exception& /*e*/) {
@@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
16821862
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
16831863
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
16841864
}
1865+
else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
1866+
vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
1867+
vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
1868+
}
16851869
else {
16861870
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
16871871
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
@@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
22232407
return true;
22242408
}
22252409

2226-
if (ctx->has_glm_projector) {
2410+
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
22272411
res_imgs->size = 1;
22282412
res_imgs->data = new clip_image_f32[res_imgs->size];
22292413
clip_image_u8 resized_image;
@@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
27482932
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
27492933
free(positions_data);
27502934
}
2935+
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
2936+
// do nothing
2937+
}
27512938
else {
27522939
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
27532940

@@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
29603147
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
29613148
return ctx->vision_model.mm_1_b->ne[0];
29623149
}
3150+
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
3151+
return ctx->vision_model.mm_input_proj_w->ne[0];
3152+
}
29633153

29643154
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
29653155
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));

0 commit comments

Comments
 (0)