Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[model]Enable qwen2 #281

Merged
merged 10 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions neural_speed/convert/convert_quantized_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", hparams["hidden_size"]))
f.write(struct.pack("i", hparams["intermediate_size"])) # dummy data
f.write(struct.pack("i", hparams["num_attention_heads"]))
f.write(struct.pack("i", 0)) # multi-query attention
f.write(struct.pack("i", hparams["num_key_value_heads"] if "num_key_value_heads" in hparams
else ["num_attention_heads"])) # multi-query attention
f.write(struct.pack("i", hparams["num_hidden_layers"]))
f.write(
struct.pack(
Expand All @@ -89,7 +90,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", 0)) # n_expert_used
f.write(struct.pack("i", 0)) # n_embd_head_k for gemma
f.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
f.write(struct.pack("f", 10000.0)) # freq_base
f.write(struct.pack("f", hparams.get("rope_theta", 10000.0))) # freq_base
f.write(struct.pack("f", 1.0)) # rope_factor

f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
Expand Down Expand Up @@ -186,6 +187,8 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
print(f"Success! saved as {out_path}")
elif hparams['model_type'] == 'qwen2':
# 3. write tensors
if hparams['tie_word_embeddings']:
list_vars["lm_head.weight"] = list_vars["model.embed_tokens.weight"]
convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, f)
convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, f)
convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, f)
Expand Down
5 changes: 3 additions & 2 deletions neural_speed/convert/convert_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", hparams["hidden_size"]))
fout.write(struct.pack("i", hparams["intermediate_size"])) # dummy data
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", 0)) # multi-query attention
fout.write(struct.pack("i", hparams["num_key_value_heads"] if "num_key_value_heads" in hparams
else ["num_attention_heads"])) # multi-query attention
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
fout.write(
struct.pack(
Expand All @@ -128,7 +129,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", hparams.get("rope_theta", 10000.0))) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
Expand Down
86 changes: 45 additions & 41 deletions neural_speed/models/qwen/qwen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
const int head_dim = n_embd / n_head;
const int n_head_kv = hparams.n_head_kv;
const int n_embd_gqa = head_dim * n_head_kv;
int qwen_version = 0;
if (hparams.max_seq_len == 8192) {
qwen_version = 1;
Expand Down Expand Up @@ -132,7 +134,7 @@
attn_shape_t attn_shape = {
/* .batch_size = */ 1,
/* .head_num = */ n_head,
/* .heads_kv = */ n_head,
/* .heads_kv = */ n_head_kv,
/* .head_size = */ head_dim,
/* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference
/* .sl_kv = */ n_past + N,
Expand All @@ -141,7 +143,7 @@
NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
bestla_reordered_attn_fp32_support(&attn_shape)));
kv_shape_t kv_shape{
/* .heads_kv = */ static_cast<uint32_t>(n_head),
/* .heads_kv = */ static_cast<uint32_t>(n_head_kv),
/* .head_size = */ static_cast<uint32_t>(head_dim),
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
};
Expand Down Expand Up @@ -194,11 +196,11 @@

Kcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
Kcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[3], Kcur), Kcur);
Kcur = ne_reshape_3d(ctx0, Kcur, head_dim, n_head, N);
Kcur = ne_reshape_3d(ctx0, Kcur, head_dim, n_head_kv, N);

Vcur = ne_mul_mat(ctx0, model.layers[il].attn[4], cur);
Vcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[5], Vcur), Vcur);
Vcur = ne_reshape_3d(ctx0, Vcur, head_dim, n_head, N);
Vcur = ne_reshape_3d(ctx0, Vcur, head_dim, n_head_kv, N);
}

// using mode = 2 for GPT-NeoX mode
Expand All @@ -216,29 +218,31 @@
std::vector<ne_tensor*> v_bs(batch_size);
for (int i = 0; i < batch_size; ++i) {
// batch K
Kcur_bs[i] = ne_permute(ctx0,
ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim,
ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N,
i * ne_element_size(Kcur) * n_embd * N),
0, 2, 1, 3);
Kcur_bs[i] =
ne_permute(ctx0,
ne_view_4d(ctx0, Kcur, head_dim, n_head_kv, N, 1, ne_element_size(Kcur) * head_dim,
ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * N,
i * ne_element_size(Kcur) * n_embd_gqa * N),
0, 2, 1, 3);
k_bs[i] = ne_view_4d(
ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k)));
ctx0, kv_self.k, head_dim, N, n_head_kv, 1, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd_gqa * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block +
i * n_ctx * n_embd_gqa * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k)));
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed

