diff --git a/tests/integration/gemm/fp32/common.hpp b/tests/integration/gemm/fp32/common.hpp index 7ce5098fc..67303c398 100644 --- a/tests/integration/gemm/fp32/common.hpp +++ b/tests/integration/gemm/fp32/common.hpp @@ -97,7 +97,7 @@ class Test3 : public TestBase { static constexpr size_t mat_m = 16; static constexpr size_t mat_n = 64; static constexpr size_t mat_k = 32; - static constexpr size_t wg_m = 8; + static constexpr size_t wg_m = 16; static constexpr size_t wg_n = 64; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 64; @@ -205,7 +205,7 @@ class Test8 : public TestBase { static constexpr uint32_t global_kslicing = 2; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -227,7 +227,6 @@ class Test9 : public TestBase { static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - static constexpr mma_engine engine = mma_engine::xmx; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -245,10 +244,10 @@ class Test10 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 8; - static constexpr uint32_t global_kslicing = 2; + static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -258,9 +257,9 @@ class Test10 : public TestBase { class Test11 : public TestBase { public: static constexpr size_t batch_size = 35; - static constexpr size_t mat_m = 4192; - static constexpr size_t mat_k = 1136; - static constexpr size_t mat_n = 688; + static constexpr size_t mat_m = 4193; + static constexpr size_t mat_k = 1134; + static constexpr size_t mat_n = 686; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; @@ -314,4 +313,4 @@ class result_validate { Test::layout_a, Test::layout_b); } -}; +}; \ No newline at end of file diff --git a/tests/integration/gemm/fp32/main.cpp b/tests/integration/gemm/fp32/main.cpp index 3a9badc9f..a05393399 100644 --- a/tests/integration/gemm/fp32/main.cpp +++ b/tests/integration/gemm/fp32/main.cpp @@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) { REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd); using tests = ::testing::Types< - Test1, - Test2, - Test3, - Test4, - Test5, - Test6, - Test7, - Test8, - Test9, + // Test1, + // Test2, + // Test3, + // Test4, + // Test5, + // Test6, + // Test7, + // Test8, + // Test9, Test10, Test11>; INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests); diff --git a/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt b/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt index 4bdf0a42b..6a3a155f3 100644 --- a/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt +++ b/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt @@ -1,11 +1,6 @@ get_filename_component(ProjectId ${CMAKE_CURRENT_SOURCE_DIR} NAME) string(REPLACE " " "_" ProjectId ${ProjectId}) -set(ProjectIdClient ${ProjectId}) -set(ProjectIdXe ${ProjectId}) -string(PREPEND ProjectIdClient "gemm_client_") -string(PREPEND ProjectIdXe "gemm_xe_") +string(PREPEND ProjectIdClient "gemm_") -FILE(GLOB src_client main_client.cpp) -add_integration_test(${ProjectIdClient} ${src_client}) -FILE(GLOB src_xe main_xe.cpp) -add_integration_test(${ProjectIdXe} ${src_xe}) +FILE(GLOB src main.cpp) +add_integration_test(${ProjectId} ${src}) diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main.cpp similarity index 100% rename from tests/integration/gemm/int4_dequantization_bias/main_client.cpp rename to tests/integration/gemm/int4_dequantization_bias/main.cpp diff --git a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp deleted file mode 100644 index 0cc9d8d6f..000000000 --- a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp +++ /dev/null @@ -1,641 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2022-2023 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#include -#include "xetla.hpp" -// #define UT_DEBUG 1 -using namespace gpu::xetla; -// The number of times the kernel is executed -constexpr int ITER = 100; - -class test1 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 16384; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class test2 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 22016; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 128; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 128; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 4; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv1 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 12288; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv2 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv3 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 11008; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv4 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 11008; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 32; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 4; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv5 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 151936; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv6 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 12288; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv7 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv8 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 11008; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv9 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 11008; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 4; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv10 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 151936; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; - -template < - typename data_type_a, - typename data_type_b, - typename data_type_c, - typename data_type_acc = float, - typename data_type_bias = data_type_a> -int gemm_result_validate( - data_type_a* A, - data_type_b* B, - data_type_c* C, - data_type_bias* bias, - uint32_t m, - uint32_t k, - uint32_t n, - mem_layout mem_layout_a_ = mem_layout::row_major, - mem_layout mem_layout_b_ = mem_layout::row_major) { - buff_cmp::buff_vals data(C, m, n, n); - std::vector gold_C(m * n, 0); - get_gemm_gold( - m, n, k, mem_layout_a_, mem_layout_b_, A, B, gold_C.data()); - - // BiasAdd - for (uint32_t i = 0; i < gold_C.size(); ++i) { - uint32_t col = i % n; - gold_C[i] += bias[col]; - } - - buff_cmp::buff_vals other(gold_C.data(), m, n, n); - - bool result = buff_cmp::xetla_buff_cmp(data, other, "gemm validation"); - - std::cout << (!result ? "FAILED\n" : "PASSED\n"); - return result ? 0 : 1; -} - -template -void dequantize_gemm_run(int iter) { - using namespace gpu; - // Accept incoming parameters - constexpr size_t matrix_m = Test::mat_m; - constexpr size_t matrix_n = Test::mat_n; - constexpr size_t matrix_k = Test::mat_k; - constexpr uint32_t global_kslicing = Test::global_kslicing; - constexpr uint32_t local_kslicing = Test::local_kslicing; - static constexpr mem_layout layout_b = Test::layout_b; - constexpr size_t wg_tile_m = Test::wg_m; - constexpr size_t wg_tile_n = Test::wg_n; - constexpr size_t sg_tile_m = Test::sg_m; - constexpr size_t sg_tile_n = Test::sg_n; - constexpr size_t sg_tile_k = Test::sg_k; - constexpr size_t dequant_s = Test::dequant_s; - using data_type_a = typename Test::data_type_a; - using data_type_b = typename Test::data_type_b; - using data_type_c = typename Test::data_type_c; - using data_type_zero_pt = int4x2; - using data_type_scale = fp16; - using data_type_acc_in = fp16; - using data_type_acc = float; - using data_type_bias = fp16; - - constexpr size_t size_a = matrix_m * matrix_k; - constexpr size_t size_b = matrix_k * matrix_n / 2; - - constexpr size_t size_scale_m = matrix_k / dequant_s; - constexpr size_t size_scale_n = matrix_n; - constexpr size_t size_scale = size_scale_m * size_scale_n; - - constexpr size_t size_zero_pt_m = matrix_k / dequant_s; - constexpr size_t size_zero_pt_n = matrix_n / 2; - constexpr size_t size_zero_pt = size_zero_pt_m * size_zero_pt_n; - - constexpr size_t size_c = matrix_m * matrix_n; - constexpr size_t size_bias = matrix_n; - - // Turn on the enable_profiling property to facilitate subsequent profiling - sycl::property_list properties{sycl::property::queue::enable_profiling()}; - auto queue = sycl::queue(properties); - auto context = queue.get_info(); - auto device = queue.get_info(); - - std::cout << "Running on " << device.get_info() << "\n"; - - using tile_shape = - xetla::group::tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 0; - static constexpr uint32_t prefetch_distance = 0; - - using mem_desc_a_t = xetla::mem_desc_t< - data_type_a, - mem_layout::row_major, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_a)>; - using mem_desc_b_t = xetla::mem_desc_t< - data_type_b, - layout_b, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_b)>; - using mem_desc_c_t = xetla::mem_desc_t< - data_type_c, - mem_layout::row_major, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_c)>; - - using mem_desc_bias_t = xetla::mem_desc_t< - data_type_bias, - mem_layout::row_major, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_bias)>; - - using compute_attr = xetla::group:: - compute_attr_t; - using perf_tuning_knob = xetla::group:: - perf_tuning_knob_t; - static constexpr quant_info quant_info{ - quant_mode::I4_SYM, Test::dequant_s, layout_b}; - - using compute_policy = xetla::group::compute_policy_int4_dequantize< - compute_attr, - perf_tuning_knob, - data_type_scale, - data_type_zero_pt, - quant_info, - mma_engine::xmx, - gpu_arch::XeHpc>; - - using gemm_t = xetla::group:: - gemm_t; - - using bias_op_t = - gpu::xetla::subgroup::bias_add_op_t; - using tile_op_t = gpu::xetla::subgroup::chained_tile_op_t; - - using epilogue_t = xetla::group::epilogue_t< - xetla::group::epilogue_policy_tile_op, - tile_shape, - mem_desc_c_t>; - - using group_swizzle = xetla::kernel::group_swizzle_default; - using gemm_op_t = xetla::kernel::gemm_universal_t< - gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing< - group_swizzle, - global_kslicing, - local_kslicing>, - gemm_t, - epilogue_t>; - - size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n); - size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n); - - // Define and initialize the data required for the calculation - auto* A_h = static_cast( - malloc_host(size_a * sizeof(data_type_a), context)); - auto* B_h = static_cast( - malloc_host(size_b * sizeof(data_type_b), context)); - auto* C_h = static_cast( - malloc_host(size_c * sizeof(data_type_c), context)); - auto* Acc_h = static_cast( - malloc_host(size_acc * sizeof(data_type_acc), context)); - auto* Cnt_h = - static_cast(malloc_host(size_cnt * sizeof(uint32_t), context)); - auto* scale_h = static_cast( - malloc_host(size_scale * sizeof(data_type_scale), context)); - auto* zero_pt_h = static_cast( - malloc_host(size_zero_pt * sizeof(data_type_zero_pt), context)); - auto* bias_h = static_cast( - malloc_host(size_bias * sizeof(data_type_bias), context)); - - auto* A_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_a * sizeof(data_type_a), device, context)); - auto* B_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_b * sizeof(data_type_b), device, context)); - auto* C_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_c * sizeof(data_type_c), device, context)); - auto* Acc_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_acc * sizeof(data_type_acc), device, context)); - auto* Cnt_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_cnt * sizeof(uint32_t), device, context)); - auto* scale_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, - size_scale * sizeof(data_type_scale), - device, - context)); - auto* zero_pt_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, - size_zero_pt * sizeof(data_type_zero_pt), - device, - context)); - auto* bias_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, - size_bias * sizeof(data_type_bias), - device, - context)); - - for (unsigned i = 0; i < size_a; ++i) { - A_h[i] = random_float(); -#ifdef UT_DEBUG - A_h[i] = 1.f; -#endif - } - for (unsigned i = 0; i < size_b; ++i) { - B_h[i] = uint8_t(random_uint8()); -#ifdef UT_DEBUG - B_h[i] = 153; -#endif - } - for (unsigned i = 0; i < size_scale; ++i) { - scale_h[i] = random_float(); -#ifdef UT_DEBUG - scale_h[i] = 1.f; -#endif - } - for (unsigned i = 0; i < size_zero_pt; ++i) { - zero_pt_h[i] = 0.f; - } - for (unsigned i = 0; i < size_c; ++i) { - C_h[i] = 0; - } - for (unsigned i = 0; i < size_acc; ++i) { - Acc_h[i] = 0; - } - for (unsigned i = 0; i < size_cnt; ++i) { - Cnt_h[i] = 0; - } - for (unsigned i = 0; i < size_bias; ++i) { - bias_h[i] = random_float(); -#ifdef UT_DEBUG - bias_h[i] = 0.f; -#endif - } - - queue.memcpy((void*)A_d, (void*)A_h, size_a * sizeof(data_type_a)).wait(); - queue.memcpy((void*)B_d, (void*)B_h, size_b * sizeof(data_type_b)).wait(); - queue.memcpy((void*)C_d, (void*)C_h, size_c * sizeof(data_type_c)).wait(); - queue.memcpy((void*)Acc_d, (void*)Acc_h, size_acc * sizeof(data_type_acc)) - .wait(); - queue.memcpy((void*)Cnt_d, (void*)Cnt_h, size_cnt * sizeof(uint32_t)).wait(); - queue - .memcpy( - (void*)scale_d, (void*)scale_h, size_scale * sizeof(data_type_scale)) - .wait(); - queue - .memcpy( - (void*)zero_pt_d, - (void*)zero_pt_h, - size_zero_pt * sizeof(data_type_zero_pt)) - .wait(); - queue.memcpy((void*)bias_d, (void*)bias_h, size_bias * sizeof(data_type_bias)) - .wait(); - - // set up gemm arguments - typename bias_op_t::shape_t bias_add_shape(matrix_n, 1, matrix_n); - using epilogue_args_t = epilogue_t::arguments_t; - - epilogue_args_t epilogue_args( - {// epilogue_args init list - // It accepts the base pointer to matrix D, and its dimensions - {bias_d, bias_add_shape}}); - - typename gemm_op_t::template arguments_t gemm_arg( - matrix_m, - matrix_k, - matrix_n, - A_d, - matrix_k, - B_d, - matrix_n, - C_d, - matrix_n, - scale_d, - matrix_n, - Acc_d, - Cnt_d, - epilogue_args); - - cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - if (!gemm_op_t::can_implement(gemm_arg)) { - std::cout << "The arguments cannot be supported, aborting ... " - << std::endl; - FAIL(); - } - - size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; - profiling_helper prof("dequantize_gemm", ops, "gflops"); - try { - for (int i = 0; i < iter; i++) { - prof.cpu_start(); - auto e_esimd = queue.submit([&](handler& cgh) { - cgh.parallel_for( - nd_range, [=](nd_item<3> item) SYCL_ESIMD_KERNEL { - // allocate slm and nbarrier resource - slm_barrier_init(); - gemm_op_t gemm_op; - gemm_op(item, gemm_arg); - }); - }); - e_esimd.wait(); - prof.cpu_end(); - prof.add_gpu_event(e_esimd); - } - } catch (cl::sycl::exception const& e) { - std::cout << "SYCL exception caught: " << e.what() << '\n'; - FAIL(); - } - - // performance - prof.print_profiling_result(profiling_selector::GPU); - - std::vector dequantize_b(matrix_k * matrix_n, 0); - for (uint32_t i = 0; i < matrix_k / dequant_s; i++) { - for (uint32_t j = 0; j < matrix_n / 2; j++) { - int start_in = i * dequant_s * matrix_n / 2 + j; - int start_out = i * dequant_s * matrix_n + j * 2; - int start_scale = i * size_scale_n + j * 2; - for (uint32_t ii = 0; ii < dequant_s; ii++) { - uint8_t data_in = B_h[start_in + ii * matrix_n / 2]; - int8_t data_0 = int8_t(data_in & 0x0f) - 8; - int8_t data_1 = int8_t(data_in >> 4) - 8; - dequantize_b[start_out + ii * matrix_n] = - fp16(data_0) * scale_h[start_scale]; - dequantize_b[start_out + ii * matrix_n + 1] = - fp16(data_1) * scale_h[start_scale + 1]; - } - } - } - - queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); - ASSERT_EQ( - 0, - gemm_result_validate( - A_h, dequantize_b.data(), C_h, bias_h, matrix_m, matrix_k, matrix_n)); - - free(A_h, context); - free(B_h, context); - free(C_h, context); - free(scale_h, context); - free(zero_pt_h, context); - free(A_d, context); - free(B_d, context); - free(C_d, context); - free(scale_d, context); - free(zero_pt_d, context); - free(Acc_h, context); - free(Cnt_h, context); - free(Acc_d, context); - free(Cnt_d, context); -} - -template -class dequantize_gemm_test : public ::testing::Test {}; -TYPED_TEST_SUITE_P(dequantize_gemm_test); - -TYPED_TEST_P(dequantize_gemm_test, esimd) { - dequantize_gemm_run(ITER); -} - -REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd); -using tests = ::testing::Types; -// using tests = ::testing::Types; - -INSTANTIATE_TYPED_TEST_SUITE_P( - dequantize_gemm_test_suite, - dequantize_gemm_test, - tests);