diff --git a/.github/workflows/cpp-graph-test.yml b/.github/workflows/cpp-graph-test.yml index 42c63f95f..56f4a15f7 100644 --- a/.github/workflows/cpp-graph-test.yml +++ b/.github/workflows/cpp-graph-test.yml @@ -63,7 +63,7 @@ jobs: - name: Env build run: | - bash ${{ github.workspace }}/.github/workflows/scripts/prepare_env_with_conda.sh "cpp-graph-test-neural-speed" "3.8" + bash ${{ github.workspace }}/.github/workflows/scripts/prepare_env_with_conda.sh "cpp-graph-test-neural-speed" "3.9" - name: BF16 Benchmark run: | diff --git a/.github/workflows/scripts/models/cpp_graph_inference.sh b/.github/workflows/scripts/models/cpp_graph_inference.sh index 6f159cdca..59ab8b4f6 100644 --- a/.github/workflows/scripts/models/cpp_graph_inference.sh +++ b/.github/workflows/scripts/models/cpp_graph_inference.sh @@ -68,7 +68,6 @@ function main() { if [[ "${compiler_version}" != "12.1.0" ]]; then conda install --update-deps -c conda-forge gxx==${compiler_version} gcc==${compiler_version} gxx_linux-64==${compiler_version} libstdcxx-ng sysroot_linux-64 -y fi - export LD_LIBRARY_PATH=${HOME}/miniconda3/envs/${conda_env}/lib/:$LD_LIBRARY_PATH # compile binary cd ${working_dir} mkdir build @@ -81,7 +80,7 @@ function main() { ## prepare example requirement pip install -r neural_speed/models/requirements/common.txt - + export LD_LIBRARY_PATH=${HOME}/miniconda3/envs/${conda_env}/lib/:$LD_LIBRARY_PATH ## prepare fp32 bin python ${convert_script} --outtype f32 --outfile ${working_dir}/${model}-fp32.bin ${input_model} diff --git a/neural_speed/convert/convert_quantized_qwen.py b/neural_speed/convert/convert_quantized_qwen.py index 676f4423e..619ac5adc 100644 --- a/neural_speed/convert/convert_quantized_qwen.py +++ b/neural_speed/convert/convert_quantized_qwen.py @@ -65,7 +65,8 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", hparams["hidden_size"])) f.write(struct.pack("i", hparams["intermediate_size"])) # dummy data f.write(struct.pack("i", hparams["num_attention_heads"])) - f.write(struct.pack("i", 0)) # multi-query attention + f.write(struct.pack("i", hparams["num_key_value_heads"] if "num_key_value_heads" in hparams + else ["num_attention_heads"])) # multi-query attention f.write(struct.pack("i", hparams["num_hidden_layers"])) f.write( struct.pack( @@ -89,7 +90,7 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", 0)) # n_expert_used f.write(struct.pack("i", 0)) # n_embd_head_k for gemma f.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps - f.write(struct.pack("f", 10000.0)) # freq_base + f.write(struct.pack("f", hparams.get("rope_theta", 10000.0))) # freq_base f.write(struct.pack("f", 1.0)) # rope_factor f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled @@ -186,6 +187,8 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout): print(f"Success! saved as {out_path}") elif hparams['model_type'] == 'qwen2': # 3. write tensors + if hparams['tie_word_embeddings']: + list_vars["lm_head.weight"] = list_vars["model.embed_tokens.weight"] convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, f) convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, f) convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, f) diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py index f268b37e6..28f1d9341 100644 --- a/neural_speed/convert/convert_qwen.py +++ b/neural_speed/convert/convert_qwen.py @@ -103,7 +103,8 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", hparams["hidden_size"])) fout.write(struct.pack("i", hparams["intermediate_size"])) # dummy data fout.write(struct.pack("i", hparams["num_attention_heads"])) - fout.write(struct.pack("i", 0)) # multi-query attention + fout.write(struct.pack("i", hparams["num_key_value_heads"] if "num_key_value_heads" in hparams + else ["num_attention_heads"])) # multi-query attention fout.write(struct.pack("i", hparams["num_hidden_layers"])) fout.write( struct.pack( @@ -128,7 +129,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) # n_expert_used fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps - fout.write(struct.pack("f", 10000.0)) # freq_base + fout.write(struct.pack("f", hparams.get("rope_theta", 10000.0))) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled diff --git a/neural_speed/models/qwen/qwen.cpp b/neural_speed/models/qwen/qwen.cpp index 44ba22670..f70365997 100644 --- a/neural_speed/models/qwen/qwen.cpp +++ b/neural_speed/models/qwen/qwen.cpp @@ -102,6 +102,8 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_rot; const int head_dim = n_embd / n_head; + const int n_head_kv = hparams.n_head_kv; + const int n_embd_gqa = head_dim * n_head_kv; int qwen_version = 0; if (hparams.max_seq_len == 8192) { qwen_version = 1; @@ -132,7 +134,7 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu attn_shape_t attn_shape = { /* .batch_size = */ 1, /* .head_num = */ n_head, - /* .heads_kv = */ n_head, + /* .heads_kv = */ n_head_kv, /* .head_size = */ head_dim, /* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference /* .sl_kv = */ n_past + N, @@ -141,7 +143,7 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", bestla_reordered_attn_fp32_support(&attn_shape))); kv_shape_t kv_shape{ - /* .heads_kv = */ static_cast(n_head), + /* .heads_kv = */ static_cast(n_head_kv), /* .head_size = */ static_cast(head_dim), /* .sl_kv_max = */ static_cast(n_ctx), }; @@ -194,11 +196,11 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu Kcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); Kcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[3], Kcur), Kcur); - Kcur = ne_reshape_3d(ctx0, Kcur, head_dim, n_head, N); + Kcur = ne_reshape_3d(ctx0, Kcur, head_dim, n_head_kv, N); Vcur = ne_mul_mat(ctx0, model.layers[il].attn[4], cur); Vcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[5], Vcur), Vcur); - Vcur = ne_reshape_3d(ctx0, Vcur, head_dim, n_head, N); + Vcur = ne_reshape_3d(ctx0, Vcur, head_dim, n_head_kv, N); } // using mode = 2 for GPT-NeoX mode @@ -216,29 +218,31 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu std::vector v_bs(batch_size); for (int i = 0; i < batch_size; ++i) { // batch K - Kcur_bs[i] = ne_permute(ctx0, - ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim, - ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N, - i * ne_element_size(Kcur) * n_embd * N), - 0, 2, 1, 3); + Kcur_bs[i] = + ne_permute(ctx0, + ne_view_4d(ctx0, Kcur, head_dim, n_head_kv, N, 1, ne_element_size(Kcur) * head_dim, + ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * N, + i * ne_element_size(Kcur) * n_embd_gqa * N), + 0, 2, 1, 3); k_bs[i] = ne_view_4d( - ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim, - ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, - ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block + - i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); + ctx0, kv_self.k, head_dim, N, n_head_kv, 1, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd_gqa * n_ctx, + ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block + + i * n_ctx * n_embd_gqa * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); // batch V - Vcur_bs[i] = ne_permute(ctx0, - ne_reshape_4d(ctx0, - ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd, - i * ne_element_size(Vcur) * n_embd * N), - head_dim, n_head, N, 1), - 1, 2, 0, 3); - v_bs[i] = - ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v), - n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, - ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block + - i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); + Vcur_bs[i] = + ne_permute(ctx0, + ne_reshape_4d(ctx0, + ne_view_2d(ctx0, Vcur, n_embd_gqa, N, ne_element_size(Vcur) * n_embd_gqa, + i * ne_element_size(Vcur) * n_embd_gqa * N), + head_dim, n_head_kv, N, 1), + 1, 2, 0, 3); + v_bs[i] = ne_view_4d( + ctx0, kv_self.v, N, head_dim, n_head_kv, 1, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd_gqa, + ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block + + i * n_ctx * n_embd_gqa * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); } @@ -247,10 +251,10 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3); // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - struct ne_tensor* K = - ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim, - ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, - il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block); + struct ne_tensor* K = ne_view_4d( + ctx0, kv_self.k, head_dim, n_past + N, n_head_kv, batch_size, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd_gqa * n_ctx, + il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block); // K * Q struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); @@ -267,9 +271,9 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() struct ne_tensor* V = - ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v), - n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, - il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block); + ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head_kv, batch_size, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd_gqa, + il * n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block); // KQV = transpose(V) * KQ_soft_max struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); @@ -286,15 +290,15 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu // store key and value to memory { - const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor - head_dim, n_ctx, n_head, // ne - 0, 0, // nb (bestla managed) - il * k_size); // offset + const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, n_ctx, n_head_kv, // ne + 0, 0, // nb (bestla managed) + il * k_size); // offset ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, false)); - const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor - head_dim, n_ctx, n_head, // ne - 0, 0, // nb (bestla managed) - il * v_size); // offset + const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor + head_dim, n_ctx, n_head_kv, // ne + 0, 0, // nb (bestla managed) + il * v_size); // offset ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past, false)); } @@ -303,14 +307,14 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu struct ne_tensor* K = ne_view_3d(ctx0, kv_self.k, // tensor - head_dim, seq_kv, n_head, // ne + head_dim, seq_kv, n_head_kv, // ne kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed) il * k_size); // offset *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout ne_set_name(K, "K"); struct ne_tensor* V = ne_view_3d(ctx0, kv_self.v, // tensor - seq_kv, head_dim, n_head, // ne + seq_kv, head_dim, n_head_kv, // ne kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed) il * v_size); // offset *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout diff --git a/neural_speed/models/qwen/qwen.h b/neural_speed/models/qwen/qwen.h index 3fb54b7c6..e6e8ac0f3 100644 --- a/neural_speed/models/qwen/qwen.h +++ b/neural_speed/models/qwen/qwen.h @@ -44,6 +44,18 @@ static const model_scratch qwen_mem_req(int n_layers, float scratch_size_ratio = static_cast(scratch_size_ratio * 2048) * MB, static_cast(scratch_size_ratio * 4096) * MB, }; + case 28: + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; + case 80: + return { + static_cast(scratch_size_ratio * 10 * 4096) * MB, + static_cast(scratch_size_ratio * 10 * 2048) * MB, + static_cast(scratch_size_ratio * 10 * 4096) * MB, + }; default: MODEL_ASSERT(false); } @@ -53,7 +65,7 @@ class QWEN : public IModel { private: model_archs arch = MODEL_QWEN; std::unique_ptr ml; - uint32_t n_layer, n_embd, n_ff, n_vocab; + uint32_t n_layer, n_embd, n_ff, n_vocab, n_head, n_head_kv; int n_gpu_layer; bool use_mmap, use_mlock, vocab_only; model_scratch scratch; diff --git a/neural_speed/models/qwen/qwen_utils.cpp b/neural_speed/models/qwen/qwen_utils.cpp index 4dd25cec4..28e49f675 100644 --- a/neural_speed/models/qwen/qwen_utils.cpp +++ b/neural_speed/models/qwen/qwen_utils.cpp @@ -65,6 +65,8 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); n_embd = hparams.n_embd; + n_head_kv = hparams.n_head_kv; + n_head = hparams.n_head; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; scratch = qwen_mem_req(n_layer, lctx.scratch_size_ratio); @@ -181,10 +183,12 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v // qkv GEMM + out proj GEMM layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend); layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend); - layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend); - layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend); - layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend); - layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend); + layer.attn[2] = + ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd * n_head_kv / n_head}, backend); + layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd * n_head_kv / n_head}, backend); + layer.attn[4] = + ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd * n_head_kv / n_head}, backend); + layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd * n_head_kv / n_head}, backend); layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); // ffn GEMM