diff --git a/Makefile b/Makefile index 24acb80136516..fb6098424cf31 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # Define the default target now so that it is always the first target BUILD_TARGETS = \ main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ - simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli baby-llama beam-search \ + simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli minicpmv-cli baby-llama beam-search \ retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o # Binaries only useful for tests @@ -842,6 +842,12 @@ llava-cli: examples/llava/llava-cli.cpp examples/llava/clip.h examples/llava/cli $(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS) +minicpmv-cli: examples/llava/minicpmv-cli.cpp examples/llava/clip.h examples/llava/clip.cpp examples/llava/llava.h examples/llava/llava.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual + $(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) + $(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS) + baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index e431c7f70fb0e..3851c3a6735d1 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -16,6 +16,10 @@ #include "ggml-metal.h" #endif +#ifdef GGML_USE_CANN +#include "ggml-cann.h" +#endif + #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -70,26 +74,28 @@ static std::string format(const char * fmt, ...) { // key constants // -#define KEY_FTYPE "general.file_type" -#define KEY_NAME "general.name" -#define KEY_DESCRIPTION "general.description" -#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" -#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" -#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" -#define KEY_USE_GELU "clip.use_gelu" -#define KEY_N_EMBD "clip.%s.embedding_length" -#define KEY_N_FF "clip.%s.feed_forward_length" -#define KEY_N_BLOCK "clip.%s.block_count" -#define KEY_N_HEAD "clip.%s.attention.head_count" -#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" -#define KEY_PROJ_DIM "clip.%s.projection_dim" -#define KEY_TOKENS "tokenizer.ggml.tokens" -#define KEY_N_POSITIONS "clip.text.context_length" -#define KEY_IMAGE_SIZE "clip.vision.image_size" -#define KEY_PATCH_SIZE "clip.vision.patch_size" -#define KEY_IMAGE_MEAN "clip.vision.image_mean" -#define KEY_IMAGE_STD "clip.vision.image_std" -#define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_FTYPE "general.file_type" +#define KEY_NAME "general.name" +#define KEY_DESCRIPTION "general.description" +#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" +#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" +#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" +#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_USE_GELU "clip.use_gelu" +#define KEY_N_EMBD "clip.%s.embedding_length" +#define KEY_N_FF "clip.%s.feed_forward_length" +#define KEY_N_BLOCK "clip.%s.block_count" +#define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" +#define KEY_PROJ_DIM "clip.%s.projection_dim" +#define KEY_TOKENS "tokenizer.ggml.tokens" +#define KEY_N_POSITIONS "clip.text.context_length" +#define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PATCH_SIZE "clip.vision.patch_size" +#define KEY_IMAGE_MEAN "clip.vision.image_mean" +#define KEY_IMAGE_STD "clip.vision.image_std" +#define KEY_PROJ_TYPE "clip.projector_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -104,6 +110,7 @@ static std::string format(const char * fmt, ...) { #define TN_POS_EMBD "%s.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" #define TN_PATCH_EMBD "v.patch_embd.weight" +#define TN_PATCH_BIAS "v.patch_embd.bias" #define TN_ATTN_K "%s.blk.%d.attn_k.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" @@ -122,12 +129,21 @@ static std::string format(const char * fmt, ...) { #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_MINICPMV_POS_EMBD "resampler.pos_embed" +#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" +#define TN_MINICPMV_QUERY "resampler.query" +#define TN_MINICPMV_PROJ "resampler.proj.weight" +#define TN_MINICPMV_KV_PROJ "resampler.kv.weight" +#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" +#define TN_MINICPMV_LN "resampler.ln_%s.%s" + enum projector_type { PROJECTOR_TYPE_MLP, PROJECTOR_TYPE_MLP_NORM, PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, + PROJECTOR_TYPE_RESAMPLER, PROJECTOR_TYPE_UNKNOWN, }; @@ -135,6 +151,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_MLP, "mlp" }, { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, + { PROJECTOR_TYPE_RESAMPLER, "resampler"}, }; @@ -195,17 +212,14 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int } static void replace_all(std::string & s, const std::string & search, const std::string & replace) { - std::string result; - for (size_t pos = 0; ; pos += search.length()) { - auto new_pos = s.find(search, pos); - if (new_pos == std::string::npos) { - result += s.substr(pos, s.size() - pos); - break; - } - result += s.substr(pos, new_pos - pos) + replace; - pos = new_pos; + if (search.empty()) { + return; // Avoid infinite loop if 'search' is an empty string + } + size_t pos = 0; + while ((pos = s.find(search, pos)) != std::string::npos) { + s.replace(pos, search.length(), replace); + pos += replace.length(); } - s = std::move(result); } static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { @@ -425,6 +439,7 @@ struct clip_vision_model { // embeddings struct ggml_tensor * class_embedding; struct ggml_tensor * patch_embeddings; + struct ggml_tensor * patch_bias; struct ggml_tensor * position_embeddings; struct ggml_tensor * pre_ln_w; @@ -486,12 +501,36 @@ struct clip_vision_model { struct ggml_tensor * mm_model_mlp_2_b; struct ggml_tensor * mm_model_peg_0_w; struct ggml_tensor * mm_model_peg_0_b; + + // MINICPMV projection + struct ggml_tensor * mm_model_pos_embed; + struct ggml_tensor * mm_model_pos_embed_k; + struct ggml_tensor * mm_model_query; + struct ggml_tensor * mm_model_proj; + struct ggml_tensor * mm_model_kv_proj; + struct ggml_tensor * mm_model_attn_q_w; + struct ggml_tensor * mm_model_attn_q_b; + struct ggml_tensor * mm_model_attn_k_w; + struct ggml_tensor * mm_model_attn_k_b; + struct ggml_tensor * mm_model_attn_v_w; + struct ggml_tensor * mm_model_attn_v_b; + struct ggml_tensor * mm_model_attn_o_w; + struct ggml_tensor * mm_model_attn_o_b; + struct ggml_tensor * mm_model_ln_q_w; + struct ggml_tensor * mm_model_ln_q_b; + struct ggml_tensor * mm_model_ln_kv_w; + struct ggml_tensor * mm_model_ln_kv_b; + struct ggml_tensor * mm_model_ln_post_w; + struct ggml_tensor * mm_model_ln_post_b; }; struct clip_ctx { bool has_text_encoder = false; bool has_vision_encoder = false; bool has_llava_projector = false; + bool has_minicpmv_projector = false; + int minicpmv_version = 2; + int max_slice_nums = 9; struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -501,6 +540,11 @@ struct clip_ctx { bool use_gelu = false; int32_t ftype = 1; + bool has_class_embedding = true; + bool has_pre_norm = true; + bool has_post_norm = false; + bool has_patch_bias = false; + struct gguf_context * ctx_gguf; struct ggml_context * ctx_data; @@ -511,9 +555,11 @@ struct clip_ctx { ggml_backend_t backend = NULL; ggml_gallocr_t compute_alloc = NULL; + + struct clip_image_size * load_image_size; }; -static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) { +static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); return nullptr; @@ -522,20 +568,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const auto & model = ctx->vision_model; const auto & hparams = model.hparams; - const int image_size = hparams.image_size; + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + if (ctx->has_minicpmv_projector) { + if (load_image_size == nullptr) { + load_image_size = clip_image_size_init(); + } + LOG_TEE("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height); + image_size_width = load_image_size->width; + image_size_height = load_image_size->height; + if (is_inf) { + image_size_width = imgs->data->nx; + image_size_height = imgs->data->ny; + } + } const int patch_size = hparams.patch_size; - const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); - const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side); - const int num_positions = num_patches + 1; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; - const int n_layer = hparams.n_layer; + int n_layer = hparams.n_layer; const float eps = hparams.eps; const int batch_size = imgs->size; - if (ctx->has_llava_projector) { + if (ctx->has_llava_projector || ctx->has_minicpmv_projector) { GGML_ASSERT(batch_size == 1); } @@ -548,7 +607,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); @@ -557,16 +616,25 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); - // concat class_embeddings and patch_embeddings - struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); - ggml_set_name(embeddings, "embeddings"); - ggml_set_input(embeddings); - - embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + if (ctx->has_patch_bias) { + // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + } + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed = nullptr; - embeddings = ggml_acc(ctx0, embeddings, inp, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + if (ctx->has_llava_projector) { + // concat class_embeddings and patch_embeddings + if (ctx->has_class_embedding) { + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_name(embeddings, "embeddings"); + ggml_set_input(embeddings); + embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + } + } struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); ggml_set_name(positions, "positions"); @@ -575,8 +643,27 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); + if (ctx->has_minicpmv_projector) { + int pos_w = image_size_width/patch_size; + int pos_h = image_size_height/patch_size; + if (ctx->minicpmv_version == 1) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2304, pos_w * pos_h, 1); + } + else if (ctx->minicpmv_version == 2) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + } + else if (ctx->minicpmv_version == 3) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } + else if (ctx->minicpmv_version == 4) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + } + ggml_set_name(pos_embed, "pos_embed"); + ggml_set_input(pos_embed); + } + // pre-layernorm - { + if (ctx->has_pre_norm) { embeddings = ggml_norm(ctx0, embeddings, eps); ggml_set_name(embeddings, "pre_ln"); @@ -584,6 +671,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // loop over layers + if (ctx->has_minicpmv_projector) { + n_layer += 1; + } for (int il = 0; il < n_layer - 1; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states @@ -664,8 +754,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = cur; } + // post-layernorm + if (ctx->has_post_norm) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } + // llava projector - { + if (ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); @@ -686,8 +784,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - - } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + } + else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -846,6 +944,87 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } + // minicpmv projector + else if (ctx->has_minicpmv_projector) + { + if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + struct ggml_tensor * q = model.mm_model_query; + { // layernorm + q = ggml_norm(ctx0, q, eps); + q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + } + struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); + { // layernorm + v = ggml_norm(ctx0, v, eps); + v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); + } + struct ggml_tensor * k; + { // position + if (ctx->minicpmv_version == 1) { + q = ggml_add(ctx0, q, model.mm_model_pos_embed); + } + k = ggml_add(ctx0, v, pos_embed); + } + + { // attention + int hidden_size = 4096; + const int d_head = 128; + int n_head = hidden_size/d_head; + int num_query = 96; + if (ctx->minicpmv_version == 1) { + hidden_size = 2304; + n_head = hidden_size/d_head; + num_query = 64; + } + else if (ctx->minicpmv_version == 2) { + hidden_size = 4096; + n_head = hidden_size/d_head; + num_query = 96; + } + else if (ctx->minicpmv_version == 3) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } + else if (ctx->minicpmv_version == 4) { + hidden_size = 4096; + n_head = hidden_size/d_head; + num_query = 96; + } + + struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); + struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); + // permute + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); + + embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); + } + { // layernorm + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b); + } + embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); + } + else { + GGML_ASSERT(false); + } + } // build the graph ggml_build_forward_expand(gf, embeddings); @@ -979,6 +1158,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_TEE("%s: CLIP using Metal backend\n", __func__); #endif +#ifdef GGML_USE_CANN + new_clip->backend = ggml_backend_cann_init(0); + LOG_TEE("%s: CLIP using CANN backend\n", __func__); +#endif + if (!new_clip->backend) { new_clip->backend = ggml_backend_cpu_init(); @@ -998,7 +1182,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); } - GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search + idx = gguf_find_key(ctx, KEY_HAS_MINICPMV_PROJ); + if (idx != -1) { + new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); + } + + idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION); + if (idx != -1) { + new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); + } + + // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search + GGML_ASSERT(new_clip->has_vision_encoder); GGML_ASSERT(!new_clip->has_text_encoder); @@ -1009,6 +1204,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_TEE("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); LOG_TEE("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); LOG_TEE("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); + LOG_TEE("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); LOG_TEE("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); LOG_TEE("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } @@ -1099,20 +1295,20 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } if (n < 32) hparams.image_grid_pinpoints[n] = 0; - } catch (std::runtime_error & e) { + } catch (std::runtime_error & /*e*/) { hparams.image_grid_pinpoints[0]=0; } try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx)); - } catch (std::runtime_error & e) { + } catch (std::runtime_error & /*e*/) { strcpy(hparams.mm_patch_merge_type, "flat"); } try { hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6 - } catch(const std::exception& e) { + } catch(const std::exception& /*e*/) { hparams.image_crop_resolution = hparams.image_size; } @@ -1148,13 +1344,40 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } + try { + vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD); + new_clip->has_class_embedding = true; + } catch (const std::exception& /*e*/) { + new_clip->has_class_embedding = false; + } + + try { + vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); + vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); + new_clip->has_pre_norm = true; + } catch (std::exception & /*e*/) { + new_clip->has_pre_norm = false; + } + + try { + vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight")); + vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias")); + new_clip->has_post_norm = true; + } catch (std::exception & /*e*/) { + new_clip->has_post_norm = false; + } + + try { + vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS); + new_clip->has_patch_bias = true; + } catch (std::exception & /*e*/) { + new_clip->has_patch_bias = false; + } + try { vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); - vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD); vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); - vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); - vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); - } catch(const std::exception& e) { + } catch(const std::exception& /*e*/) { LOG_TEE("%s: failed to load vision model tensors\n", __func__); } @@ -1166,26 +1389,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { // Yi-type llava vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight")); vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { // missing in Yi-type llava vision_model.mm_2_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); vision_model.mm_2_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { // Yi-type llava vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight")); vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { // Yi-type llava vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight")); vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE); // LOG_TEE("%s: image_newline tensor (llava-1.6) found\n", __func__); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) { // MobileVLM projection vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight")); @@ -1223,6 +1446,29 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight")); vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_RESAMPLER) { + if (new_clip->minicpmv_version == 1) { + vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); + } + vision_model.mm_model_pos_embed_k = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD_K); + vision_model.mm_model_query = get_tensor(new_clip->ctx_data, TN_MINICPMV_QUERY); + vision_model.mm_model_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_PROJ); + vision_model.mm_model_kv_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_KV_PROJ); + vision_model.mm_model_attn_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "weight")); + vision_model.mm_model_attn_k_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "weight")); + vision_model.mm_model_attn_v_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "weight")); + vision_model.mm_model_attn_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "bias")); + vision_model.mm_model_attn_k_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "bias")); + vision_model.mm_model_attn_v_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "bias")); + vision_model.mm_model_attn_o_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "weight")); + vision_model.mm_model_attn_o_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "bias")); + vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "weight")); + vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "bias")); + vision_model.mm_model_ln_kv_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "weight")); + vision_model.mm_model_ln_kv_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "bias")); + vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); + vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); + } else { std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); @@ -1261,7 +1507,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend)); clip_image_f32_batch batch; batch.size = 1; - ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch); + ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); ggml_gallocr_reserve(new_clip->compute_alloc, gf); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); LOG_TEE("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); @@ -1270,6 +1516,17 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { return new_clip; } +void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { + ctx_clip->load_image_size = load_image_size; +} + +struct clip_image_size * clip_image_size_init() { + struct clip_image_size * load_image_size = new struct clip_image_size(); + load_image_size->width = 448; + load_image_size->height = 448; + return load_image_size; +} + struct clip_image_u8 * clip_image_u8_init() { return new clip_image_u8(); } @@ -1325,7 +1582,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length } // Linear interpolation between two points -inline float lerp(float s, float e, float t) { +inline float clip_lerp(float s, float e, float t) { return s + (e - s) * t; } // Bilinear resize function @@ -1347,17 +1604,17 @@ static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int ta float y_lerp = py - y_floor; for (int c = 0; c < 3; c++) { - float top = lerp( + float top = clip_lerp( static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]), static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]), x_lerp ); - float bottom = lerp( + float bottom = clip_lerp( static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]), static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]), x_lerp ); - dst.buf[3 * (y * target_width + x) + c] = static_cast(lerp(top, bottom, y_lerp)); + dst.buf[3 * (y * target_width + x) + c] = static_cast(clip_lerp(top, bottom, y_lerp)); } } } @@ -1540,9 +1797,211 @@ static std::vector divide_to_patches_u8(const clip_image_u8 & im return patches; } +static int ensure_divide(int length, int patch_size) { + return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); +} + +static std::pair uhd_find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.first; + int height = original_size.second; + if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { + float r = static_cast(width) / height; + height = static_cast(scale_resolution / std::sqrt(r)); + width = static_cast(height * r); + } + int best_width = ensure_divide(width, patch_size); + int best_height = ensure_divide(height, patch_size); + return std::make_pair(best_width, best_height); +} + +static std::pair uhd_get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width, height; + std::tie(width, height) = original_size; + int grid_x, grid_y; + std::tie(grid_x, grid_y) = grid; + + int refine_width = ensure_divide(width, grid_x); + int refine_height = ensure_divide(height, grid_y); + + int grid_width = refine_width / grid_x; + int grid_height = refine_height / grid_y; + + // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) + auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair + int best_grid_width, best_grid_height; + std::tie(best_grid_width, best_grid_height) = best_grid_size; + + // std::pair refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) + std::pair refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line) + return refine_size; +} + +inline int clip(int x, int lower, int upper) { + return std::max(lower, std::min(x, upper)); +} + +static std::pair uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { + std::vector candidate_split_grids_nums; + for (int i : {multiple - 1, multiple, multiple + 1}) { + if (i == 1 || i > max_slice_nums) { + continue; + } + candidate_split_grids_nums.push_back(i); + } + + std::vector> candidate_grids; + for (int split_grids_nums : candidate_split_grids_nums) { + int m = 1; + while (m <= split_grids_nums) { + if (split_grids_nums % m == 0) { + candidate_grids.emplace_back(m, split_grids_nums / m); + } + ++m; + } + } + + std::pair best_grid{1, 1}; + float min_error = std::numeric_limits::infinity(); + for (const auto& grid : candidate_grids) { + float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second)); + if (error < min_error) { + best_grid = grid; + min_error = error; + } + } + return best_grid; +} + +// inspired from LLaVA-UHD: +// -> https://arxiv.org/pdf/2403.11703 +// -> https://github.com/thunlp/LLaVA-UHD +// -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 +static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { + const std::pair original_size={img->nx,img->ny}; + const int original_width = img->nx; + const int original_height = img->ny; + const float log_ratio = log(1.0*original_width/original_height); + const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); + const int multiple = fmin(ceil(ratio), max_slice_nums); + + std::vector> images; + LOG_TEE("%s: multiple %d\n", __func__, multiple); + images.push_back(std::vector()); + + if (multiple <= 1) { + auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); + clip_image_u8 * source_image = clip_image_u8_init(); + bicubic_resize(*img, *source_image, best_size.first, best_size.second); + // source_image = image.resize(best_size, Image.Resampling.BICUBIC) + images[images.size()-1].push_back(source_image); + } + else if (multiple > 1) { + auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); + clip_image_u8 * source_image = clip_image_u8_init(); + bicubic_resize(*img, *source_image, best_size.first, best_size.second); + // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + LOG_TEE("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); + images[images.size()-1].push_back(source_image); + + std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); + LOG_TEE("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); + + auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); + clip_image_u8 * refine_image = clip_image_u8_init(); + bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); + + LOG_TEE("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); + + // split_to_patches + int width = refine_image->nx; + int height = refine_image->ny; + int grid_x = int(width / best_grid.first); + int grid_y = int(height / best_grid.second); + for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ + images.push_back(std::vector()); + for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ + clip_image_u8 * patch = clip_image_u8_init(); + patch->nx = grid_x; + patch->ny = grid_y; + patch->buf.resize(3 * patch->nx * patch->ny); + for (int y = patches_i; y < patches_i + grid_y; ++y) { + for (int x = patches_j; x < patches_j + grid_x; ++x) { + const int i = 3 * (y * refine_image->nx + x); + const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j)); + patch->buf[j] = refine_image->buf[i]; + patch->buf[j+1] = refine_image->buf[i+1]; + patch->buf[j+2] = refine_image->buf[i+2]; + } + } + images[images.size()-1].push_back(patch); + } + } + } + return images; +} + +int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip){ + const int max_slice_nums=ctx_clip->max_slice_nums; + const int scale_resolution=448; + const int original_width = ctx_clip->load_image_size->width; + const int original_height = ctx_clip->load_image_size->height; + const float log_ratio = log(1.0*original_width/original_height); + const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); + const int multiple = fmin(ceil(ratio), max_slice_nums); + std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); + return best_grid.first; +} + // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { + + if(clip_is_minicpmv(ctx)){ + if (ctx->minicpmv_version >1) + { + std::vector> imgs = uhd_slice_image(img, ctx->max_slice_nums); + res_imgs->size = 0; + for (size_t i = 0; i < imgs.size(); ++i){ + res_imgs->size += imgs[i].size(); + } + res_imgs->data = new clip_image_f32[res_imgs->size]; + int idx = 0; + for (size_t i = 0; i < imgs.size(); ++i){ + for (size_t j = 0; j < imgs[i].size(); ++j) { + LOG_TEE("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); + clip_image_f32 * res = clip_image_f32_init(); + normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std); + res_imgs->data[idx++] = *res; + clip_image_f32_free(res); + } + } + return true; + } + else { + if (res_imgs->size == 0){ + std::vector> imgs = uhd_slice_image(img, ctx->max_slice_nums); + res_imgs->size = 0; + for (size_t i = 0; i < imgs.size(); ++i){ + res_imgs->size += imgs[i].size(); + } + res_imgs->data = new clip_image_f32[res_imgs->size]; + int idx = 0; + + for (size_t i = 0; i < imgs.size(); ++i){ + for (size_t j = 0; j < imgs[i].size(); ++j) { + LOG_TEE("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); + clip_image_f32_batch img_res_v_batch; + img_res_v_batch.size = 1; + img_res_v_batch.data = nullptr; + clip_image_preprocess(ctx, imgs[i][j], &img_res_v_batch); + res_imgs->data[idx++] = img_res_v_batch.data[0]; + } + } + return true; + } + } + } + bool pad_to_square = true; if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); @@ -1758,11 +2217,110 @@ int clip_n_patches(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { n_patches /= 4; + } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + if (ctx->minicpmv_version == 1) { + n_patches = 64; + } + else if (ctx->minicpmv_version == 2) { + n_patches = 96; + } + else if (ctx->minicpmv_version == 3) { + n_patches = 64; + } + else if (ctx->minicpmv_version == 4) { + n_patches = 96; + } } return n_patches; } +static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) { + assert(embed_dim % 2 == 0); + int H = pos.size(); + int W = pos[0].size(); + + std::vector omega(embed_dim / 2); + for (int i = 0; i < embed_dim / 2; ++i) { + omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2)); + } + + std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + for (int d = 0; d < embed_dim / 2; ++d) { + float out_value = pos[h][w] * omega[d]; + emb[h][w][d] = sin(out_value); + emb[h][w][d + embed_dim / 2] = cos(out_value); + } + } + } + + return emb; +} + +static std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>> & grid) { + assert(embed_dim % 2 == 0); + std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2) + std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2) + + int H = emb_h.size(); + int W = emb_h[0].size(); + std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); + + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + for (int d = 0; d < embed_dim / 2; ++d) { + emb[h][w][d] = emb_h[h][w][d]; + emb[h][w][d + embed_dim / 2] = emb_w[h][w][d]; + } + } + } + return emb; +} + +static std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) { + int grid_h_size = image_size.first; + int grid_w_size = image_size.second; + + std::vector grid_h(grid_h_size); + std::vector grid_w(grid_w_size); + + for (int i = 0; i < grid_h_size; ++i) { + grid_h[i] = static_cast(i); + } + for (int i = 0; i < grid_w_size; ++i) { + grid_w[i] = static_cast(i); + } + + std::vector> grid(grid_h_size, std::vector(grid_w_size)); + for (int h = 0; h < grid_h_size; ++h) { + for (int w = 0; w < grid_w_size; ++w) { + grid[h][w] = grid_w[w]; + } + } + std::vector>> grid_2d = {grid, grid}; + for (int h = 0; h < grid_h_size; ++h) { + for (int w = 0; w < grid_w_size; ++w) { + grid_2d[0][h][w] = grid_h[h]; + grid_2d[1][h][w] = grid_w[w]; + } + } + + std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d); + + int H = image_size.first; + int W = image_size.second; + std::vector> pos_embed_2d(H * W, std::vector(embed_dim)); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + pos_embed_2d[w * H + h] = pos_embed_3d[h][w]; + } + } + + return pos_embed_2d; +} + bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); @@ -1785,20 +2343,34 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ctx->has_llava_projector) { GGML_ASSERT(batch_size == 1); // TODO: support multiple images } + if (ctx->has_minicpmv_projector) { + GGML_ASSERT(batch_size == 1); + } // build the inference graph - ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); + ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); // set inputs const auto & model = ctx->vision_model; const auto & hparams = model.hparams; - const int image_size = hparams.image_size; + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + if (ctx->has_minicpmv_projector) { + image_size_width = imgs->data[0].nx; + image_size_height = imgs->data[0].ny; + } const int patch_size = hparams.patch_size; - const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); - const int num_positions = num_patches + 1; - + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + if(ctx->load_image_size==nullptr){ + ctx->load_image_size= clip_image_size_init(); + } + const int pos_w = ctx->load_image_size->width/patch_size; + const int pos_h = ctx->load_image_size->height/patch_size; + { struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); float * data = (float *)malloc(ggml_nbytes(inp_raw)); @@ -1806,7 +2378,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (size_t i = 0; i < imgs->size; i++) { const int nx = imgs->data[i].nx; const int ny = imgs->data[i].ny; - GGML_ASSERT(nx == image_size && ny == image_size); + if (!ctx->has_minicpmv_projector) { + GGML_ASSERT(nx == image_size && ny == image_size); + } const int n = nx * ny; @@ -1823,35 +2397,104 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); free(data); } + if (ctx->minicpmv_version) { + if (ctx->minicpmv_version == 1) + { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - { - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = i; + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + else { + // inspired from siglip: + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + int bucket_coords_h[70]; + int bucket_coords_w[70]; + for (int i = 0; i < pos_h; i++){ + bucket_coords_h[i] = std::floor(70.0*i/pos_h); + } + for (int i = 0; i < pos_w; i++){ + bucket_coords_w[i] = std::floor(70.0*i/pos_w); + } + for (int i = 0, id = 0; i < pos_h; i++){ + for (int j = 0; j < pos_w; j++){ + positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + } + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + + { + // inspired from resampler of Qwen-VL: + // -> https://huggingface.co/Qwen/Qwen-VL/tree/main + // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 + struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); + int embed_dim = 4096; + if (ctx->minicpmv_version == 1) { + embed_dim = 2304; + } + else if (ctx->minicpmv_version == 2) { + embed_dim = 4096; + } + else if (ctx->minicpmv_version == 3) { + embed_dim = 3584; + } + else if (ctx->minicpmv_version == 4) { + embed_dim = 4096; + } + auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); - void* zero_mem = malloc(ggml_nbytes(embeddings)); - memset(zero_mem, 0, ggml_nbytes(embeddings)); - ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); - free(zero_mem); + float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); + for(int i=0;ihas_class_embedding) { + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); - { - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + void* zero_mem = malloc(ggml_nbytes(embeddings)); + memset(zero_mem, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); + free(zero_mem); + } + } - int* positions_data = (int*)malloc(ggml_nbytes(positions)); - for (int i = 0; i < num_positions; i++) { - positions_data[i] = i; + { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = i; + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); - } - { - struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); - int* patches_data = (int*)malloc(ggml_nbytes(patches)); - for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + 1; + { + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + int* patches_data = (int*)malloc(ggml_nbytes(patches)); + for (int i = 0; i < num_patches; i++) { + patches_data[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); + free(patches_data); } - ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); - free(patches_data); } if (ggml_backend_is_cpu(ctx->backend)) { @@ -2021,7 +2664,32 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { return ctx->vision_model.mm_3_b->ne[0]; } + if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + if (ctx->minicpmv_version == 1) { + return 2304; + } + else if (ctx->minicpmv_version == 2) { + return 4096; + } + else if (ctx->minicpmv_version == 3) { + return 3584; + } + else if (ctx->minicpmv_version == 4) { + return 4096; + } + } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); } + +int clip_is_minicpmv(const struct clip_ctx * ctx) { + if (ctx->has_minicpmv_projector) { + return ctx->minicpmv_version; + } + return 0; +} + +void clip_uhd_max_slice_nums(struct clip_ctx * ctx, int max_slice_nums) { + ctx->max_slice_nums = max_slice_nums; +} \ No newline at end of file diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 45bdad6897658..7a0f479261e32 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -18,14 +18,17 @@ # define CLIP_API #endif -struct clip_ctx; - #ifdef __cplusplus extern "C" { #endif struct clip_ctx; +struct clip_image_size { + int width; + int height; +}; + struct clip_image_u8_batch { struct clip_image_u8 * data; size_t size; @@ -55,6 +58,10 @@ CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); CLIP_API int clip_n_patches (const struct clip_ctx * ctx); CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); +CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); +CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); + +CLIP_API struct clip_image_size * clip_image_size_init(); CLIP_API struct clip_image_u8 * clip_image_u8_init (); CLIP_API struct clip_image_f32 * clip_image_f32_init(); @@ -68,7 +75,7 @@ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); -/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */ +/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); @@ -78,6 +85,9 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); +CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API void clip_uhd_max_slice_nums(struct clip_ctx * ctx, int max_slice_nums); + #ifdef __cplusplus } #endif diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 9a990bb182f35..ff054bb42fd8c 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -88,7 +88,6 @@ static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair< // Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out) static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) { struct { - struct ggml_tensor * newline; struct ggml_context * ctx; } model; @@ -150,20 +149,6 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector model.ctx = ggml_init(params); - ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip); - model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]); - if (newline_tmp->backend != GGML_BACKEND_TYPE_CPU) { - if (newline_tmp->buffer == NULL) { - LOG_TEE("newline_tmp tensor buffer is NULL\n"); - } - ggml_backend_tensor_get(newline_tmp, model.newline->data, 0, ggml_nbytes(newline_tmp)); - } else { - model.newline->data = newline_tmp->data; - if (model.newline->data == NULL) { - LOG_TEE("newline_tmp tensor data is NULL\n"); - } - } - struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4 // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); // fill it with the image embeddings, ignoring the base @@ -217,6 +202,33 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector return true; } +static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { + int width = image->nx; + int height = image->ny; + int num_patches = (height / patch_size) * (width / patch_size); + clip_image_f32 * patch = clip_image_f32_init(); + patch->nx = patch_size * num_patches; + patch->ny = patch_size; + patch->buf.resize(3 * patch->nx * patch->ny); + + int patch_index = 0; + + for (int i = 0; i < height; i += patch_size) { + for (int j = 0; j < width; j += patch_size) { + for (int pi = 0; pi < patch_size; ++pi) { + for (int pj = 0; pj < patch_size; ++pj) { + int input_index = ((i + pi) * width + (j + pj)) * 3; + int output_index = (pi * patch_size * num_patches + patch_index * patch_size + pj) * 3; + patch->buf[output_index] = image->buf[input_index]; + patch->buf[output_index+1] = image->buf[input_index+1]; + patch->buf[output_index+2] = image->buf[input_index+2]; + } + } + patch_index++; + } + } + return patch; +} static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 @@ -233,7 +245,51 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); - if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { + if (clip_is_minicpmv(ctx_clip)) { + std::vector image_embd_v; + image_embd_v.resize(img_res_v.size); + struct clip_image_size * load_image_size = clip_image_size_init(); + for (size_t i = 0; i < img_res_v.size; i++) { + const int64_t t_img_enc_step_start_us = ggml_time_us(); + image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); + int patch_size = 14; + load_image_size->width = img_res_v.data[i].nx; + load_image_size->height = img_res_v.data[i].ny; + clip_add_load_image_size(ctx_clip, load_image_size); + bool encoded = false; + int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); + if (has_minicpmv_projector == 2) { + encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + } + else { + encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); + } + if (!encoded) { + LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + return false; + } + const int64_t t_img_enc_steop_batch_us = ggml_time_us(); + LOG_TEE("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); + } + const int64_t t_img_enc_batch_us = ggml_time_us(); + LOG_TEE("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + + int n_img_pos_out = 0; + for (size_t i = 0; i < image_embd_v.size(); i++) { + std::memcpy(image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], clip_embd_nbytes(ctx_clip)); + n_img_pos_out += clip_n_patches(ctx_clip); + } + *n_img_pos = n_img_pos_out; + for (size_t i = 0; i < image_embd_v.size(); i++) { + free(image_embd_v[i]); + } + image_embd_v.clear(); + load_image_size->width = img->nx; + load_image_size->height = img->ny; + clip_add_load_image_size(ctx_clip, load_image_size); + LOG_TEE("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + } + else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 @@ -243,7 +299,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return false; } - } else { + } + else { // spatial_unpad llava-1.6 type embedding // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working std::vector image_embd_v; @@ -312,7 +369,11 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * } bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*6); // TODO: base on gridsize/llava model + int num_max_patches = 6; + if (clip_is_minicpmv(ctx_clip)) { + num_max_patches = 10; + } + float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model if (!image_embd) { LOG_TEE("Unable to allocate memory for image embeddings\n"); return false; diff --git a/examples/llava/llava.h b/examples/llava/llava.h index 19212f6e9e9c5..b6feb3027b2da 100644 --- a/examples/llava/llava.h +++ b/examples/llava/llava.h @@ -17,12 +17,11 @@ # define LLAVA_API #endif -struct clip_ctx; - #ifdef __cplusplus extern "C" { #endif +struct clip_ctx; struct llava_image_embed { float * embed; int n_image_pos; @@ -37,8 +36,8 @@ LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); /** build an image embed from a path to an image filename */ LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); /** free an embedding made with llava_image_embed_make_* */ +LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp new file mode 100644 index 0000000000000..6436e99c08523 --- /dev/null +++ b/examples/llava/minicpmv-cli.cpp @@ -0,0 +1,371 @@ +#include "ggml.h" +#include "log.h" +#include "common.h" +#include "clip.h" +#include "llava.h" +#include "llama.h" + +#include +#include +#include + +// extern "C" { +// #include +// #include +// #include +// #include +// } + +struct llava_context { + struct clip_ctx * ctx_clip = NULL; + struct llama_context * ctx_llama = NULL; + struct llama_model * model = NULL; +}; + +struct clip_image_u8 { + int nx; + int ny; + std::vector buf; +}; + +static void show_additional_info(int /*argc*/, char ** argv) { + LOG_TEE("\n example usage: %s -m --mmproj [--video ] [--image ] [--image ] [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); +} + +static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + LOG_TEE("%s", text); +} + +static struct llama_model * llava_init(gpt_params * params) { + llama_backend_init(); + llama_numa_init(params->numa); + + llama_model_params model_params = llama_model_params_from_gpt_params(*params); + + llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + if (model == NULL) { + LOG_TEE("%s: error: unable to load model\n" , __func__); + return NULL; + } + return model; +} + +static struct llava_context * llava_init_context(gpt_params * params) { + auto model = llava_init(params); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__); + return NULL; + } + + const char * clip_path = params->mmproj.c_str(); + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + + llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); + if (params->n_ctx < 2048) { + // warn user here, "Image processing requires at least 2048 context, setting context to 2048" + LOG_TEE("%s: warn: Image processing requires at least 2048 context, setting context to 2048\n" , __func__); + ctx_params.n_ctx = 2048; + } else { + ctx_params.n_ctx = params->n_ctx; + } + + llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + + if (ctx_llama == NULL) { + LOG_TEE("%s: error: failed to create the llama_context\n" , __func__); + return NULL; + } + + auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); + + ctx_llava->ctx_llama = ctx_llama; + ctx_llava->ctx_clip = ctx_clip; + ctx_llava->model = model; + return ctx_llava; +} + +static void llava_free(struct llava_context * ctx_llava) { + if (ctx_llava->ctx_clip) { + clip_free(ctx_llava->ctx_clip); + ctx_llava->ctx_clip = NULL; + } + + llama_free(ctx_llava->ctx_llama); + llama_free_model(ctx_llava->model); + llama_backend_free(); +} + +static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { + int N = (int) tokens.size(); + for (int i = 0; i < N; i += n_batch) { + int n_eval = (int) tokens.size() - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { + LOG_TEE("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); + return false; + } + *n_past += n_eval; + } + return true; +} + +static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { + std::vector tokens; + tokens.push_back(id); + return eval_tokens(ctx_llama, tokens, 1, n_past); +} + +static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ + std::string str2 = str; + std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); + return eval_tokens(ctx_llama, embd_inp, n_batch, n_past); +} + +static void process_eval_image_embed(struct llava_context * ctx_llava, const struct llava_image_embed * embeds, int n_batch, int * n_past, int idx) { + float * image_embed = (float *)malloc(clip_embd_nbytes(ctx_llava->ctx_clip)); + std::memcpy(image_embed, embeds->embed + idx * clip_n_patches(ctx_llava->ctx_clip) * clip_n_mmproj_embd(ctx_llava->ctx_clip), clip_embd_nbytes(ctx_llava->ctx_clip)); + + auto slice_embed = (llava_image_embed*)malloc(sizeof(llava_image_embed)); + slice_embed->embed = image_embed; + slice_embed->n_image_pos = clip_n_patches(ctx_llava->ctx_clip); + llava_eval_image_embed(ctx_llava->ctx_llama, slice_embed, n_batch, n_past); + llava_image_embed_free(slice_embed); +} + +static int process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, gpt_params * params, int &n_past) { + std::string system_prompt; + bool res = false; + int idx = 0; + int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip); + LOG_TEE("%s: image token past: %d\n", __func__, n_past); + eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); + process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); + res = eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + if (num_image_embeds > 1) { + size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip); + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) { + for (size_t j = 0; j < num_image_embeds_col; ++j) { + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + if (j == num_image_embeds_col - 1) { + eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false); + } + } + } + res = eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + } + res = eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false); + LOG_TEE("%s: image token past: %d\n", __func__, n_past); + if(!res) return 0; + return n_past; +} + +static bool process_prompt(int type, struct llava_context * ctx_llava, gpt_params * params, int &n_past, std::string prompt = ""){ + int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); + if (type==0) { + std::string system_prompt; + if (has_minicpmv_projector == 1) { + system_prompt = "<用户>"; + } + else if (has_minicpmv_projector == 2) { + system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; + } + else if (has_minicpmv_projector == 3) { + system_prompt = "<|im_start|>user\n"; + } + else if (has_minicpmv_projector == 4) { + system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; + } + return eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, false); + } + else if (type==1) { + std::string user_prompt = prompt; + return eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); + } + else if (type==2) { + if (has_minicpmv_projector == 1) { + return eval_string(ctx_llava->ctx_llama, "\n", params->n_batch, &n_past, false); + } + else if (has_minicpmv_projector == 2) { + return eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); + } + else if (has_minicpmv_projector == 3) { + return eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); + } + else if (has_minicpmv_projector == 4) { + return eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); + } + } + return 0; +} + +static const char * sample(struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_llama, + int * n_past) { + const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); + llama_sampling_accept(ctx_sampling, ctx_llama, id, true); + static std::string ret; + if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + ret = ""; + } else { + ret = llama_token_to_piece(ctx_llama, id); + } + eval_id(ctx_llama, id, n_past); + return ret.c_str(); +} + +static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){ + auto ctx_llava = llava_init_context(params); + auto embeds = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str()); + if (!embeds) { + LOG_TEE("error: failed to load image %s. Terminating\n\n", fname.c_str()); + return NULL; + } + + // process the prompt + if (params->prompt.empty() && params->interactive == false) { + LOG_TEE("prompt should be given or interactive mode should be on"); + return NULL; + } + + const int64_t t_llava_init_start_us = ggml_time_us(); + + const int64_t t_llava_init_end_us = ggml_time_us(); + float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0; + LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); + + const int64_t t_process_image_start_us = ggml_time_us(); + process_prompt(0, ctx_llava, params, n_past); + process_image(ctx_llava, embeds, params, n_past); + const int64_t t_process_image_end_us = ggml_time_us(); + float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0; + LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); + + llava_image_embed_free(embeds); + return ctx_llava; +} + +static int process_input(struct llava_context * ctx_llava, gpt_params * params, int type, std::string prompt, int &n_past, struct llava_image_embed * embeds = nullptr){ + if (type==0) { + if (process_prompt(1, ctx_llava, params, n_past, prompt)) return 1; + } + else if (type == 1) { + if(embeds != NULL){ + return (process_image(ctx_llava, embeds, params, n_past)); + } + } + return 0; +} + +static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ + std::string user_prompt = prompt; + if(is_first)process_prompt(0, ctx_llava, params, n_past); + process_prompt(1, ctx_llava, params, n_past, prompt); + process_prompt(2, ctx_llava, params, n_past); + + // generate the response + LOG_TEE("\n"); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + return ctx_sampling; +} + +static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){ + + const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + return tmp; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + show_additional_info(argc, argv); + return 1; + } + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("llava", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); + llama_log_set(llama_log_callback_logTee, nullptr); +#endif // LOG_DISABLE_LOGS + + if (params.mmproj.empty()) { + gpt_print_usage(argc, argv, params); + show_additional_info(argc, argv); + return 1; + } + + int n_past = 0; + struct llava_context * ctx_llava = nullptr; + + if (params.image.size() > 0) { + auto image = params.image; + ctx_llava = minicpmv_init(¶ms, image, n_past); + + if (!params.prompt.empty()) { + LOG_TEE("%s\n", params.prompt.c_str()); + LOG_TEE(""); + auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, false); + const int max_tgt_len = params.n_predict < 0 ? 8192 : params.n_predict; + std::string response = ""; + bool have_tmp = false; + for (int i = 0; i < max_tgt_len; i++) { + auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + response += tmp; + if (strcmp(tmp, "") == 0){ + if(!have_tmp)continue; + else break; + } + have_tmp = true; + printf("%s", tmp); + if (strstr(response.c_str(), "")) break; // minicpm-v + + fflush(stdout); + } + llama_sampling_free(ctx_sampling); + } + else { + while (true) { + LOG_TEE(""); + std::string prompt; + std::getline(std::cin, prompt); + LOG_TEE(""); + auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true); + const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; + std::string response = ""; + for (int i = 0; i < max_tgt_len; i++) { + auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + response += tmp; + if (strcmp(tmp, "") == 0) break; + if (strstr(tmp, "###")) break; // Yi-VL behavior + printf("%s", tmp);// mistral llava-1.6 + if (strstr(response.c_str(), "")) break; // minicpm-v + fflush(stdout); + } + llama_sampling_free(ctx_sampling); + } + } + } + printf("\n"); + llama_print_timings(ctx_llava->ctx_llama); + + ctx_llava->model = NULL; + llava_free(ctx_llava); + // } + + return 0; +}