Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
add back f32 model support
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Jun 7, 2024
1 parent 609e559 commit 8f64e54
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 16 deletions.
1 change: 0 additions & 1 deletion bestla/bestla/bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ enum class BTLA_ISA : uint8_t {
AVX512_BF16,
AMX_FP16,
ISA_COUNT,
SYCL_XVE,
};
enum class BTLA_DTYPE : uint32_t {
EleBitsMask = 0xff,
Expand Down
8 changes: 4 additions & 4 deletions bestla/bestla/ut/bestla_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ class UTWOQ_CompFp32 {
}
auto psize = (size_t)m * n * k * 2;
int blks = k / blocksize;
int nbits = utils::bestla_dtype_bits(qtype);
size_t nbits = utils::bestla_dtype_bits(qtype);
auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float);
tm.start();
while (tm.stop() < timems) {
Expand Down Expand Up @@ -700,7 +700,7 @@ class UTWOQ_CompBf16 {
}
auto psize = (size_t)m * n * k * 2;
int blks = k / blocksize;
int nbits = utils::bestla_dtype_bits(qtype);
size_t nbits = utils::bestla_dtype_bits(qtype);
auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float);
tm.start();
while (tm.stop() < timems) {
Expand Down Expand Up @@ -816,7 +816,7 @@ class UTWOQ_CompInt8 {
quanA.assign(bufferA.data());
auto psize = (size_t)m * n * k * 2;
int blks = k / blocksize;
int nbits = utils::bestla_dtype_bits(qtype);
auto nbits = utils::bestla_dtype_bits(qtype);
auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float);
if (isasym) {
memsize += n * blks * sizeof(int8_t);
Expand Down Expand Up @@ -878,7 +878,7 @@ class UTWOQ_CompInt8 {
quanA.assign(bufferA.data());
auto psize = (size_t)m * n * k * 2;
int blks = k / blocksize;
int nbits = utils::bestla_dtype_bits(qtype);
auto nbits = utils::bestla_dtype_bits(qtype);
auto memsize = (size_t)(n * k * nbits / 8 + n * blks * sizeof(Scale_T)) + (m * k + m * n) * sizeof(float);
if (isasym) {
memsize += n * blks * sizeof(int8_t);
Expand Down
1 change: 0 additions & 1 deletion neural_speed/core/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ enum ne_type {
NE_TYPE_I16,
NE_TYPE_I32,
NE_TYPE_BTLA,
NE_TYPE_BTLA_SYCL,
NE_TYPE_COUNT,
};

Expand Down
8 changes: 8 additions & 0 deletions neural_speed/core/layers/ne_bestla.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,23 @@ bool bestla_support(struct ne_tensor* node, int n_threads, size_t* workspace, si
support = true;
}
switch (node->op) {
case NE_OP_MUL_MAT_ID:
case NE_OP_MUL_MAT_BIAS:
case NE_OP_MUL_MAT: {
struct ne_tensor* wei = node->src0;
if (node->op == NE_OP_MUL_MAT_ID) {
wei = node->opt[0];
}
if (node->src0->type == NE_TYPE_BTLA) {
if (node->src0->backend == NE_BACKEND_CPU) {
ws_h = bestla_f32f32_get_workspace_size(node->src1->ne[1], wei->ne[1], node->src1->ne[0], wei->data);
}
support = true;
}
} break;
case NE_OP_ROPE:
if (node->type == NE_TYPE_BTLA) support = true;
break;
case NE_OP_MUL:
case NE_OP_ADD: {
if (ne_is_contiguous(node->src1) && ne_is_contiguous(node->src0) &&
Expand Down
17 changes: 7 additions & 10 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ static const int NE_BLCK_SIZE[NE_TYPE_COUNT] = {
[NE_TYPE_Q6_K] = QK_K, [NE_TYPE_Q8_K] = QK_K, [NE_TYPE_I8] = 1, [NE_TYPE_I16] = 1,
[NE_TYPE_I32] = 1,
};
static_assert(NE_TYPE_COUNT == 21, "NE_BLCK_SIZE is outdated");
static_assert(NE_TYPE_COUNT == 20, "NE_BLCK_SIZE is outdated");

static const size_t NE_TYPE_SIZE[NE_TYPE_COUNT] = {
[NE_TYPE_F32] = sizeof(float), [NE_TYPE_F16] = sizeof(ne_fp16_t), [NE_TYPE_Q4_0] = sizeof(block_q4_0),
Expand All @@ -353,22 +353,22 @@ static const size_t NE_TYPE_SIZE[NE_TYPE_COUNT] = {
[NE_TYPE_Q8_K] = sizeof(block_q8_K), [NE_TYPE_I8] = sizeof(int8_t), [NE_TYPE_I16] = sizeof(int16_t),
[NE_TYPE_I32] = sizeof(int32_t),
};
static_assert(NE_TYPE_COUNT == 21, "NE_TYPE_SIZE is outdated");
static_assert(NE_TYPE_COUNT == 20, "NE_TYPE_SIZE is outdated");

static const char* NE_TYPE_NAME[NE_TYPE_COUNT] = {
[NE_TYPE_F32] = "f32", [NE_TYPE_F16] = "f16", [NE_TYPE_Q4_0] = "q4_0", [NE_TYPE_Q4_1] = "q4_1",
[NE_TYPE_Q5_0] = "q5_0", [NE_TYPE_Q5_1] = "q5_1", [NE_TYPE_Q8_0] = "q8_0", [NE_TYPE_Q8_1] = "q8_1",
[NE_TYPE_Q6_K] = "q6_k", [NE_TYPE_Q8_K] = "q8_k", [NE_TYPE_I8] = "i8", [NE_TYPE_I16] = "i16",
[NE_TYPE_I32] = "i32",
};
static_assert(NE_TYPE_COUNT == 21, "NE_TYPE_NAME is outdated");
static_assert(NE_TYPE_COUNT == 20, "NE_TYPE_NAME is outdated");

static bool NE_IS_QUANTIZED[NE_TYPE_COUNT] = {
[NE_TYPE_F32] = false, [NE_TYPE_F16] = false, [NE_TYPE_Q4_0] = true, [NE_TYPE_Q4_1] = true, [NE_TYPE_Q5_0] = true,
[NE_TYPE_Q5_1] = true, [NE_TYPE_Q8_0] = true, [NE_TYPE_Q8_1] = true, [NE_TYPE_Q6_K] = true, [NE_TYPE_Q6_K] = true,
[NE_TYPE_I8] = false, [NE_TYPE_I16] = false, [NE_TYPE_I32] = false, [NE_TYPE_BTLA] = true,
};
static_assert(NE_TYPE_COUNT == 21, "NE_IS_QUANTIZED is outdated");
static_assert(NE_TYPE_COUNT == 20, "NE_IS_QUANTIZED is outdated");

static const char* NE_OP_LABEL[NE_OP_COUNT] = {"NONE",

Expand Down Expand Up @@ -11691,7 +11691,6 @@ bool ne_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t
node->n_tasks = n_threads;
support = true;
} break;
case NE_OP_MUL_MAT_BIAS:
case NE_OP_MUL_MAT_ID:
case NE_OP_CONV_1D:
case NE_OP_MUL_MAT: {
Expand All @@ -11713,6 +11712,8 @@ bool ne_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t
} else if (ne_is_quantized(wei->type) && node->src1->type == NE_TYPE_F32) {
const enum ne_type type_q = quantize_fns[wei->type].vec_dot_type;
ws_h = NE_TYPE_SIZE[type_q] * ne_nelements(node->src1) / NE_BLCK_SIZE[type_q];
} else if (wei->type == NE_TYPE_F32 && node->src1->type == NE_TYPE_F32) {
ws_h = 0;
} else {
NE_ASSERT(false);
}
Expand All @@ -11739,11 +11740,7 @@ bool ne_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t
case NE_OP_DIAG_MASK_INF:
case NE_OP_PADDING_MASK_INF:
case NE_OP_ROPE:
// only first token use parallel
if (node->type == NE_TYPE_BTLA)
node->n_tasks = 1;
else
node->n_tasks = n_threads;
node->n_tasks = n_threads;
support = true;
break;
case NE_OP_SOFT_MAX: {
Expand Down

0 comments on commit 8f64e54

Please sign in to comment.