// batch V
Vcur_bs[i] = ne_permute(ctx0,
ne_reshape_4d(ctx0,
ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd,
i * ne_element_size(Vcur) * n_embd * N),
head_dim, n_head, N, 1),
1, 2, 0, 3);
v_bs[i] =
ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v)));
Vcur_bs[i] =
ne_permute(ctx0,
ne_reshape_4d(ctx0,
ne_view_2d(ctx0, Vcur, n_embd_gqa, N, ne_element_size(Vcur) * n_embd_gqa,
i * ne_element_size(Vcur) * n_embd_gqa * N),
head_dim, n_head_kv, N, 1),
1, 2, 0, 3);
v_bs[i] = ne_view_4d(
ctx0, kv_self.v, N, head_dim, n_head_kv, 1, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd_gqa,
((il * n_ctx) * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block +
i * n_ctx * n_embd_gqa * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v)));
Dismissed Show dismissed Hide dismissed
ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i]));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i]));
}
Expand All @@ -247,10 +251,10 @@
struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3);

// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ne_tensor* K =
ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block);
struct ne_tensor* K = ne_view_4d(
ctx0, kv_self.k, head_dim, n_past + N, n_head_kv, batch_size, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd_gqa * n_ctx,
il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block);
Dismissed Show dismissed Hide dismissed

// K * Q
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);
Expand All @@ -267,9 +271,9 @@

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ne_tensor* V =
ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block);
ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head_kv, batch_size, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd_gqa,
il * n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block);
Dismissed Show dismissed Hide dismissed

// KQV = transpose(V) * KQ_soft_max
struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);
Expand All @@ -286,15 +290,15 @@

// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (bestla managed)
il * k_size); // offset
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, n_ctx, n_head_kv, // ne
0, 0, // nb (bestla managed)
il * k_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, false));
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (bestla managed)
il * v_size); // offset
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_dim, n_ctx, n_head_kv, // ne
0, 0, // nb (bestla managed)
il * v_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past, false));
}

Expand All @@ -303,14 +307,14 @@

struct ne_tensor* K =
ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, seq_kv, n_head, // ne
head_dim, seq_kv, n_head_kv, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed)
il * k_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
ne_set_name(K, "K");
struct ne_tensor* V =
ne_view_3d(ctx0, kv_self.v, // tensor
seq_kv, head_dim, n_head, // ne
seq_kv, head_dim, n_head_kv, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed)
il * v_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
Expand Down
8 changes: 7 additions & 1 deletion neural_speed/models/qwen/qwen.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ static const model_scratch qwen_mem_req(int n_layers, float scratch_size_ratio =
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
case 28:
return {
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
default:
MODEL_ASSERT(false);
}
Expand All @@ -53,7 +59,7 @@ class QWEN : public IModel {
private:
model_archs arch = MODEL_QWEN;
std::unique_ptr<model_model_loader> ml;
uint32_t n_layer, n_embd, n_ff, n_vocab;
uint32_t n_layer, n_embd, n_ff, n_vocab, n_head, n_head_kv;
int n_gpu_layer;
bool use_mmap, use_mlock, vocab_only;
model_scratch scratch;
Expand Down
12 changes: 8 additions & 4 deletions neural_speed/models/qwen/qwen_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
n_embd = hparams.n_embd;
n_head_kv = hparams.n_head_kv;
n_head = hparams.n_head;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
scratch = qwen_mem_req(n_layer, lctx.scratch_size_ratio);
Expand Down Expand Up @@ -181,10 +183,12 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
// qkv GEMM + out proj GEMM
layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend);
layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend);
layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend);
layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend);
layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend);
layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend);
layer.attn[2] =
ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd * n_head_kv / n_head}, backend);
layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd * n_head_kv / n_head}, backend);
layer.attn[4] =
ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd * n_head_kv / n_head}, backend);
layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd * n_head_kv / n_head}, backend);
layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend);

// ffn GEMM
Expand Down
Loading