diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 9771fccf9ffc1..32d54b45f3325 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -218,6 +218,8 @@ def from_model_architecture(model_architecture): return BertModel if model_architecture == "NomicBertModel": return NomicBertModel + if model_architecture == "GemmaForCausalLM": + return GemmaModel return Model def _is_model_safetensors(self) -> bool: @@ -277,6 +279,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH: return gguf.MODEL_ARCH.BERT if arch == "NomicBertModel": return gguf.MODEL_ARCH.NOMIC_BERT + if arch == "GemmaForCausalLM": + return gguf.MODEL_ARCH.GEMMA raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -618,11 +622,6 @@ def write_tensors(self): self.gguf_writer.add_tensor(new_name, data) - # note: MPT output is tied to (same as) wte in original model; - # for easier implementation in llama.cpp it's duplicated in GGUF, though :/ - if new_name == "token_embd.weight": - self.gguf_writer.add_tensor("output.weight", data) - class OrionModel(Model): def set_vocab(self): @@ -655,6 +654,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) + # note: config provides rms norm but it is actually layer norm + # ref: https://huggingface.co/OrionStarAI/Orion-14B-Chat/blob/276a17221ce42beb45f66fac657a41540e71f4f5/modeling_orion.py#L570-L571 self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"]) def write_tensors(self): @@ -1031,7 +1032,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) - self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) def set_vocab(self): self._set_vocab_sentencepiece() @@ -1785,6 +1785,63 @@ def get_tensors(self): yield name, data +class GemmaModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(hparams["head_dim"]) + self.gguf_writer.add_value_length(hparams["head_dim"]) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in self.get_tensors(): + # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 + if name.endswith("norm.weight"): + data_torch = data_torch + 1 + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### diff --git a/examples/server/README.md b/examples/server/README.md index 4b24ee5dc3f28..4b6cd8326efa8 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -151,7 +151,7 @@ node index.js `temperature`: Adjust the randomness of the generated text (default: 0.8). - `dynatemp_range`: Dynamic temperature range (default: 0.0, 0.0 = disabled). + `dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` (default: 0.0, 0.0 = disabled). `dynatemp_exponent`: Dynamic temperature exponent (default: 1.0). @@ -209,7 +209,7 @@ node index.js `slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1) - `cache_prompt`: Save the prompt and generation for avoid reprocess entire prompt if a part of this isn't change (default: false) + `cache_prompt`: Re-use previously cached prompt from the last request if possible. This may prevent re-caching the prompt from scratch. (default: false) `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) @@ -242,7 +242,7 @@ Notice that each `probs` is an array of length `n_probs`. - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) -- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model` +- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). - `model`: The path to the model loaded with `-m` - `prompt`: The provided `prompt` - `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 635a52603ae22..74c4719d9a870 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -401,6 +401,16 @@ struct llama_server_context return true; } + void validate_model_chat_template(server_params & sparams) { + llama_chat_message chat[] = {{"user", "test"}}; + std::vector buf(1); + int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size()); + if (res < 0) { + LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); + sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template + } + } + void initialize() { // create slots all_slots_are_idle = true; @@ -1939,6 +1949,10 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); + printf(" -ctk TYPE, --cache-type-k TYPE\n"); + printf(" KV cache data type for K (default: f16)\n"); + printf(" -ctv TYPE, --cache-type-v TYPE\n"); + printf(" KV cache data type for V (default: f16)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); printf(" --log-disable disables logging to a file.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); @@ -2377,6 +2391,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, ); llama.process_system_prompt_data(json::parse(systm_content)); } + else if (arg == "-ctk" || arg == "--cache-type-k") { + params.cache_type_k = argv[++i]; + } + else if (arg == "-ctv" || arg == "--cache-type-v") { + params.cache_type_v = argv[++i]; + } else if(arg == "--mmproj") { if (++i >= argc) @@ -2753,6 +2773,11 @@ int main(int argc, char **argv) LOG_INFO("model loaded", {}); } + if (sparams.chat_template.empty()) { // custom chat template is not supplied + // check if the template comes with the model is supported by us + llama.validate_model_chat_template(sparams); + } + // Middleware for API key validation auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { // If API key is not set, skip validation diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 068aadbb3c91b..794182db9afbe 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,3 +1,7 @@ +#include "ggml-cuda.h" +#include "ggml.h" +#include "ggml-backend-impl.h" + #include #include #include @@ -121,11 +125,6 @@ #endif // defined(GGML_USE_HIPBLAS) -// ggml-cuda need half type so keep ggml headers include at last -#include "ggml-cuda.h" -#include "ggml.h" -#include "ggml-backend-impl.h" - #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CC_PASCAL 600 diff --git a/ggml-impl.h b/ggml-impl.h index 19df66bceee4a..c5637e4d45d8c 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -53,11 +53,23 @@ extern "C" { // #include -#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) -#define GGML_COMPUTE_FP32_TO_FP16(x) (x) +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + __fp16 tmp; + memcpy(&tmp, &h, sizeof(ggml_fp16_t)); + return (float)tmp; +} -#define GGML_FP16_TO_FP32(x) ((float) (x)) -#define GGML_FP32_TO_FP16(x) (x) +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + __fp16 tmp = f; + memcpy(&res, &tmp, sizeof(ggml_fp16_t)); + return res; +} #else @@ -214,8 +226,7 @@ extern float ggml_table_f32_f16[1 << 16]; // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. -#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) - +#if !defined(GGML_FP16_TO_FP32) inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { uint16_t s; memcpy(&s, &f, sizeof(uint16_t)); @@ -223,8 +234,10 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { } #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) +#endif +#if !defined(GGML_FP32_TO_FP16) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) #endif #define GGML_HASHTABLE_FULL ((size_t)-1) diff --git a/ggml-quants.c b/ggml-quants.c index 65c061078249c..30bfc9c69d7a0 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -440,6 +440,30 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { return res; } +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + #else #define ggml_int16x8x2_t int16x8x2_t @@ -453,6 +477,7 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { #define ggml_vld1q_u8_x4 vld1q_u8_x4 #define ggml_vld1q_s8_x2 vld1q_s8_x2 #define ggml_vld1q_s8_x4 vld1q_s8_x4 +#define ggml_vqtbl1q_s8 vqtbl1q_s8 #endif @@ -5631,8 +5656,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -5781,8 +5806,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -6435,7 +6460,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); - const float d = y[i].d * (float)x[i].d; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); @@ -6637,7 +6662,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); - const float d = y[i].d * (float)x[i].d; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); @@ -7140,9 +7165,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r aux16[1] = (a[0] >> 4) & 0x0f0f; const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); - sum_mins += y[i].d * (float)x[i].d[1] * summi; + sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi; - const float d = y[i].d * (float)x[i].d[0]; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); @@ -7800,7 +7825,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const int8_t * sc = x[i].scales; const uint8_t * restrict q5 = x[i].qs; @@ -7942,7 +7967,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const int8_t * sc = x[i].scales; const uint8_t * restrict q5 = x[i].qs; @@ -8510,7 +8535,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; ++i) { - const float d_all = (float)x[i].d; + const float d_all = GGML_FP16_TO_FP32(x[i].d); const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; @@ -8681,7 +8706,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; ++i) { - const float d_all = (float)x[i].d; + const float d_all = GGML_FP16_TO_FP32(x[i].d); const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; @@ -9335,7 +9360,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const uint16_t gindex[8]; uint16x8x2_t vindex; int8x16x4_t q1b; - int8x16x4_t q8b; + ggml_int8x16x4_t q8b; uint16x8x4_t scales; int32x4x2_t sumi; int32x4x2_t dotq; @@ -9500,7 +9525,6 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * float sumf = 0; for (int ib = 0; ib < nb; ib += 2) { - q4bits.val[0] = vld1q_u8(x[ib+0].qs); q4bits.val[1] = vld1q_u8(x[ib+1].qs); q8b.val[0] = vld1q_s8(y[ib+0].qs); @@ -9508,16 +9532,17 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * q8b.val[2] = vld1q_s8(y[ib+1].qs); q8b.val[3] = vld1q_s8(y[ib+1].qs + 16); - q4b.val[0] = vqtbl1q_s8(values, vandq_u8(q4bits.val[0], m4b)); - q4b.val[1] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); - q4b.val[2] = vqtbl1q_s8(values, vandq_u8(q4bits.val[1], m4b)); - q4b.val[3] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - sumf += (float)x[ib+0].d * (float)y[ib+0].d * vaddvq_s32(prod_1) + (float)x[ib+1].d * (float)y[ib+1].d * vaddvq_s32(prod_2); - + sumf += + GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) + + GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2); } *s = sumf; diff --git a/ggml.c b/ggml.c index 8cb497b7f5d08..80729069b6201 100644 --- a/ggml.c +++ b/ggml.c @@ -323,7 +323,7 @@ float ggml_table_f32_f16[1 << 16]; // note: do not use these inside ggml.c // these are meant to be used via the ggml.h API float ggml_fp16_to_fp32(ggml_fp16_t x) { - return (float) GGML_FP16_TO_FP32(x); + return GGML_FP16_TO_FP32(x); } ggml_fp16_t ggml_fp32_to_fp16(float x) { @@ -798,7 +798,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16x8 float16x8_t #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) #define GGML_F16x8_SET1(x) vdupq_n_f16(x) - #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x)) #define GGML_F16x8_STORE vst1q_f16 #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) #define GGML_F16x8_ADD vaddq_f16 @@ -841,7 +841,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F32Cx4 float32x4_t #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) - #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x))) #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) #define GGML_F32Cx4_ADD vaddq_f32 diff --git a/ggml.h b/ggml.h index 063f0e4130cea..1472098727544 100644 --- a/ggml.h +++ b/ggml.h @@ -322,13 +322,7 @@ extern "C" { #endif -#if defined(__ARM_NEON) && defined(__CUDACC__) - typedef half ggml_fp16_t; -#elif defined(__ARM_NEON) && !defined(_MSC_VER) - typedef __fp16 ggml_fp16_t; -#else typedef uint16_t ggml_fp16_t; -#endif // convert FP16 <-> FP32 GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); diff --git a/llama.cpp b/llama.cpp index 859d909c09e3e..f8929ac7d6194 100644 --- a/llama.cpp +++ b/llama.cpp @@ -533,7 +533,6 @@ static std::map> LLM_TENSOR_NAMES = { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, @@ -4126,7 +4125,12 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, false); + + // same as tok_embd, duplicated to allow offloading + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); } for (int i = 0; i < n_layer; ++i) { @@ -4135,14 +4139,23 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, false); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, false); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, false); // AWQ ScaleActivation layer layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, false); @@ -6243,7 +6256,7 @@ struct llm_build_context { attn_norm = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, - NULL, + model.layers[il].attn_norm_b, LLM_NORM, cb, il); cb(attn_norm, "attn_norm", il); @@ -6254,6 +6267,11 @@ struct llm_build_context { cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); + if (model.layers[il].bqkv){ + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + if (hparams.f_clamp_kqv > 0.0f) { cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cb(cur, "wqkv_clamped", il); @@ -6270,7 +6288,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, + model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -6283,13 +6301,13 @@ struct llm_build_context { { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, - NULL, + model.layers[il].ffn_norm_b, LLM_NORM, cb, il); cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_act, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); @@ -6306,7 +6324,7 @@ struct llm_build_context { cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, - NULL, + model.output_norm_b, LLM_NORM, cb, -1); cb(cur, "result_norm", -1); @@ -7506,6 +7524,7 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); cb(inpL, "inp_embd", -1); + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); cb(inpL, "inp_scaled", -1); @@ -7547,6 +7566,7 @@ struct llm_build_context { n_embd_head_k, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); cb(Qcur, "Qcur_scaled", il); @@ -7561,6 +7581,7 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); cb(cur, "kqv_out", il); } + struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); cb(sa_out, "sa_out", il); @@ -10802,7 +10823,10 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty return std::make_pair(i_layer, n_layer); }; - if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { + // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings + // with the quantization of the output tensor + if (name == tn(LLM_TENSOR_OUTPUT, "weight") || + (LLM_TENSOR_NAMES.at(arch).find(LLM_TENSOR_OUTPUT) == LLM_TENSOR_NAMES.at(arch).end() && name == "token_embd.weight")) { int nx = tensor->ne[0]; if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) { new_type = GGML_TYPE_Q8_0; @@ -13085,6 +13109,37 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } + } else if (tmpl.find("bos_token + message['role']") != std::string::npos) { + // mlabonne/AlphaMonarch-7B template (the is included inside history) + for (auto message : chat) { + std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message + ss << bos << message->role << "\n" << message->content << "\n"; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl.find("") != std::string::npos) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken + system_prompt = trim(message->content); + continue; + } + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (!system_prompt.empty() && role != "model") { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << trim(message->content) << "\n"; + } + if (add_ass) { + ss << "model\n"; + } } else { // template not supported return -1; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 9830650d4f8dd..fa2eb577b6e42 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -27,12 +27,24 @@ int main(void) { "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", // bofenghuang/vigogne-2-70b-chat "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + // mlabonne/AlphaMonarch-7B + "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", + // google/gemma-7b-it + "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", }; - std::vector expected_substr = { - "<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant", - "[/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - "[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - "[/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + std::vector expected_output = { + // teknium/OpenHermes-2.5-Mistral-7B + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + // mistralai/Mistral-7B-Instruct-v0.2 + "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + // TheBloke/FusionNet_34Bx2_MoE-AWQ + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + // bofenghuang/vigogne-2-70b-chat + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + // mlabonne/AlphaMonarch-7B + "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + // google/gemma-7b-it + "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", }; std::vector formatted_chat(1024); int32_t res; @@ -43,7 +55,7 @@ int main(void) { for (size_t i = 0; i < templates.size(); i++) { std::string custom_template = templates[i]; - std::string substr = expected_substr[i]; + std::string expected = expected_output[i]; formatted_chat.resize(1024); res = llama_chat_apply_template( nullptr, @@ -57,8 +69,7 @@ int main(void) { formatted_chat.resize(res); std::string output(formatted_chat.data(), formatted_chat.size()); std::cout << output << "\n-------------------------\n"; - // expect the "formatted_chat" to contain pre-defined strings - assert(output.find(substr) != std::string::npos); + assert(output == expected); } return 0; }