diff --git a/.gitignore b/.gitignore index a0c16e880b719..6c3a0abf09264 100644 --- a/.gitignore +++ b/.gitignore @@ -71,7 +71,6 @@ models-mnt !models/ggml-vocab-*.gguf* # Zig - zig-out/ zig-cache/ diff --git a/Makefile b/Makefile index 3aad77394c5ac..2ffee6338d6ac 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,7 @@ BUILD_TARGETS = \ llama-imatrix \ llama-infill \ llama-llava-cli \ + llama-minicpmv-cli\ llama-lookahead \ llama-lookup \ llama-lookup-create \ @@ -949,6 +950,13 @@ llama-llava-cli: examples/llava/llava-cli.cpp examples/llava/clip.h examples/lla $(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS) +llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp examples/llava/clip.h examples/llava/clip.cpp examples/llava/llava.h examples/llava/llava.cpp examples/llava/minicpmv_wrapper.h examples/llava/minicpmv_wrapper.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual + $(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) + $(CXX) $(CXXFLAGS) -c examples/llava/minicpmv_wrapper.cpp -o $(call GET_OBJ_FILE, examples/llava/minicpmv_wrapper.cpp) + $(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp examples/llava/minicpmv_wrapper.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(call GET_OBJ_FILE, examples/llava/minicpmv_wrapper.cpp) -o $@ $(LDFLAGS) + llama-baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0b51c44c05e4e..12173746a37d6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -30,6 +30,7 @@ else() add_subdirectory(infill) add_subdirectory(llama-bench) add_subdirectory(llava) + add_subdirectory(minicpmv) add_subdirectory(lookahead) add_subdirectory(lookup) add_subdirectory(main) diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt index e9fa73acb097b..74eee5e6b036d 100644 --- a/examples/llava/CMakeLists.txt +++ b/examples/llava/CMakeLists.txt @@ -36,3 +36,8 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-cli) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + +add_library(minicpmv_wrapper OBJECT + minicpmv_wrapper.cpp +) +target_link_libraries(minicpmv_wrapper PRIVATE llava ${CMAKE_THREAD_LIBS_INIT}) diff --git a/examples/llava/README_minicpmv2.5.md b/examples/llava/README_minicpmv2.5.md new file mode 100644 index 0000000000000..b661b26466a36 --- /dev/null +++ b/examples/llava/README_minicpmv2.5.md @@ -0,0 +1,104 @@ +## MiniCPM-Llama3-V 2.5 + +### Usage + +Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) PyTorch model from huggingface to "MiniCPM-Llama3-V-2_5" folder. + +Clone llama.cpp and checkout to branch `minicpm-v2.5`: +```bash +git clone -b minicpm-v2.5 https://github.com/OpenBMB/llama.cpp.git +cd llama.cpp +``` + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us) + +```bash +python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 +python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 +python ./convert.py ../MiniCPM-Llama3-V-2_5/model --outtype f16 --vocab-type bpe + +# quantize int4 version +./quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + +Build for Linux or Mac + +```bash +make +make minicpmv-cli +``` + +Inference on Linux or Mac +``` +# run f16 version +./minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run quantized int4 version +./minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# or run in interactive mode +./minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i +``` + +### Android + +#### Build on Android device using Termux +We found that build on Android device would bring better runtime performance, so we recommend to build on device. + +[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required). + +Install tools in Termux: +``` +apt update && apt upgrade -y +apt install git make cmake +``` + +It's recommended to move your model inside the `~/` directory for best performance: +``` +cd storage/downloads +mv model.gguf ~/ +``` + +#### Building the Project using Android NDK +Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake. + +Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux: + +```bash +mkdir build-android +cd build-android +export NDK=/your_ndk_path +cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. +make +``` + +Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice). + +Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission: + +(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`) +``` +$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/ +$cd /data/data/com.termux/files/home/bin +$chmod +x ./* +``` + +Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/` + +``` +$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/ +$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/ +``` + +Now, you can start chatting: +``` +$cd /data/data/com.termux/files/home/bin +$./minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +``` + +### result +We use this command on Xiaomi 14 Pro, and the measured results. +``` +$./minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 -t 6 --image xx.jpg -p "What is in the image?" +``` +![alt text](assets/xiaomi14pro_test.jpeg) \ No newline at end of file diff --git a/examples/llava/assets/xiaomi14pro_test.jpeg b/examples/llava/assets/xiaomi14pro_test.jpeg new file mode 100644 index 0000000000000..8762c9c7b1495 Binary files /dev/null and b/examples/llava/assets/xiaomi14pro_test.jpeg differ diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 95fbe3d0216c4..9353f5a02283f 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -3,6 +3,7 @@ // I'll gradually clean and extend it // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" +#include "common.h" #include "log.h" #include "ggml.h" #include "ggml-alloc.h" @@ -70,26 +71,27 @@ static std::string format(const char * fmt, ...) { // key constants // -#define KEY_FTYPE "general.file_type" -#define KEY_NAME "general.name" -#define KEY_DESCRIPTION "general.description" -#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" -#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" -#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" -#define KEY_USE_GELU "clip.use_gelu" -#define KEY_N_EMBD "clip.%s.embedding_length" -#define KEY_N_FF "clip.%s.feed_forward_length" -#define KEY_N_BLOCK "clip.%s.block_count" -#define KEY_N_HEAD "clip.%s.attention.head_count" -#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" -#define KEY_PROJ_DIM "clip.%s.projection_dim" -#define KEY_TOKENS "tokenizer.ggml.tokens" -#define KEY_N_POSITIONS "clip.text.context_length" -#define KEY_IMAGE_SIZE "clip.vision.image_size" -#define KEY_PATCH_SIZE "clip.vision.patch_size" -#define KEY_IMAGE_MEAN "clip.vision.image_mean" -#define KEY_IMAGE_STD "clip.vision.image_std" -#define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_FTYPE "general.file_type" +#define KEY_NAME "general.name" +#define KEY_DESCRIPTION "general.description" +#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" +#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" +#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" +#define KEY_HAS_MiniCPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_USE_GELU "clip.use_gelu" +#define KEY_N_EMBD "clip.%s.embedding_length" +#define KEY_N_FF "clip.%s.feed_forward_length" +#define KEY_N_BLOCK "clip.%s.block_count" +#define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" +#define KEY_PROJ_DIM "clip.%s.projection_dim" +#define KEY_TOKENS "tokenizer.ggml.tokens" +#define KEY_N_POSITIONS "clip.text.context_length" +#define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PATCH_SIZE "clip.vision.patch_size" +#define KEY_IMAGE_MEAN "clip.vision.image_mean" +#define KEY_IMAGE_STD "clip.vision.image_std" +#define KEY_PROJ_TYPE "clip.projector_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -122,6 +124,14 @@ static std::string format(const char * fmt, ...) { #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" +// MINICPMV +// #define TN_MINICPMV_POS_EMBD "resampler.pos_embed" +#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" +#define TN_MINICPMV_QUERY "resampler.query" +#define TN_MINICPMV_PROJ "resampler.proj.weight" +#define TN_MINICPMV_KV_PROJ "resampler.kv.weight" +#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" +#define TN_MINICPMV_LN "resampler.ln_%s.%s" enum projector_type { @@ -129,6 +139,7 @@ enum projector_type { PROJECTOR_TYPE_MLP_NORM, PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, + PROJECTOR_TYPE_RESAMPLER, PROJECTOR_TYPE_UNKNOWN, }; @@ -136,6 +147,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_MLP, "mlp" }, { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, + { PROJECTOR_TYPE_RESAMPLER, "resampler"}, }; @@ -488,12 +500,34 @@ struct clip_vision_model { struct ggml_tensor * mm_model_mlp_2_b; struct ggml_tensor * mm_model_peg_0_w; struct ggml_tensor * mm_model_peg_0_b; + + // MINICPMV projection + // struct ggml_tensor * mm_model_pos_embed; + struct ggml_tensor * mm_model_pos_embed_k; + struct ggml_tensor * mm_model_query; + struct ggml_tensor * mm_model_proj; + struct ggml_tensor * mm_model_kv_proj; + struct ggml_tensor * mm_model_attn_q_w; + struct ggml_tensor * mm_model_attn_q_b; + struct ggml_tensor * mm_model_attn_k_w; + struct ggml_tensor * mm_model_attn_k_b; + struct ggml_tensor * mm_model_attn_v_w; + struct ggml_tensor * mm_model_attn_v_b; + struct ggml_tensor * mm_model_attn_o_w; + struct ggml_tensor * mm_model_attn_o_b; + struct ggml_tensor * mm_model_ln_q_w; + struct ggml_tensor * mm_model_ln_q_b; + struct ggml_tensor * mm_model_ln_kv_w; + struct ggml_tensor * mm_model_ln_kv_b; + struct ggml_tensor * mm_model_ln_post_w; + struct ggml_tensor * mm_model_ln_post_b; }; struct clip_ctx { bool has_text_encoder = false; bool has_vision_encoder = false; bool has_llava_projector = false; + bool has_minicpmv_projector = false; struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -520,7 +554,7 @@ struct clip_ctx { ggml_gallocr_t compute_alloc = NULL; }; -static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) { +static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, std::pair load_image_size = {448, 448}) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); return nullptr; @@ -529,10 +563,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const auto & model = ctx->vision_model; const auto & hparams = model.hparams; - const int image_size = hparams.image_size; - const int patch_size = hparams.patch_size; - const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); - const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side); + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + if (ctx->has_minicpmv_projector) { + image_size_width = load_image_size.first; + image_size_height = load_image_size.second; + } + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; @@ -542,7 +581,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int batch_size = imgs->size; - if (ctx->has_llava_projector) { + if (ctx->has_llava_projector || ctx->has_minicpmv_projector) { GGML_ASSERT(batch_size == 1); } @@ -555,7 +594,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); @@ -563,25 +602,27 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed; if (ctx->has_patch_bias) { // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); } - // concat class_embeddings and patch_embeddings - struct ggml_tensor * embeddings = inp; - if (ctx->has_class_embedding) { - embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); - ggml_set_name(embeddings, "embeddings"); - ggml_set_input(embeddings); - embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); - embeddings = ggml_acc(ctx0, embeddings, inp, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + if(ctx->has_llava_projector){ + // concat class_embeddings and patch_embeddings + if (ctx->has_class_embedding) { + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_name(embeddings, "embeddings"); + ggml_set_input(embeddings); + embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + } } - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); ggml_set_name(positions, "positions"); ggml_set_input(positions); @@ -589,6 +630,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); + if(ctx->has_minicpmv_projector){ + int pos_w = image_size_width/patch_size; + int pos_h = image_size_height/patch_size; + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + ggml_set_name(pos_embed, "pos_embed"); + ggml_set_input(pos_embed); + } + // pre-layernorm if (ctx->has_pre_norm) { embeddings = ggml_norm(ctx0, embeddings, eps); @@ -687,6 +736,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // llava projector + if(ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); @@ -864,6 +914,65 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); embeddings = peg_0; } + + else { + GGML_ASSERT(false); + } + } + // minicpmv projector + else if(ctx->has_minicpmv_projector) + { + if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + struct ggml_tensor * q = model.mm_model_query; + { // layernorm + q = ggml_norm(ctx0, q, eps); + q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + } + struct ggml_tensor *k, *v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); + { // layernorm + v = ggml_norm(ctx0, v, eps); + v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); + } + { // position + // q = ggml_add(ctx0, q, model.mm_model_pos_embed); + k = ggml_add(ctx0, v, pos_embed); + } + + { // attention + const int hidden_size = 4096; + const int d_head = 128; + const int n_head = hidden_size/d_head; + const int num_query = 96; + + struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); + struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); + // permute + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); + + embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); + } + { // layernorm + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b); + } + embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); + } else { GGML_ASSERT(false); } @@ -878,7 +987,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // read and create ggml_context containing the tensors and their data -struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { +struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1, std::pair load_image_size) { struct ggml_context * meta = NULL; struct gguf_init_params params = { @@ -1020,7 +1129,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); } - GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search + idx = gguf_find_key(ctx, KEY_HAS_MiniCPMV_PROJ); + if (idx != -1) { + new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); + } + + // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search + GGML_ASSERT(new_clip->has_vision_encoder); GGML_ASSERT(!new_clip->has_text_encoder); @@ -1031,6 +1146,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_TEE("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); LOG_TEE("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); LOG_TEE("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); + LOG_TEE("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); LOG_TEE("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); LOG_TEE("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } @@ -1272,6 +1388,27 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight")); vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_RESAMPLER) { + // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); + vision_model.mm_model_pos_embed_k = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD_K); + vision_model.mm_model_query = get_tensor(new_clip->ctx_data, TN_MINICPMV_QUERY); + vision_model.mm_model_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_PROJ); + vision_model.mm_model_kv_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_KV_PROJ); + vision_model.mm_model_attn_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "weight")); + vision_model.mm_model_attn_k_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "weight")); + vision_model.mm_model_attn_v_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "weight")); + vision_model.mm_model_attn_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "bias")); + vision_model.mm_model_attn_k_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "bias")); + vision_model.mm_model_attn_v_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "bias")); + vision_model.mm_model_attn_o_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "weight")); + vision_model.mm_model_attn_o_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "bias")); + vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "weight")); + vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "bias")); + vision_model.mm_model_ln_kv_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "weight")); + vision_model.mm_model_ln_kv_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "bias")); + vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); + vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); + } else { std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); @@ -1310,7 +1447,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend)); clip_image_f32_batch batch; batch.size = 1; - ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch); + ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, load_image_size); ggml_gallocr_reserve(new_clip->compute_alloc, gf); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); LOG_TEE("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); @@ -1424,6 +1561,19 @@ static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* } } +void uhd_normalize_image_u8_to_f32(struct clip_ctx * ctx, const clip_image_u8* src, clip_image_f32* dst) { + dst->nx = src->nx; + dst->ny = src->ny; + dst->buf.resize(src->buf.size()); + const auto & m3 = ctx->image_mean; + const auto & s3 = ctx->image_std; + + for (size_t i = 0; i < src->buf.size(); ++i) { + int c = i % 3; // rgb + dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - m3[c]) / s3[c]; + } +} + inline float clip(float x, float lower, float upper) { return std::max(lower, std::min(x, upper)); } @@ -1807,12 +1957,100 @@ int clip_n_patches(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { n_patches /= 4; + } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + n_patches = 96; } return n_patches; } -bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { +static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector>& pos) { + assert(embed_dim % 2 == 0); + int H = pos.size(); + int W = pos[0].size(); + + std::vector omega(embed_dim / 2); + for (int i = 0; i < embed_dim / 2; ++i) { + omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2)); + } + + std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + for (int d = 0; d < embed_dim / 2; ++d) { + float out_value = pos[h][w] * omega[d]; + emb[h][w][d] = sin(out_value); + emb[h][w][d + embed_dim / 2] = cos(out_value); + } + } + } + + return emb; +} + +static std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>>& grid) { + assert(embed_dim % 2 == 0); + std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2) + std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2) + + int H = emb_h.size(); + int W = emb_h[0].size(); + std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); + + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + for (int d = 0; d < embed_dim / 2; ++d) { + emb[h][w][d] = emb_h[h][w][d]; + emb[h][w][d + embed_dim / 2] = emb_w[h][w][d]; + } + } + } + return emb; +} + +static std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) { + int grid_h_size = image_size.first; + int grid_w_size = image_size.second; + + std::vector grid_h(grid_h_size); + std::vector grid_w(grid_w_size); + + for (int i = 0; i < grid_h_size; ++i) { + grid_h[i] = static_cast(i); + } + for (int i = 0; i < grid_w_size; ++i) { + grid_w[i] = static_cast(i); + } + + std::vector> grid(grid_h_size, std::vector(grid_w_size)); + for (int h = 0; h < grid_h_size; ++h) { + for (int w = 0; w < grid_w_size; ++w) { + grid[h][w] = grid_w[w]; + } + } + std::vector>> grid_2d = {grid, grid}; + for (int h = 0; h < grid_h_size; ++h) { + for (int w = 0; w < grid_w_size; ++w) { + grid_2d[0][h][w] = grid_h[h]; + grid_2d[1][h][w] = grid_w[w]; + } + } + + std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d); + + int H = image_size.first; + int W = image_size.second; + std::vector> pos_embed_2d(H * W, std::vector(embed_dim)); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + pos_embed_2d[w * H + h] = pos_embed_3d[h][w]; + } + } + + return pos_embed_2d; +} + +bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec, std::pair load_image_size) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); return false; @@ -1821,10 +2059,10 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 clip_image_f32_batch imgs{}; imgs.size = 1; imgs.data = img; - return clip_image_batch_encode(ctx, n_threads, &imgs, vec); + return clip_image_batch_encode(ctx, n_threads, &imgs, vec, load_image_size); } -bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { +bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec, std::pair load_image_size) { if (!ctx->has_vision_encoder) { LOG_TEE("This gguf file seems to have no vision encoder\n"); return false; @@ -1834,6 +2072,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ctx->has_llava_projector) { GGML_ASSERT(batch_size == 1); // TODO: support multiple images } + if (ctx->has_minicpmv_projector) { + GGML_ASSERT(batch_size == 1); + } // build the inference graph ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); @@ -1844,8 +2085,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const auto & hparams = model.hparams; const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + if (ctx->has_minicpmv_projector) { + image_size_width = load_image_size.first; + image_size_height = load_image_size.second; + } const int patch_size = hparams.patch_size; - const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); { @@ -1855,7 +2102,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (size_t i = 0; i < imgs->size; i++) { const int nx = imgs->data[i].nx; const int ny = imgs->data[i].ny; - GGML_ASSERT(nx == image_size && ny == image_size); + if (!ctx->has_minicpmv_projector) { + GGML_ASSERT(nx == image_size && ny == image_size); + } const int n = nx * ny; @@ -1872,37 +2121,74 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); free(data); } + if (ctx->has_minicpmv_projector) { + { + // inspired from siglip: + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = std::floor(70.0*i/num_positions); + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } - { - if (ctx->has_class_embedding) { - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + { + // inspired from resampler of Qwen-VL: + // -> https://huggingface.co/Qwen/Qwen-VL/tree/main + // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 + struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); + int pos_w = image_size_width/patch_size; + int pos_h = image_size_height/patch_size; + int embed_dim = 4096; + auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); + + float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); + for(int i=0;ihas_class_embedding) { + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); - { - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + void* zero_mem = malloc(ggml_nbytes(embeddings)); + memset(zero_mem, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); + free(zero_mem); + } + } - int* positions_data = (int*)malloc(ggml_nbytes(positions)); - for (int i = 0; i < num_positions; i++) { - positions_data[i] = i; + { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = i; + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); - } - { - struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); - int* patches_data = (int*)malloc(ggml_nbytes(patches)); - for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + 1; + { + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + int* patches_data = (int*)malloc(ggml_nbytes(patches)); + for (int i = 0; i < num_patches; i++) { + patches_data[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); + free(patches_data); } - ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); - free(patches_data); } if (ggml_backend_is_cpu(ctx->backend)) { @@ -2072,6 +2358,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { return ctx->vision_model.mm_3_b->ne[0]; } + if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + return 4096; + } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); diff --git a/examples/llava/clip.h b/examples/llava/clip.h index ca36313844c13..232fe50a81aae 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -3,6 +3,7 @@ #include #include +#include #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) @@ -36,7 +37,7 @@ struct clip_image_f32_batch { size_t size; }; -CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); +CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity, std::pair load_image_size = {448, 448}); CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity); CLIP_API void clip_free(struct clip_ctx * ctx); @@ -71,10 +72,12 @@ CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t byt /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); +CLIP_API void uhd_normalize_image_u8_to_f32(struct clip_ctx * ctx, const clip_image_u8* src, clip_image_f32* dst); + CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); -CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec); -CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec); +CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec, std::pair load_image_size = {448, 448}); +CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec, std::pair load_image_size = {448, 448}); CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 63878d176b0bb..0d7324037833c 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -409,3 +409,342 @@ void llava_image_embed_free(struct llava_image_embed * embed) { free(embed->embed); free(embed); } + +static bool encode_image_with_clip_uhd(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { + // std::vector img_res_v; + // format VectN x H x W x RGB (N x 448 x 448 x 3) + clip_image_f32 * img_res_v = clip_image_f32_init(); + std::pair load_image_size; + load_image_size.first = img->nx; + load_image_size.second = img->ny; + uhd_normalize_image_u8_to_f32(ctx_clip, img, img_res_v); + + const int64_t t_img_enc_start_us = ggml_time_us(); + + const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); + LOG_TEE("\n%s: mm_patch_merge_type is %s.\n", __func__, mm_patch_merge_type); + + *n_img_pos = clip_n_patches(ctx_clip); + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res_v, image_embd, load_image_size); // image_embd shape is 96 x 4096 + if (!encoded) { + LOG_TEE("Unable to encode image\n"); + return false; + } + LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); + + const int64_t t_img_enc_end_us = ggml_time_us(); + float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; + LOG_TEE("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); + + return true; +} + +static int ensure_divide(int length, int patch_size) { + return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); +} + +static std::pair uhd_find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.first; + int height = original_size.second; + if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { + float r = static_cast(width) / height; + height = static_cast(scale_resolution / std::sqrt(r)); + width = static_cast(height * r); + } + int best_width = ensure_divide(width, patch_size); + int best_height = ensure_divide(height, patch_size); + return std::make_pair(best_width, best_height); +} + +static std::pair uhd_get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width, height; + std::tie(width, height) = original_size; + int grid_x, grid_y; + std::tie(grid_x, grid_y) = grid; + + int refine_width = ensure_divide(width, grid_x); + int refine_height = ensure_divide(height, grid_y); + + int grid_width = refine_width / grid_x; + int grid_height = refine_height / grid_y; + + // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) + auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair + int best_grid_width, best_grid_height; + std::tie(best_grid_width, best_grid_height) = best_grid_size; + + // std::pair refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) + std::pair refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line) + return refine_size; +} + +inline int clip(int x, int lower, int upper) { + return std::max(lower, std::min(x, upper)); +} + +static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) { + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +// inspired from LLaVA-UHD: +// -> https://arxiv.org/pdf/2403.11703 +// -> https://github.com/thunlp/LLaVA-UHD +// -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 +static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { + const std::pair original_size={img->nx,img->ny}; + const int original_width = img->nx; + const int original_height = img->ny; + const float log_ratio = log(1.0*original_width/original_height); // + const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); + const int multiple = fmin(ceil(ratio), max_slice_nums); + + std::vector> images; + LOG_TEE("%s: multiple %d\n", __func__, multiple); + images.push_back(std::vector()); + + if(multiple <= 1){ + auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); + clip_image_u8 *source_image = clip_image_u8_init(); + bicubic_resize(*img, *source_image, best_size.first, best_size.second); + // source_image = image.resize(best_size, Image.Resampling.BICUBIC) + images[images.size()-1].push_back(source_image); + } + else if(multiple > 1){ + + std::vector candidate_split_grids_nums; + for (int i : {multiple - 1, multiple, multiple + 1}) { + if (i == 1 || i > max_slice_nums) { + continue; + } + candidate_split_grids_nums.push_back(i); + } + + auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); + clip_image_u8 *source_image = clip_image_u8_init(); + bicubic_resize(*img, *source_image, best_size.first, best_size.second); + // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + LOG_TEE("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); + images[images.size()-1].push_back(source_image); + + std::vector> candidate_grids; + + for (int split_grids_nums : candidate_split_grids_nums) { + int m = 1; + while (m <= split_grids_nums) { + if (split_grids_nums % m == 0) { + candidate_grids.emplace_back(m, split_grids_nums / m); + } + ++m; + } + } + + std::pair best_grid{1, 1}; + float min_error = std::numeric_limits::infinity(); + + for (const auto& grid : candidate_grids) { + float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second)); + if (error < min_error) { + best_grid = grid; + min_error = error; + } + } + LOG_TEE("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); + + auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); + clip_image_u8 *refine_image = clip_image_u8_init(); + bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); + + LOG_TEE("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); + + // split_to_patches + int width = refine_image->nx; + int height = refine_image->ny; + int grid_x = int(width / best_grid.first); + int grid_y = int(height / best_grid.second); + for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ + images.push_back(std::vector()); + for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ + clip_image_u8 * patch = clip_image_u8_init(); + patch->nx = grid_x; + patch->ny = grid_y; + patch->buf.resize(3 * patch->nx * patch->ny); + for (int y = patches_i; y < patches_i + grid_y; ++y) { + for (int x = patches_j; x < patches_j + grid_x; ++x) { + const int i = 3 * (y * refine_image->nx + x); + const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j)); + patch->buf[j] = refine_image->buf[i]; + patch->buf[j+1] = refine_image->buf[i+1]; + patch->buf[j+2] = refine_image->buf[i+2]; + } + } + images[images.size()-1].push_back(patch); + } + } + } + return images; +} + +struct uhd_image_embed * llava_image_embed_make_with_bytes_uhd(struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img) { + std::vector> imgs = uhd_slice_image(img); + for (size_t i = 0; i < imgs.size(); ++i){ + for (size_t j = 0; j < imgs[i].size(); ++j) { + LOG_TEE("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); + } + } + struct uhd_image_embed * results = new uhd_image_embed(); + + for (size_t i = 0; i < imgs.size(); ++i){ + results->image_embeds.push_back(std::vector()); + for (size_t j = 0; j < imgs[i].size(); ++j) { + float* image_embed = NULL; + int n_image_pos = 0; + bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, imgs[i][j], &image_embed, &n_image_pos); + if (!image_embed_result) { + LOG_TEE("%s: coulnd't embed the image\n", __func__); + return NULL; + } + + auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed)); + result->embed = image_embed; + result->n_image_pos = n_image_pos; + results->image_embeds[i].push_back(result); + } + } + return results; +} + +bool llava_image_embed_make_with_clip_img_ollama(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { + auto embeds = llava_image_embed_make_with_bytes_uhd(ctx_clip, n_threads, img); + auto image_embed_slices = embeds->image_embeds; + if (!image_embed_slices[0][0]){ + LOG_TEE("%s: failed to embeding image\n", __func__); + return false; + } + std::string fname = "./examples/minicpm-v2.5/slice_token_for_ollama.raw"; + unsigned char* slice_token; + long image_bytes_length; + auto loaded = load_file_to_bytes(fname.c_str(), &slice_token, &image_bytes_length); + if (!loaded) { + LOG_TEE("%s: failed to load %s\n", __func__, fname.c_str()); + return false; + } + + float * all_image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*61); + int all_n_img_pos=0; + int token_len = clip_n_mmproj_embd(ctx_clip)*sizeof(float); + + std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token, token_len); + std::memcpy(all_image_embd+token_len*all_n_img_pos, image_embed_slices[0][0]->embed, 96*token_len); + all_n_img_pos+=clip_n_patches(ctx_clip); + std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len, token_len); + if (image_embed_slices.size() > 1) { + std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len*2, token_len); + for (size_t i = 1; i < image_embed_slices.size(); ++i) { + for (size_t j = 0; j < image_embed_slices[i].size(); ++j) { + std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token, token_len); + std::memcpy(all_image_embd+token_len*all_n_img_pos, image_embed_slices[i][j]->embed, 96*token_len); + all_n_img_pos+=clip_n_patches(ctx_clip); + std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len, token_len); + if (j == image_embed_slices[i].size() - 1) { + std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len*4, token_len); + } + } + } + std::memcpy(all_image_embd+token_len*all_n_img_pos++, slice_token+token_len*3, token_len); + } + *image_embd_out = all_image_embd; + *n_img_pos_out = all_n_img_pos; + return true; +} + +struct uhd_image_embed * llava_image_embed_make_with_filename_uhd(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) { + unsigned char* image_bytes; + long image_bytes_length; + auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length); + if (!loaded) { + LOG_TEE("%s: failed to load %s\n", __func__, image_path); + return NULL; + } + clip_image_u8 * img = clip_image_u8_init(); + if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) { + clip_image_u8_free(img); + LOG_TEE("%s: can't load image from bytes, is it a valid image?", __func__); + return NULL; + } + + struct uhd_image_embed * embeds = llava_image_embed_make_with_bytes_uhd(ctx_clip, n_threads, img); + + clip_image_u8_free(img); + free(image_bytes); + return embeds; +} + +void llava_image_embed_free_uhd(struct uhd_image_embed * embed) { + for (size_t i = 0; i < embed->image_embeds.size(); ++i){ + for (size_t j = 0; j < embed->image_embeds[i].size(); ++j){ + free(embed->image_embeds[i][j]->embed); + free(embed->image_embeds[i][j]); + } + embed->image_embeds[i] = std::vector(); + } + embed->image_embeds = std::vector>(); +} \ No newline at end of file diff --git a/examples/llava/llava.h b/examples/llava/llava.h index 19212f6e9e9c5..420ae15d641f2 100644 --- a/examples/llava/llava.h +++ b/examples/llava/llava.h @@ -19,6 +19,10 @@ struct clip_ctx; +struct uhd_image_embed { + std::vector> image_embeds; +}; + #ifdef __cplusplus extern "C" { #endif @@ -40,6 +44,13 @@ LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); /** free an embedding made with llava_image_embed_make_* */ +/** build an image embed from image file bytes */ +LLAVA_API struct uhd_image_embed * llava_image_embed_make_with_bytes_uhd(struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img); +/** build an image embed from a path to an image filename */ +LLAVA_API bool llava_image_embed_make_with_clip_img_ollama(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); +LLAVA_API struct uhd_image_embed * llava_image_embed_make_with_filename_uhd(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); +LLAVA_API void llava_image_embed_free_uhd(struct uhd_image_embed * embed); + /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp new file mode 100644 index 0000000000000..befaec8bc6ec2 --- /dev/null +++ b/examples/llava/minicpmv-cli.cpp @@ -0,0 +1,158 @@ +#include "ggml.h" +#include "log.h" +#include "common.h" +#include "clip.h" +#include "llava.h" +#include "minicpmv_wrapper.h" +#include "llama.h" + +#include +#include +#include + +static void show_additional_info(int /*argc*/, char ** argv) { + LOG_TEE("\n example usage: %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); +} + +static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + LOG_TEE("%s", text); +} + +static struct minicpmv_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){ + auto embeds = minicpmv_image_embed(params, fname); + auto image_embed_slices = embeds->image_embeds; + if (!image_embed_slices[0][0]) { + std::cerr << "error: failed to load image " << fname << ". Terminating\n\n"; + return NULL; + } + + // process the prompt + if (params->prompt.empty() && params->interactive == false) { + LOG_TEE("prompt should be given or interactive mode should be on"); + return NULL; + } + + auto model = llava_init(params); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__); + return NULL; + } + const int64_t t_llava_init_start_us = ggml_time_us(); + auto ctx_llava = llava_init_context(params, model); + + const int64_t t_llava_init_end_us = ggml_time_us(); + float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0; + LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); + + const int64_t t_process_image_start_us = ggml_time_us(); + process_image(ctx_llava, image_embed_slices, params, n_past); + const int64_t t_process_image_end_us = ggml_time_us(); + float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0; + LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); + + llava_image_embed_free_uhd(embeds); + return ctx_llava; +} + +static struct llama_sampling_context * llama_init(struct minicpmv_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ + std::string user_prompt = prompt; + if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; + + eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); + eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); + // generate the response + + LOG_TEE("\n"); + + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + return ctx_sampling; +} + +static const char * llama_loop(struct minicpmv_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){ + + const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + return tmp; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + show_additional_info(argc, argv); + return 1; + } + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("llava", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); + llama_log_set(llama_log_callback_logTee, nullptr); +#endif // LOG_DISABLE_LOGS + + if (params.mmproj.empty() || (params.image.empty())) { + gpt_params_print_usage(argc, argv, params); + show_additional_info(argc, argv); + return 1; + } + + for (auto & image : params.image) { + int n_past = 0; + auto ctx_llava = minicpmv_init(¶ms, image, n_past); + + if (!params.prompt.empty()) { + LOG_TEE("%s\n", params.prompt.c_str()); + LOG_TEE(""); + auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); + const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; + std::string response = ""; + bool have_tmp = false; + for (int i = 0; i < max_tgt_len; i++) { + auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + response += tmp; + if (strcmp(tmp, "") == 0){ + if(!have_tmp)continue; + else break; + } + if (strstr(tmp, "###")) break; // Yi-VL behavior + have_tmp = true; + printf("%s", tmp); + if (strstr(response.c_str(), "")) break; // minicpm-v + + fflush(stdout); + } + llama_sampling_free(ctx_sampling); + }else { + while (true) { + LOG_TEE(""); + std::string prompt; + std::getline(std::cin, prompt); + LOG_TEE(""); + auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true); + const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; + std::string response = ""; + for (int i = 0; i < max_tgt_len; i++) { + auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + response += tmp; + if (strcmp(tmp, "") == 0) break; + if (strstr(tmp, "###")) break; // Yi-VL behavior + printf("%s", tmp);// mistral llava-1.6 + if (strstr(response.c_str(), "")) break; // minicpm-v + fflush(stdout); + } + llama_sampling_free(ctx_sampling); + } + } + printf("\n"); + llama_print_timings(ctx_llava->ctx_llama); + + ctx_llava->model = NULL; + llava_free(ctx_llava); + } + + return 0; +} \ No newline at end of file diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py new file mode 100644 index 0000000000000..71af0b4f1ccb6 --- /dev/null +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -0,0 +1,382 @@ +import argparse +import os +import json +import re + +import torch +import numpy as np +from gguf import * +import timm +from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig + +TEXT = "clip.text" +VISION = "clip.vision" + + +def k(raw_key: str, arch: str) -> str: + return raw_key.format(arch=arch) + + +def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool: + if name in ( + "logit_scale", + "text_model.embeddings.position_ids", + "vision_model.embeddings.position_ids", + ): + return True + + if has_minicpmv and name in ["visual_projection.weight"]: + return True + + if name.startswith("v") and not has_vision: + return True + + if name.startswith("t") and not has_text: + return True + + return False + + +def get_tensor_name(name: str) -> str: + if "projection" in name: + return name + if "mm_projector" in name: + name = name.replace("model.mm_projector", "mm") + name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) + name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) + return name + + return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + + +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +ap = argparse.ArgumentParser() +ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) +ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") +ap.add_argument("--text-only", action="store_true", required=False, + help="Save a text-only model. It can't be used to encode images") +ap.add_argument("--vision-only", action="store_true", required=False, + help="Save a vision-only model. It can't be used to encode texts") +ap.add_argument("--clip-model-is-vision", action="store_true", required=False, + help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") +ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, + help="The clip model is from openclip (for ViT-SO400M type))") +ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.") +ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") +ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 +# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 +default_image_mean = [0.48145466, 0.4578275, 0.40821073] +default_image_std = [0.26862954, 0.26130258, 0.27577711] +ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) +ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) + +# with proper +args = ap.parse_args() + + +if args.text_only and args.vision_only: + print("--text-only and --image-only arguments cannot be specified at the same time.") + exit(1) + +if args.use_f32: + print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + +# output in the same directory as the model if output_dir is None +dir_model = args.model_dir + +if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: + vocab = None + tokens = None +else: + with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: + vocab = json.load(f) + tokens = [key for key in vocab] + +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if args.use_f32: + ftype = 0 + +# if args.clip_model_is_vision or args.clip_model_is_openclip: +# model = CLIPVisionModel.from_pretrained(dir_model) +# processor = None +# else: +# model = CLIPModel.from_pretrained(dir_model) +# processor = CLIPProcessor.from_pretrained(dir_model) + +default_vision_config = { + "hidden_size": 1152, + "image_size": 980, + "intermediate_size": 4304, + "model_type": "idefics2", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } +vision_config = Idefics2VisionConfig(**default_vision_config) +model = Idefics2VisionTransformer(vision_config) + +processor = None +# if model.attn_pool is not None: +# model.attn_pool = torch.nn.Identity() + +# model.blocks = model.blocks[:-1] +model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip"))) + +fname_middle = None +has_text_encoder = True +has_vision_encoder = True +has_minicpmv_projector = False +if args.text_only: + fname_middle = "text-" + has_vision_encoder = False +elif args.minicpmv_projector is not None: + fname_middle = "mmproj-" + has_text_encoder = False + has_minicpmv_projector = True +elif args.vision_only: + fname_middle = "vision-" + has_text_encoder = False +else: + fname_middle = "" + +output_dir = args.output_dir if args.output_dir is not None else dir_model +os.makedirs(output_dir, exist_ok=True) +output_prefix = os.path.basename(output_dir).replace("ggml_", "") +fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") +fout = GGUFWriter(path=fname_out, arch="clip") + +fout.add_bool("clip.has_text_encoder", has_text_encoder) +fout.add_bool("clip.has_vision_encoder", has_vision_encoder) +fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector) +fout.add_file_type(ftype) +if args.text_only: + fout.add_description("text-only CLIP model") +elif args.vision_only and not has_minicpmv_projector: + fout.add_description("vision-only CLIP model") +elif has_minicpmv_projector: + fout.add_description("image encoder for MiniCPM-V") + # add projector type + fout.add_string("clip.projector_type", "resampler") +else: + fout.add_description("two-tower CLIP model") + +if has_vision_encoder: + # vision_model hparams + fout.add_uint32("clip.vision.image_size", 448) + fout.add_uint32("clip.vision.patch_size", 14) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), 1152) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 4304) + fout.add_uint32("clip.vision.projection_dim", 0) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), 16) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) + block_count = 26 + fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) + + if processor is not None: + image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean + image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std + else: + image_mean = args.image_mean if args.image_mean is not None else default_image_mean + image_std = args.image_std if args.image_std is not None else default_image_std + fout.add_array("clip.vision.image_mean", image_mean) + fout.add_array("clip.vision.image_std", image_std) + +use_gelu = True +fout.add_bool("clip.use_gelu", use_gelu) + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + +def _replace_name_resampler(s, v): + if re.match("resampler.pos_embed", s): + return { + s: v, + re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), + } + if re.match("resampler.proj", s): + return { + re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), + re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), + } + if re.match("resampler.attn.in_proj_.*", s): + return { + re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0], + re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1], + re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2], + } + return {s: v} + +if has_minicpmv_projector: + projector = torch.load(args.minicpmv_projector) + new_state_dict = {} + for k, v in projector.items(): + kvs = _replace_name_resampler(k, v) + for nk, nv in kvs.items(): + new_state_dict[nk] = nv + projector = new_state_dict + for name, data in projector.items(): + name = get_tensor_name(name) + data = data.squeeze().numpy() + + n_dims = len(data.shape) + if ftype == 1: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + fout.add_tensor(name, data) + print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") + + print("Projector tensors added\n") + +def _replace_name(s, v): + s = "vision_model." + s + if re.match("vision_model.embeddings.position_embedding", s): + v = v.unsqueeze(0) + return {s: v} + + return {s: v} + +state_dict = model.state_dict() +new_state_dict = {} +for k, v in state_dict.items(): + kvs = _replace_name(k, v) + for nk, nv in kvs.items(): + new_state_dict[nk] = nv +state_dict = new_state_dict +for name, data in state_dict.items(): + if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector): + # we don't need this + print(f"skipping parameter: {name}") + continue + + name = get_tensor_name(name) + data = data.squeeze().numpy() + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if n_dims == 4: + print(f"tensor {name} is always saved in f16") + data = data.astype(np.float16) + ftype_cur = 1 + elif ftype == 1: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") + fout.add_tensor(name, data) + + +fout.write_header_to_file() +fout.write_kv_data_to_file() +fout.write_tensors_to_file() +fout.close() + +print("Done. Output file: " + fname_out) diff --git a/examples/llava/minicpmv-surgery.py b/examples/llava/minicpmv-surgery.py new file mode 100644 index 0000000000000..51b4d48598440 --- /dev/null +++ b/examples/llava/minicpmv-surgery.py @@ -0,0 +1,48 @@ +import argparse +import glob +import os +import torch +from transformers import AutoModel, AutoTokenizer + +ap = argparse.ArgumentParser() +ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model") +args = ap.parse_args() + +# find the model part that includes the the multimodal projector weights +model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +checkpoint = model.state_dict() + +# get a list of mm tensor names +mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")] + +# store these tensors in a new dictionary and torch.save them +projector = {name: checkpoint[name].float() for name in mm_tensors} +torch.save(projector, f"{args.model}/minicpmv.projector") + +clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")] +if len(clip_tensors) > 0: + clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors} + torch.save(clip, f"{args.model}/minicpmv.clip") + + # added tokens should be removed to be able to convert Mistral models + if os.path.exists(f"{args.model}/added_tokens.json"): + with open(f"{args.model}/added_tokens.json", "w") as f: + f.write("{}\n") + +config = model.llm.config +config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5" +config.auto_map = { + "AutoConfig": "configuration_minicpm.MiniCPMConfig", + "AutoModel": "modeling_minicpm.MiniCPMModel", + "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM", + "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM", + "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification" +} +model.llm.save_pretrained(f"{args.model}/model") +tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) +tok.save_pretrained(f"{args.model}/model") +# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py") + +print("Done!") +print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") +print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.") diff --git a/examples/llava/minicpmv_wrapper.cpp b/examples/llava/minicpmv_wrapper.cpp new file mode 100644 index 0000000000000..5e1d9b134ea67 --- /dev/null +++ b/examples/llava/minicpmv_wrapper.cpp @@ -0,0 +1,145 @@ +#include "ggml.h" +#include "common.h" +#include "clip.h" +#include "llava.h" +#include "minicpmv_wrapper.h" +#include "llama.h" +#include +#include +#include + +struct llama_model * llava_init(gpt_params * params) { + llama_backend_init(); + llama_numa_init(params->numa); + + llama_model_params model_params = llama_model_params_from_gpt_params(*params); + + llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + if (model == NULL) { + LOG_TEE("%s: error: unable to load model\n" , __func__); + return NULL; + } + return model; +} + +struct minicpmv_context * llava_init_context(gpt_params * params, llama_model * model) { + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + + llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); + ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings + + llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + + if (ctx_llama == NULL) { + LOG_TEE("%s: error: failed to create the llama_context\n" , __func__); + return NULL; + } + + auto ctx_llava = (struct minicpmv_context *)malloc(sizeof(minicpmv_context)); + + ctx_llava->ctx_llama = ctx_llama; + ctx_llava->model = model; + return ctx_llava; +} + +void llava_free(struct minicpmv_context * ctx_llava) { + llama_free(ctx_llava->ctx_llama); + llama_free_model(ctx_llava->model); + llama_backend_free(); +} + +struct clip_ctx * clip_init_context(gpt_params * params) { + const char * clip_path = params->mmproj.c_str(); + + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + std::pair load_image_size = std::make_pair(448, 448); + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1, load_image_size); + return ctx_clip; +} + +struct uhd_image_embed * minicpmv_image_embed(gpt_params * params, const std::string & fname){ + auto ctx_clip = clip_init_context(params); + auto image_embed_and_slices = llava_image_embed_make_with_filename_uhd(ctx_clip, params->n_threads, fname.c_str()); + if (ctx_clip) { + clip_free(ctx_clip); + ctx_clip = NULL; + } + return image_embed_and_slices; +} + + +bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { + int N = (int) tokens.size(); + for (int i = 0; i < N; i += n_batch) { + int n_eval = (int) tokens.size() - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { + LOG_TEE("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); + return false; + } + *n_past += n_eval; + } + return true; +} + +bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { + std::vector tokens; + tokens.push_back(id); + return eval_tokens(ctx_llama, tokens, 1, n_past); +} + +bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ + std::string str2 = str; + std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); + eval_tokens(ctx_llama, embd_inp, n_batch, n_past); + return true; +} + +void process_image(struct minicpmv_context * ctx_llava, std::vector> image_embed_slices, gpt_params * params, int &n_past) { + std::string system_prompt; + + system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; + LOG_TEE("%s: image token past: %d\n", __func__, n_past); + eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); + llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[0][0], params->n_batch, &n_past); + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + if (image_embed_slices.size() > 1) { + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + for (size_t i = 1; i < image_embed_slices.size(); ++i) { + for (size_t j = 0; j < image_embed_slices[i].size(); ++j) { + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + llava_eval_image_embed(ctx_llava->ctx_llama, image_embed_slices[i][j], params->n_batch, &n_past); + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + if (j == image_embed_slices[i].size() - 1) { + eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false); + } + } + } + eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); + + } + LOG_TEE("%s: image token past: %d\n", __func__, n_past); +} + +const char * sample(struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_llama, + int * n_past) { + const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); + llama_sampling_accept(ctx_sampling, ctx_llama, id, true); + static std::string ret; + if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + ret = ""; + } else { + ret = llama_token_to_piece(ctx_llama, id); + } + eval_id(ctx_llama, id, n_past); + return ret.c_str(); +} \ No newline at end of file diff --git a/examples/llava/minicpmv_wrapper.h b/examples/llava/minicpmv_wrapper.h new file mode 100644 index 0000000000000..7bd7b44fe2b47 --- /dev/null +++ b/examples/llava/minicpmv_wrapper.h @@ -0,0 +1,49 @@ +#ifndef MINICPMV_H +#define MINICPMV_H + +#include "common.h" +#include "clip.h" +#include "llava.h" +#include "llama.h" + +#ifdef LLAMA_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef LLAMA_BUILD +# define MINICPMV_API __declspec(dllexport) +# else +# define MINICPMV_API __declspec(dllimport) +# endif +# else +# define MINICPMV_API __attribute__ ((visibility ("default"))) +# endif +#else +# define MINICPMV_API +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +struct minicpmv_context { + struct llama_context * ctx_llama = NULL; + struct llama_model * model = NULL; +}; + +MINICPMV_API struct llama_model * llava_init(gpt_params * params); +MINICPMV_API struct minicpmv_context * llava_init_context(gpt_params * params, llama_model * model); +MINICPMV_API void llava_free(struct minicpmv_context * ctx_llava); + +MINICPMV_API struct clip_ctx * clip_init_context(gpt_params * params); +MINICPMV_API struct uhd_image_embed * minicpmv_image_embed(gpt_params * params, const std::string & fname); + +MINICPMV_API bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past); +MINICPMV_API bool eval_id(struct llama_context * ctx_llama, int id, int * n_past); +MINICPMV_API bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos); +MINICPMV_API void process_image(struct minicpmv_context * ctx_llava, std::vector> image_embed_slices, gpt_params * params, int &n_past); +MINICPMV_API const char * sample(struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_llama, int * n_past); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/examples/llava/requirements.txt b/examples/llava/requirements.txt index 17cb4d5e5ee8e..5c478e6932d9b 100644 --- a/examples/llava/requirements.txt +++ b/examples/llava/requirements.txt @@ -1,3 +1,4 @@ -r ../../requirements/requirements-convert-legacy-llama.txt pillow~=10.2.0 torch~=2.1.1 +torchvision==0.16.2 \ No newline at end of file