diff --git a/CMakeLists.txt b/CMakeLists.txt index 765bbab88..1c07ef0fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,8 @@ if (NOT CMAKE_BUILD_TYPE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() if(UNIX) + set(CMAKE_C_COMPILER icx) + set(CMAKE_CXX_COMPILER icpx) else() # Windows # Force CMake to use icx-cl rather than the default C++ compiler/linker # (needed on Windows only) @@ -24,7 +26,7 @@ include(CTest) enable_testing() if(UNIX) -list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/tools/cmake") + list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/tools/cmake") endif() find_package(MKL CONFIG REQUIRED) message(STATUS "MKL_VERSION=${MKL_VERSION}") @@ -33,7 +35,7 @@ message(STATUS "MKL_IMPORTED_TARGETS=${MKL_IMPORTED_TARGETS}") # debug option message(STATUS "'DEBUG' is set to " ${DEBUG}) if (${DEBUG}) -add_compile_options(-debug=minimal -Rno-debug-disables-optimization -DDEBUG=${DEBUG}) + add_compile_options(-debug=minimal -Rno-debug-disables-optimization -DDEBUG=${DEBUG}) endif () # log message print @@ -43,20 +45,41 @@ if (${LOG} STREQUAL "on") add_definitions(-DLOG_PRINT) endif () +# For large registers mode, enable 256 registers for kernels +set(XETLA_OFFLINE_OPTIONS "-doubleGRF") +set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt") +set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-codegen") +# Enable bank conflict reduction. +set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -enableBCR") +# Optimization to reduce the tokens used for DPAS instruction. +set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -DPASTokenReduction") + # AOT device -set(AOT_DEVICE "" CACHE STRING "Set device list for AOT build") +set(USE_AOT_DEVLIST "" CACHE STRING "Set device list for AOT build") +if (USE_AOT_DEVLIST) + add_compile_options(-fsycl-targets=spir64_gen) + add_link_options(-fsycl-targets=spir64_gen) + # For registers usage verbose at AOT + set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -printregusage") + set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "-options '${XETLA_OFFLINE_OPTIONS}' -device '${USE_AOT_DEVLIST}'") +else() + set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "${XETLA_OFFLINE_OPTIONS}") +endif() + +add_compile_options(-fsycl -fsycl-device-code-split=per_kernel) +add_compile_options(-Wall -Wextra -Werror) + +include(ProcessorCount) +ProcessorCount(nproc) +add_link_options(-fsycl -fsycl-device-code-split=per_kernel -fsycl-max-parallel-link-jobs=${nproc}) +add_link_options(${XETLA_KERNEL_FLAGS}) -add_compile_options(-fsycl) -add_link_options(-fsycl) if(UNIX) - if (AOT_DEVICE) - add_compile_options(-fsycl-targets=spir64_gen) - add_link_options(-fsycl-targets=spir64_gen -Xs "-device ${AOT_DEVICE}") # MTL - endif() - add_compile_options(-fp-model=precise -Wall -Wextra -Werror) + add_compile_options(-fp-model=precise) add_link_options(-lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lpthread -lm) link_libraries(-lgtest -lgtest_main) else() # Windows + add_compile_options(/fp:precise) add_compile_options(/EHsc) if (CMAKE_BUILD_TYPE STREQUAL "Debug") add_compile_options(/MDd) diff --git a/examples/11_stream_k_gemm/CMakeLists.txt b/examples/11_stream_k_gemm/CMakeLists.txt index 1e09d5bef..88b3527a4 100644 --- a/examples/11_stream_k_gemm/CMakeLists.txt +++ b/examples/11_stream_k_gemm/CMakeLists.txt @@ -1,25 +1,6 @@ set(TARGET stream_k_gemm) -set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl) -set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl-targets=spir64_gen) - -# disable loop invariance optimization, this is for performance -set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt") -# For large registers mode, enable 256 registers for kernels -set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -doubleGRF") -# For registers usage verbose at AOT -set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -printregusage") -# Enable bank conflict reduction. -set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -enableBCR") -# Optimization to reduce the tokens used for DPAS instruction. -set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -DPASTokenReduction") - -set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs) -set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} "-device pvc -options '${XETLA_OFFLINE_OPTIONS} ' ") - #build test add_executable(${TARGET} stream_k_gemm.cpp) -target_link_options(${TARGET} PRIVATE ${XETLA_KERNEL_FLAGS}) # Disable vector combine, to remove redundant loads and stores -#target_compile_options(${TARGET} PRIVATE -mllvm -disable-vector-combine -fsycl -fsycl-targets=spir64_gen) - +# target_compile_options(${TARGET} PRIVATE -mllvm -disable-vector-combine -fsycl -fsycl-targets=spir64_gen) diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 1b302b91a..8c7c56463 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -89,7 +89,7 @@ struct load_store_attr_t template inline constexpr bool arch_has_2d_load_store = - load_store_attr_t::has_hw_block_2d; + load_store_attr_t::has_hw_block_2d; template struct load_store_attr_t { @@ -149,9 +149,19 @@ struct register_nums_t { }; template -struct register_bytes_t { +struct register_bytes_t; +template <> +struct register_bytes_t { static constexpr uint32_t reg_in_bytes = 64; }; +template <> +struct register_bytes_t { + static constexpr uint32_t reg_in_bytes = 32; +}; +template <> +struct register_bytes_t { + static constexpr uint32_t reg_in_bytes = 32; +}; template struct register_attr_t { @@ -188,41 +198,47 @@ struct mma_attr_t>> { template struct arch_attr_t {}; -template -struct client_arch_attr_base_t { +template <> +struct arch_attr_t { template - using load_store_attr = load_store_attr_t; + using load_store_attr = load_store_attr_t; - template - using register_attr = register_attr_t; + template + using register_attr = register_attr_t; - using dpas_attr = dpas_attr_t; + using dpas_attr = dpas_attr_t; static constexpr uint32_t max_wg_num = 64; - static constexpr uint32_t local_mem_size = 64 * 1024; + static constexpr uint32_t local_mem_size = 128 * 1024; }; template <> -struct arch_attr_t { +struct arch_attr_t { template - using load_store_attr = load_store_attr_t; + using load_store_attr = load_store_attr_t; template - using register_attr = register_attr_t; + using register_attr = register_attr_t; - using dpas_attr = dpas_attr_t; + using dpas_attr = dpas_attr_t; static constexpr uint32_t max_wg_num = 64; - static constexpr uint32_t local_mem_size = 128 * 1024; + static constexpr uint32_t local_mem_size = 64 * 1024; }; template <> -struct arch_attr_t - : public client_arch_attr_base_t {}; +struct arch_attr_t { + template + using load_store_attr = load_store_attr_t; -template <> -struct arch_attr_t - : public client_arch_attr_base_t {}; + template + using register_attr = register_attr_t; + + using dpas_attr = dpas_attr_t; + + static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t local_mem_size = 64 * 1024; +}; /// @} xetla_core_arch_config diff --git a/include/common/core/base_types.hpp b/include/common/core/base_types.hpp index 0b52e8a34..33ed26b74 100644 --- a/include/common/core/base_types.hpp +++ b/include/common/core/base_types.hpp @@ -55,6 +55,28 @@ using fp16 = sycl::half; /// using tf32 = sycl::ext::intel::experimental::esimd::tfloat32; +/// @brief mx_fp4(E2M1) data packed as 8bits data type. +struct mx_fp4 { + uint8_t data; + operator uint8_t() const { + return data; + } + mx_fp4() = default; + mx_fp4(uint8_t val) { + data = val; + } +}; + +template +struct get_packed_num { + static constexpr uint32_t value = 1; +}; + +template <> +struct get_packed_num { + static constexpr uint32_t value = 2; +}; + template struct is_host_callable : std::false_type {}; template @@ -66,7 +88,8 @@ struct is_host_callable> template struct is_internal_type { static constexpr bool value = std::is_same, bf16>::value || - std::is_same, tf32>::value; + std::is_same, tf32>::value || + std::is_same, mx_fp4>::value; }; template inline constexpr bool is_internal_type_v = is_internal_type::value; @@ -108,6 +131,12 @@ struct native_type { using type = T; }; +/// @brief Set uint8_t as the native data type of mx_fp4. +template <> +struct native_type { + using type = uint8_t; +}; + /// @brief Return the native data type of T template using native_type_t = typename native_type::type; diff --git a/include/common/core/common.hpp b/include/common/core/common.hpp index 96cbdd14d..e51365797 100644 --- a/include/common/core/common.hpp +++ b/include/common/core/common.hpp @@ -82,15 +82,44 @@ enum class msg_type : uint8_t { // prefetch_1d = 5 }; -/// L1 or L2 cache hint kinds. +/// L1, L2 or L3 cache hints. enum class cache_hint : uint8_t { none = 0, + /// load/store/atomic: do not cache data to cache; uncached = 1, + + // load: cache data to cache; cached = 2, + + /// store: write data into cache level and mark the cache line as "dirty". + /// Upon eviction, the "dirty" data will be written into the furthest + /// subsequent cache; write_back = 3, + + /// store: immediately write data to the subsequent furthest cache, marking + /// the cache line in the current cache as "not dirty"; write_through = 4, + + /// load: cache data to cache using the evict-first policy to minimize cache + /// pollution caused by temporary streaming data that may only be accessed + /// once or twice; + /// store/atomic: same as write-through, but use the evict-first policy + /// to limit cache pollution by streaming; streaming = 5, - read_invalidate = 6 + + /// load: asserts that the cache line containing the data will not be read + /// again until it’s overwritten, therefore the load operation can invalidate + /// the cache line and discard "dirty" data. If the assertion is violated + /// (the cache line is read again) then behavior is undefined. + read_invalidate = 6, + + // TODO: Implement the verification of this enum in check_cache_hint(). + /// load, L2 cache only, next gen GPU after Xe required: asserts that + /// the L2 cache line containing the data will not be written until all + /// invocations of the shader or kernel execution are finished. + /// If the assertion is violated (the cache line is written), the behavior + /// is undefined. + const_cached = 7 }; /// Data size or format to read or store diff --git a/include/common/core/explicit_conv.hpp b/include/common/core/explicit_conv.hpp index 47f6ca845..0c61f12bc 100644 --- a/include/common/core/explicit_conv.hpp +++ b/include/common/core/explicit_conv.hpp @@ -55,6 +55,9 @@ __XETLA_API typename std::enable_if_t< std::is_same::value && std::is_same::value, xetla_vector> xetla_cvt(xetla_vector src) { + // xetla_vector a = src.template bit_cast_view(); + // xetla_vector c = a >> 16; + // return c.xetla_format(); xetla_vector dst = src; return dst; } @@ -68,6 +71,10 @@ __XETLA_API typename std::enable_if_t< std::is_same::value && std::is_same::value, xetla_vector> xetla_cvt(xetla_vector src) { + // xetla_vector a = src.template bit_cast_view(); + // xetla_vector b = a; + // auto c = b << 16; + // return c.xetla_format(); xetla_vector dst = src; return dst; } @@ -139,6 +146,97 @@ xetla_cvt(xetla_vector src, float scaling_value) { return dst; } +/// @brief xetla explicit data conversion, fp16->mx_fp4. +/// @tparam T_src is the float16 data type. +/// @tparam T_dst is the mx_fp4(E2M1) data type. +/// @tparam N is the element number in xetla_vector. +template +__XETLA_API typename std::enable_if_t< + std::is_same::value && std::is_same::value, + xetla_vector::value>> +xetla_cvt( + xetla_vector src, + xetla_vector rand_vec = 0x100) { + /*********prepare, 4 instructions******/ + xetla_vector src_abs; + src_abs.xetla_format() = src.xetla_format() & 0x7fff; + xetla_vector sign = + (src.xetla_format() & 0x8000) >> 12; + // only compare 9bits mantissa + rand_vec = rand_vec & 0x1ff; + + xetla_vector src_abs_carried; + // if src_abs is 0.3, then it only has 30% possibility to carry (rand_vec + // >=0.7) inf If it is 0_11110_1xxxxxxxxx, and carry, then it will become + // 0_11111_0, and finally round to 6 If it is 0_11111_0000000000, no carry, + // then it will become 0_11111_0, and finally round to 6 nan If it is + // 0_11111_1xxxxxxxxx, and carry, then it will become 1_00000_0, and finally + // round to 0 If it is 0_11111_1xxxxxxxxx, no carry, then it will become + // 0_11111_1, and finally round to 6 If it is 0_11111_0xxxxxxxxx, and carry, + // then it will become 0_11111_1, and finally round to 6 If it is + // 0_11111_0xxxxxxxxx, no carry, then it will become 0_11111_0, and finally + // round to 4 subnormal If it is 0_00000_1xxxxxxxxx, and carry, then it will + // become 0_00001_0, and finally round to 0 + + // clean the low 9bits to make sure inf will not become nan + /*********rounding, 1 instruction******/ + src_abs_carried.xetla_format() = + src_abs.xetla_format() + rand_vec.xetla_format(); + + // dst = 0_00_0 + // if src_abs_carried >= 0.5(0_01110_0) => dst = 0_00_1 + // if src_abs_carried >= 1 (0_01111_0) => dst = 0_01_0 + // if src_abs_carried >= 1.5(0_01111_1) => dst = 0_01_1 + // if src_abs_carried >= 2 (0_10000_0) => dst = 0_10_0 + // if src_abs_carried >= 3 (0_10000_1) => dst = 0_10_1 + // if src_abs_carried >= 4 (0_10001_0) => dst = 0_11_0 + // if src_abs_carried >= 6 (0_10001_1) => dst = 0_11_1 + /*********common path, 5 instructions******/ + xetla_vector dst = + ((src_abs_carried.xetla_format() >> 9) & 3) | + ((src_abs_carried.xetla_format() >> 12) & 4); + + /*********handle conner case, 6 instructions******/ + // if src_abs_carried is nan, there flags should all be false, only go with + // common path still srnd for subnormal case + xetla_mask zero_flag = src_abs_carried < 0.5; + xetla_mask subnormal_flag = src_abs_carried < 1; + xetla_mask saturate_flag = src_abs >= 6; + dst.xetla_merge(0x7, saturate_flag); + // subnormal_flag should prior to zero_flag + dst.xetla_merge(0x1, subnormal_flag); + dst.xetla_merge(0x0, zero_flag); + + /*********add sign bit, 1 instructions******/ + dst |= sign; + + /*********pack data, 7 instructions******/ + // T_dst is uint8_t, get_packed_num::value is 2 + xetla_vector::value> out; + auto out_u16 = out.xetla_format(); + out_u16 = dst.xetla_select(0); + out_u16 |= dst.xetla_select(1) << 4; + out_u16 |= dst.xetla_select(2) << 8; + out_u16 |= dst.xetla_select(3) << 12; + + return out; +} + +/// @brief xetla explicit data conversion, fp32->mx_fp4. +/// @tparam T_src is the float32 data type. +/// @tparam T_dst is the mx_fp4(E2M1) data type. +/// @tparam N is the element number in xetla_vector. +template +__XETLA_API typename std::enable_if_t< + std::is_same::value && std::is_same::value, + xetla_vector::value>> +xetla_cvt( + xetla_vector src, + xetla_vector rand_vec = 0x100) { + xetla_vector src_f16 = src; + return xetla_cvt(src_f16, rand_vec); +} + /// @brief xetla explicit data conversion, same type. /// @tparam T_dst is the dst data type. /// @tparam T_src is the src data type. @@ -153,4 +251,4 @@ xetla_cvt(xetla_vector src) { /// @} xetla_core_conv -} // namespace gpu::xetla \ No newline at end of file +} // namespace gpu::xetla diff --git a/include/common/core/math_general.hpp b/include/common/core/math_general.hpp index 0c15c8946..54f4e1a2f 100644 --- a/include/common/core/math_general.hpp +++ b/include/common/core/math_general.hpp @@ -508,6 +508,40 @@ __XETLA_API T xetla_tanh(T src) { return (src >= 10) ? 1 : ret; } +/// @brief Calculate sigmoid value for each element of the input vector. +/// @tparam T element type of the input and return vectors. +/// @tparam SZ size of the input and returned vectors. +/// @param src the input vector. +/// @return vector of sigmoid of component-wise elements. +template +__XETLA_API xetla_vector xetla_sigmoid(xetla_vector src) { + static_assert( + (std::is_same, float>::value) || + (std::is_same, fp16>::value), + "Only support fp32 and fp16"); + xetla_mask mask = src <= -10; + xetla_vector exp = xetla_exp(-src); + xetla_vector ret_sub = 1.f / (exp + 1.f); + ret_sub.xetla_merge(0, mask); + + return ret_sub; +} + +/// @brief Calculate sigmoid of a scalar. +/// @tparam T element type of the input and return a scalar. +/// @param src the scalar value. +/// @return sigmoid value. +template +__XETLA_API T xetla_sigmoid(T src) { + static_assert( + (std::is_same, float>::value) || + (std::is_same, fp16>::value), + "Only support fp32 and fp16"); + T exp = xetla_exp(-src); + T ret = 1.f / (exp + 1.f); + return (src <= -10) ? 0 : ret; +} + /// Add two unsigned integer vectors, return the result and in-place update the /// carry. /// @tparam T element type of the src, should be uint32_t. @@ -556,6 +590,7 @@ __XETLA_API xetla_vector xetla_add_c( #else xetla_vector out = __ESIMD_ENS::addc(carry_tmp, src0, src1); #endif + carry = carry_tmp; return out; } @@ -663,6 +698,27 @@ __XETLA_API xetla_vector xetla_sat(xetla_vector src) { return __ESIMD_NS::saturate(src); } +/// Count number of bits set in the source operand per element. +/// @param src0 the source operand to count bits in. +/// @return a vector of \c uint32_t, where each element is set to bit count of +/// the corresponding element of the source operand. +template +__XETLA_API std::enable_if_t< + std::is_integral::value && (sizeof(T) <= 4), + xetla_vector> +xetla_cbit(xetla_vector src) { + return __ESIMD_NS::cbit(src); +} + +/// Scalar version of \c cbit - both input and output are scalars rather +/// than vectors. +template +__XETLA_API std:: + enable_if_t::value && (sizeof(T) <= 4), uint32_t> + xetla_cbit(T src) { + return __ESIMD_NS::cbit(src); +} + /// @} xetla_core_math } // namespace gpu::xetla diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index 9ed22b8af..0bc360d6e 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -32,8 +32,26 @@ namespace detail { /// @brief lookup table for cache hint. /// /// -constexpr __ESIMD_ENS::cache_hint get_cache_hint(gpu::xetla::cache_hint ch) { +constexpr auto get_cache_hint(gpu::xetla::cache_hint ch) { switch (ch) { +#if __INTEL_LLVM_COMPILER >= 20240100 + case gpu::xetla::cache_hint::none: + return __ESIMD_NS::cache_hint::none; + case gpu::xetla::cache_hint::uncached: + return __ESIMD_NS::cache_hint::uncached; + case gpu::xetla::cache_hint::cached: + return __ESIMD_NS::cache_hint::cached; + case gpu::xetla::cache_hint::write_back: + return __ESIMD_NS::cache_hint::write_back; + case gpu::xetla::cache_hint::write_through: + return __ESIMD_NS::cache_hint::write_through; + case gpu::xetla::cache_hint::streaming: + return __ESIMD_NS::cache_hint::streaming; + case gpu::xetla::cache_hint::read_invalidate: + return __ESIMD_NS::cache_hint::read_invalidate; + case gpu::xetla::cache_hint::const_cached: + return __ESIMD_NS::cache_hint::const_cached; +#else case gpu::xetla::cache_hint::none: return __ESIMD_ENS::cache_hint::none; case gpu::xetla::cache_hint::uncached: @@ -48,6 +66,7 @@ constexpr __ESIMD_ENS::cache_hint get_cache_hint(gpu::xetla::cache_hint ch) { return __ESIMD_ENS::cache_hint::streaming; case gpu::xetla::cache_hint::read_invalidate: return __ESIMD_ENS::cache_hint::read_invalidate; +#endif } } @@ -288,6 +307,57 @@ __XETLA_API void xetla_prefetch_global(Ty* p, uint64_t offset = 0) { gpu::xetla::detail::get_cache_hint(L2H)>((T*)p + (offset / sizeof(T))); } +/// simd block_load(const T* ptr, size_t byte_offset, +/// props={}); // (usm-bl-2) +/// This function loads a contiguous memory block from address referenced +/// by USM pointer \p ptr and the given \p byte_offset. +/// +/// There may be temporary restrictions depending on L1, L2 cache hints, +/// See details in the 'Restrictions' section below. The restrictions will be +/// relaxed in the future. +/// +/// The parameter \p props specifies the optional compile-time properties +/// of the type esimd::properties and may include esimd::cache_hint_L1, +/// esimd::cache_hint_L2, esimd::alignment. Other properties are ignored. +/// +/// Cache hints: If \p props does not specify any L1 or L2 cache hints, then +/// the cache_hint::none value is assumed by default. +/// +/// Alignment: If \p props does not specify the 'alignment' property, then +/// the default assumed alignment is 4-bytes for 4-byte or smaller elements +/// and 8-bytes for 8-byte elements. The address may be element-size aligned +/// even for byte- and word-elements, but in such case the smaller alignment +/// property must explicitly passed to this function. Extra restrictions +/// may be in place - see Restrictions/R1 below. +/// +/// Restrictions - cache hint imposed - temporary: +/// If L1 or L2 cache hint is passed, then: +/// R1: The pointer must be at least 4-byte aligned for elements of 4-bytes or +/// smaller and 8-byte aligned for 8-byte elements. +/// R2: The number of elements for 8-byte data: 1, 2, 3, 4, 8, 16, 32, 64; +/// for 4-byte data: 1, 2, 3, 4, 8, 16, 32, 64, +/// or 128(only if alignment is 8-bytes or more); +/// for 2-byte data: 2, 4, 6, 8, 16, 32, 64, 128, +/// or 256(only if alignment is 8-bytes or more); +/// for 1-byte data: 4, 8, 12, 16, 32, 64, 128, 256, +/// or 512(only if alignment is 8-bytes or more). +/// R3: The target device must be DG2, PVC or newer GPU. +template < + typename T, + int N, + cache_hint L1H = cache_hint::none, + cache_hint L2H = cache_hint::none, + int alignment = 16> +__XETLA_API xetla_vector xetla_load_global( + const T* ptr, + size_t byte_offset) { + __ESIMD_NS::properties props{ + __ESIMD_NS::cache_hint_L1, + __ESIMD_NS::cache_hint_L2, + __ESIMD_NS::alignment}; + return __ESIMD_NS::block_load(ptr, byte_offset, props); +} + /// @brief Stateless scattered load. /// Collects elements located at specified address and returns them /// to a single \ref xetla_vector object. @@ -335,47 +405,6 @@ __XETLA_API xetla_vector xetla_load_global( N>((T*)p, offsets, pred); } -/// @brief Stateless block load (transposed gather with 1 channel). -/// Collects elements located at specified address and returns them -/// to a single \ref xetla_vector object. -/// -/// Supported platforms: DG2, PVC -/// -/// VISA instruction: lsc_load.ugm -/// -/// @tparam Ty is element type. -/// @tparam NElts is the number of elements to load per address (i.e. -/// vector_size per SIMD channel). -/// @tparam DS is the data size. -/// @tparam L1H is L1 cache hint. -/// @tparam L2H is L2 cache hint. -/// @param p [in] is the base pointer. -/// @param offset [in] is the zero-based offset in bytes. -/// @return is a xetla_vector of type T and size NElts. -/// -template < - typename Ty, - uint8_t NElts = 1, - data_size DS = data_size::default_size, - cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none> -__XETLA_API xetla_vector xetla_load_global( - Ty* p, - uint64_t offset = 0) { - using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d::template check_restriction( - offset, (uint64_t)p)); - - return __ESIMD_ENS::lsc_block_load< - T, - NElts, - gpu::xetla::detail::get_data_size(DS), - gpu::xetla::detail::get_cache_hint(L1H), - gpu::xetla::detail::get_cache_hint(L2H)>((T*)p + (offset / sizeof(T))); -} - /// @brief Stateless scattered store. /// Writes elements to specific address. /// @@ -418,41 +447,55 @@ __XETLA_API void xetla_store_global( N>((T*)p, offsets, vals, pred); } -/// @brief Stateless block store (transposed scatter with 1 channel). -/// Writes elements to specific address. -/// -/// Supported platforms: DG2, PVC -/// -/// VISA instruction: lsc_store.ugm -/// -/// @tparam Ty is element type. -/// @tparam NElts is the number of elements to store per address (i.e. -/// vector_size per SIMD channel). -/// @tparam DS is the data size. -/// @tparam L1H is L1 cache hint. -/// @tparam L2H is L2 cache hint. -/// @param p [in] is the base pointer. -/// @param offset [in] is the zero-based offset in bytes. -/// @param vals [in] is values to store. -/// +/// void block_store(T* ptr, size_t byte_offset, // (usm-bs-2) +/// simd vals, props={}); +/// This function stores a contiguous memory block to USM pointer \p ptr and +/// byte-offset \p byte_offset with data specified by \p vals. +/// +/// There may be temporary restrictions depending on L1, L2 cache hints, +/// See details in the 'Restrictions' section below. The restrictions will be +/// relaxed in the future. +/// +/// The parameter \p props specifies the optional compile-time properties +/// of the type esimd::properties and may include esimd::cache_hint_L1, +/// esimd::cache_hint_L2, esimd::alignment. Other properties are ignored. +/// +/// Cache hints: If \p props does not specify any L1 or L2 cache hints, then +/// the cache_hint::none value is assumed by default. +/// +/// Alignment: If \p props does not specify the 'alignment' property, then +/// the default assumed alignment is 16 bytes if \p props does not specify any +/// L1 or L2 cache hints, and the minimally required element-size +/// alignment otherwise. Note that additional/temporary restrictions may apply +/// (see Restrictions below). +/// +/// Restrictions - cache hint imposed - temporary: +/// If L1 or L2 cache hint is passed, then: +/// R1: The pointer plus byte offset must be at least 4-byte aligned for +/// elements of 4-bytes or smaller and 8-byte aligned for 8-byte elements. +/// R2: The number of elements for 8-byte data: 1, 2, 3, 4, 8, 16, 32, 64; +/// for 4-byte data: 1, 2, 3, 4, 8, 16, 32, 64, +/// or 128(only if alignment is 8-bytes or more); +/// for 2-byte data: 2, 4, 6, 8, 16, 32, 64, 128, +/// or 256(only if alignment is 8-bytes or more); +/// for 1-byte data: 4, 8, 12, 16, 32, 64, 128, 256, +/// or 512(only if alignment is 8-bytes or more). +/// R3: The target device must be DG2, PVC or newer GPU. template < - typename Ty, - uint8_t NElts = 1, - data_size DS = data_size::default_size, + typename T, + int N, cache_hint L1H = cache_hint::none, - cache_hint L2H = cache_hint::none> + cache_hint L2H = cache_hint::none, + int alignment = 16> __XETLA_API void xetla_store_global( - Ty* p, - uint64_t offset, - xetla_vector vals) { - using T = native_type_t; - __ESIMD_ENS::lsc_block_store< - T, - NElts, - gpu::xetla::detail::get_data_size(DS), - gpu::xetla::detail::get_cache_hint(L1H), - gpu::xetla::detail::get_cache_hint(L2H)>( - (T*)p + (offset / sizeof(T)), vals); + T* ptr, + size_t byte_offset, + xetla_vector vals) { + __ESIMD_NS::properties props{ + __ESIMD_NS::cache_hint_L1, + __ESIMD_NS::cache_hint_L2, + __ESIMD_NS::alignment}; + __ESIMD_NS::block_store(ptr, byte_offset, vals, props); } /// @brief Stateless scattered atomic (0 src). diff --git a/include/experimental/group/dropout_mask_gen.hpp b/include/experimental/group/dropout_mask_gen.hpp index 9713f26ad..fce6d2366 100644 --- a/include/experimental/group/dropout_mask_gen.hpp +++ b/include/experimental/group/dropout_mask_gen.hpp @@ -117,12 +117,9 @@ struct mask_gen_t { uint32_t sg_idx, uint32_t sg_idy, uint32_t linear_idx) { - xetla_vector rand_offset_ptr_v = xetla_load_global< - uint64_t, - 1, - data_size::default_size, - cache_hint::cached, - cache_hint::cached>(args->rand_offset_ptr, 0); + xetla_vector rand_offset_ptr_v = + xetla_load_global( + args->rand_offset_ptr, 0); uint32_t threshold = uint32_t(args->dropout_prob * float(4294967296)); mask_out_tile_t mask_out; int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m; diff --git a/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp b/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp index 1e2dd7289..3357dac5d 100644 --- a/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp +++ b/include/experimental/group/fused_op/layer_norm_fused_op_fwd_xe.hpp @@ -501,12 +501,9 @@ struct ln_fwd_fused_op_t< uint32_t sg_idy, uint32_t start_m) { int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n; - xetla_vector rand_offset_ptr_v = xetla_load_global< - uint64_t, - 1, - data_size::default_size, - cache_hint::cached, - cache_hint::cached>(args->rand_offset_ptr, 0); + xetla_vector rand_offset_ptr_v = + xetla_load_global( + args->rand_offset_ptr, 0); mat_ld = args->mat_ld; mask_ld = args->mask_ld; matrix_n = args->matrix_n; @@ -668,12 +665,9 @@ struct ln_fwd_fused_op_t< uint32_t sg_idy, uint32_t start_m) { int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n; - xetla_vector rand_offset_ptr_v = xetla_load_global< - uint64_t, - 1, - data_size::default_size, - cache_hint::cached, - cache_hint::cached>(args->rand_offset_ptr, 0); + xetla_vector rand_offset_ptr_v = + xetla_load_global( + args->rand_offset_ptr, 0); mask_ld = args->mask_ld; matrix_m = args->matrix_m; matrix_n = args->matrix_n; diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 79793ebf8..8d7ffe33d 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -79,12 +79,6 @@ struct compute_policy_int4_dequantize< static constexpr bool is_int4_matB_policy = true; - static constexpr uint32_t block_size_y_a = 16; - using mma_attr = mma_attr_t; - static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; - static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; - static constexpr uint32_t block_bytes_y_b = block_bytes_x_a; - static constexpr uint32_t dequant_s = dequant_s_; static_assert( (dequant_s % (32 / sizeof(dtype_mma_b))) == 0, @@ -92,6 +86,23 @@ struct compute_policy_int4_dequantize< using dtype_scale = dtype_scale_; using dtype_zero_pt = dtype_zero_pt_; static constexpr quant_mode quant_type = quant_type_; + + static constexpr uint32_t block_size_y_a = 16; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_bytes_x_a = + (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + static constexpr uint32_t block_size_x_a = + block_bytes_x_a / sizeof(dtype_mma_a); + static constexpr uint32_t block_size_x_b = + (mma_engine == mma_engine::xmx) ? mma_attr::mma_n_in_elem : 32; + static constexpr uint32_t block_bytes_y_b = + (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + static constexpr uint32_t block_size_y_b = + block_bytes_y_b / sizeof(dtype_mma_b); + + static_assert( + block_bytes_x_a == block_bytes_y_b, + "mat_a x need to match with mat_b y"); }; } // namespace gpu::xetla::group diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 35f224334..528911fcf 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -118,7 +118,7 @@ class gemm_t< static_assert(!is_col_major_b, "only support MatB row-major for now"); static_assert( (!is_local_a) && (!is_local_b), - "only support from global memory for now"); + "only support from global memory for now"); static constexpr uint32_t stages = compute_policy::stages; static constexpr uint32_t sync_freq = compute_policy::sync_freq; @@ -130,16 +130,23 @@ class gemm_t< static constexpr uint32_t tile_size_y_b = k_stride; static constexpr uint32_t tile_size_x_c = sg_tile_n; static constexpr uint32_t tile_size_y_c = sg_tile_m; + static constexpr uint32_t block_size_x_a = - compute_policy::block_bytes_x_a / sizeof(dtype_mma_a); + (compute_policy::block_size_x_a > tile_size_x_a) + ? tile_size_x_a + : compute_policy::block_size_x_a; static constexpr uint32_t block_size_y_a = (compute_policy::block_size_y_a > tile_size_y_a) ? tile_size_y_a : compute_policy::block_size_y_a; - - static constexpr uint32_t block_size_x_b = compute_policy::block_size_x_b; + static constexpr uint32_t block_size_x_b = + (compute_policy::block_size_x_b > tile_size_x_b) + ? tile_size_x_b + : compute_policy::block_size_x_b; static constexpr uint32_t block_size_y_b = - compute_policy::block_bytes_y_b / sizeof(dtype_mma_b); + (compute_policy::block_size_y_b > tile_size_y_b) + ? tile_size_y_b + : compute_policy::block_size_y_b; /******** set tile **********/ static constexpr bool is_vnni_tiled_a = @@ -189,7 +196,8 @@ class gemm_t< tile_size_y_b, block_size_x_b, block_size_y_b, - reg_layout::tiled>; + compute_policy::mma_engine == mma_engine::xmx ? reg_layout::vnni_tiled + : reg_layout::tiled>; using matB_acc_t = subgroup::tile_t; public: diff --git a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp index 2847faf0e..0395a77fd 100644 --- a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp +++ b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp @@ -135,7 +135,6 @@ struct col_major_shuf_t< auto gidx = xetla_load_global< uint32_t, block_size_x, - data_size::default_size, cache_hint::cached, cache_hint::cached>( args.gidx_ptr, gidx_payload.base_offset + block_x * block_size_x); diff --git a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp index ce11a41e7..4cbd1498a 100644 --- a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp +++ b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp @@ -261,7 +261,6 @@ struct xetla_data_transformer< xetla_vector local_scale = xetla_load_global< dtype_compute, 1, - data_size::default_size, cache_hint::cached, cache_hint::cached>(args->scale, offset); diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 4b8c595cb..6c98df456 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -294,7 +294,6 @@ class gemm_universal_t< return *this; } }; - template <> struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x diff --git a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp index 2ee2c269a..f4c2fff61 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp @@ -333,13 +333,11 @@ struct layer_norm_bwd_t< xetla_vector mu_v = xetla_load_global< dtype_acc, 1, - data_size::default_size, cache_hint::cached, cache_hint::cached>(args->mu_ptr, row * sizeof(dtype_acc)); xetla_vector rs_v = xetla_load_global< dtype_acc, 1, - data_size::default_size, cache_hint::cached, cache_hint::cached>(args->rs_ptr, row * sizeof(dtype_acc)); dtype_acc mu = mu_v[0]; diff --git a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp index 9ddbfdfff..08ad1eaf3 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp @@ -320,7 +320,7 @@ struct layer_norm_fwd_t< itr_count += 1; nbarrier.wait(); - xetla_vector mu_m2_vec = + xetla_vector mu_m2_vec = xetla_load_local(slm_load_base); xetla_vector mu_vec = mu_m2_vec.xetla_select(0); @@ -338,7 +338,6 @@ struct layer_norm_fwd_t< xetla_store_global< dtype_acc, 1, - data_size::default_size, cache_hint::write_back, cache_hint::write_back>( args->mu_ptr, @@ -347,7 +346,6 @@ struct layer_norm_fwd_t< xetla_store_global< dtype_acc, 1, - data_size::default_size, cache_hint::write_back, cache_hint::write_back>( args->rs_ptr, diff --git a/include/group/epilogue/epilogue_policy.hpp b/include/group/epilogue/epilogue_policy.hpp index d8b43500f..2ee4147d0 100644 --- a/include/group/epilogue/epilogue_policy.hpp +++ b/include/group/epilogue/epilogue_policy.hpp @@ -69,6 +69,18 @@ template struct epilogue_policy_unaligned { static constexpr gpu_arch arch_tag = arch_tag_; }; + +/// @brief Epilogue policy for storing with stride into NHWC space from NPQC +/// descriptor. +/// @tparam arch_tag_ Is the HW architecture. +/// @tparam stride_h_ Is the stride in H dimension. +/// @tparam stride_w_ Is the stride in W dimension. +template +struct epilogue_policy_strided { + static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr uint32_t stride_w = stride_w; + static constexpr uint32_t stride_h = stride_h; +}; /// @} xetla_epilogue } // namespace gpu::xetla::group diff --git a/include/group/epilogue/impl/default_xe.hpp b/include/group/epilogue/impl/default_xe.hpp index 8a822443e..ab149396a 100644 --- a/include/group/epilogue/impl/default_xe.hpp +++ b/include/group/epilogue/impl/default_xe.hpp @@ -105,6 +105,125 @@ class epilogue_t< } }; +/// @brief Is the epilogue functor specialized for epilogue_policy_default and +/// Xe architecture. +template +class epilogue_t< + epilogue_policy_default, + tile_shape_, + mem_desc_c_t_, + std::enable_if_t<( + (arch_tag_ <= gpu_arch::XeHpc) && (mem_desc_c_t_::dim == 4))>> { + public: + using epilogue_policy = epilogue_policy_default; + using tile_shape = tile_shape_; + using mem_desc_c_t = mem_desc_c_t_; + static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr uint32_t barrier_count = 0; + static constexpr uint32_t slm_size = 0; + /// @brief Epilogue arguments. + struct arguments_t {}; + + private: + using work_group_t = typename tile_shape::work_group_t; + static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_n; + static constexpr uint32_t wg_tile_p = tile_shape::wg_tile_size_p; + static constexpr uint32_t wg_tile_q = tile_shape::wg_tile_size_q; + static constexpr uint32_t wg_tile_k = tile_shape::wg_tile_size_k; + + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_n; + static constexpr uint32_t sg_tile_p = tile_shape::sg_tile_size_p; + static constexpr uint32_t sg_tile_q = tile_shape::sg_tile_size_q; + static constexpr uint32_t sg_tile_k = tile_shape::sg_tile_size_k; + + static constexpr uint32_t wg_size_n = tile_shape::wg_size_n; + static constexpr uint32_t wg_size_p = tile_shape::wg_size_p; + static constexpr uint32_t wg_size_q = tile_shape::wg_size_q; + static constexpr uint32_t wg_size_k = tile_shape::wg_size_k; + using dtype_c = typename mem_desc_c_t::dtype; + static constexpr mem_layout mem_layout_c = mem_desc_c_t::layout; + static constexpr mem_space mem_space_c = mem_desc_c_t::space; + static constexpr msg_type msg_type_c = + (mem_space_c == mem_space::global ? msg_type::block_2d + : msg_type::scatter); + /// @brief Updates tile base descriptor based on the tid. + __XETLA_API static void update_sg_tile_tdesc( + work_group_t& g, + mem_desc_c_t& mem_desc_c) { + int32_t sg_idk = g.get_id() % wg_size_k; + int32_t sg_idq = (g.get_id() / wg_size_k) % wg_size_q; + int32_t sg_idp = (g.get_id() / (wg_size_k * wg_size_q)) % wg_size_p; + int32_t sg_idn = + (g.get_id() / (wg_size_k * wg_size_q * wg_size_p)) % wg_size_n; + + int32_t tile_offset_n = sg_idn * sg_tile_n; + int32_t tile_offset_p = sg_idp * sg_tile_p; + int32_t tile_offset_q = sg_idq * sg_tile_q; + int32_t tile_offset_k = sg_idk * sg_tile_k; + mem_desc_c.update_coord( + tile_offset_k, tile_offset_q, tile_offset_p, tile_offset_n); + } + + public: + static constexpr bool is_2d_block_c = (msg_type_c == msg_type::block_2d); + + /// @brief Default epilogue. + /// 1) Convert dtype_acc to dtype_c 2) Overwrite to memory. + /// @tparam matAcc_t Is the type of the input tile. + /// @param g Is the workgroup of the current tile. + /// @param matAcc Is the input tile. + /// @param mem_desc_c Is the memory description of matC, including base, shape + /// and coordinate. + /// @param args Is the additional arguments for epilogue. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + template < + typename matAcc_t, + int _sg_tile_n = sg_tile_n, + int _sg_tile_p = sg_tile_p> + __XETLA_API KERNEL_FUNC void operator()( + [[maybe_unused]] work_group_t& g, + [[maybe_unused]] matAcc_t matAcc[_sg_tile_n][_sg_tile_p], + [[maybe_unused]] mem_desc_c_t mem_desc_c, + [[maybe_unused]] arguments_t args = {}, + [[maybe_unused]] uint32_t slm_base = 0, + [[maybe_unused]] uint32_t nbarrier_base = 0) { + using mat_tile_desc = typename matAcc_t::tile_desc; + using matC_t = subgroup::tile_t; + using matC_payload_t = subgroup::mem_payload_t< + mem_desc_t, + mat_tile_desc, + msg_type_c, + arch_tag>; + + update_sg_tile_tdesc(g, mem_desc_c); + +#pragma unroll + for (uint32_t n = 0; n < _sg_tile_n; n++) { +#pragma unroll + for (uint32_t p = 0; p < _sg_tile_p; p++) { + matC_t matC; + matC_payload_t matC_payload; + + matC_payload.init(mem_desc_c.get_tdesc()); + + subgroup::elemwise_cvt(matC, matAcc[n][p]); + + int32_t offset_n = mem_desc_c.get_base_offset_from_w(n); + int32_t offset_p = mem_desc_c.get_base_offset_from_z(p); + int mask_n = mem_desc_c.get_mask_from_w(n); + int mask_p = mem_desc_c.get_mask_from_z(p); + int32_t base_offset = (offset_n + offset_p) * sizeof(dtype_c); + + matC_payload.update_tdesc_base_address_masked( + base_offset, mask_n & mask_p); + + subgroup::tile_store( + matC, matC_payload); + } + } + } +}; /// @} xetla_epilogue } // namespace gpu::xetla::group diff --git a/include/group/epilogue/impl/stream_k_op_xe.hpp b/include/group/epilogue/impl/stream_k_op_xe.hpp index bb0b21c0f..ea60bee2f 100644 --- a/include/group/epilogue/impl/stream_k_op_xe.hpp +++ b/include/group/epilogue/impl/stream_k_op_xe.hpp @@ -33,9 +33,10 @@ template < typename tile_shape_, typename epilogue_t_, typename mem_desc_d_t_, - typename mem_desc_atomic_sync_t_> + typename mem_desc_atomic_sync_t_, + gpu_arch arch_tag_ = gpu_arch::XeHpc> struct epilogue_stream_k_t { - static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + static constexpr gpu_arch arch_tag = arch_tag_; using epilogue_t = epilogue_t_; using mem_desc_d_t = mem_desc_d_t_; using mem_desc_c_t = typename epilogue_t::mem_desc_c_t; @@ -65,8 +66,8 @@ struct epilogue_stream_k_t { // Use special residual op for finishing SK groups to read from scratchspace // buffer and reduce in GRF; They also store zeros in scratchspace buffer - using residual_op_t = - subgroup::elemwise_reduce_op_stream_k_t; + using residual_op_t = subgroup:: + elemwise_reduce_op_stream_k_t; using residual_op_args_t = typename residual_op_t::arguments_t; static constexpr mem_layout mem_layout_d = mem_desc_d_t::layout; @@ -171,8 +172,8 @@ struct epilogue_stream_k_t { 16, cache_hint::uncached, cache_hint::write_back, - atomic_op::iadd>( - (uint64_t)flag_pointer, flag_offsets, signal_val, pred); + atomic_op::iadd, + arch_tag>((uint64_t)flag_pointer, flag_offsets, signal_val, pred); } } else { diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index fe02cb758..0a0cd1c91 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -123,9 +123,8 @@ struct compute_policy_default_fpu< static constexpr uint32_t block_bytes_x_a = 32; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_bytes_x_b = - arch_tag_ == gpu_arch::XeLpg ? 32 : 64; + arch_attr_t::template register_attr<>::reg_in_bytes; static constexpr uint32_t block_size_x_b = block_bytes_x_b / sizeof(dtype_mma_b); static constexpr uint32_t block_size_y_b = block_size_x_a; diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index 34956a328..add8e6790 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -380,9 +380,6 @@ class gemm_t< SW_BARRIER(); subgroup::tile_load( matA, matA_payload); - if constexpr (!is_col_major_a) - reorder_matA(matA); - subgroup::tile_load( matB, matB_payload); matA_payload.template update_tdesc(matA_t::tile_size_x); @@ -421,20 +418,6 @@ class gemm_t< } private: - inline void reorder_matA(matA_t& matA) { - constexpr uint32_t num_block_x = tile_size_x_a / block_size_x_a; - constexpr uint32_t num_block_y = tile_size_y_a / block_size_y_a; - for (uint32_t i = 0; i < num_block_y * num_block_x; i++) { - auto dst_blk = matA.reg.xetla_select( - i * matA_t::block_elems); - xetla_vector trans_blk; - for (uint32_t j = 0; j < block_size_y_a; j++) { - trans_blk.xetla_select(j) = - dst_blk.xetla_select(j * block_size_x_a); - } - dst_blk = trans_blk; - } - } /// @brief Updates tile base descriptor based on the tid. __XETLA_API static void update_sg_tile_tdesc( arguments_t& args, diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index ee3bf3c2b..75a0ef79c 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -126,8 +126,6 @@ class gemm_t< /******** set tile **********/ static constexpr reg_layout reg_layout_a = reg_layout::tiled; - - public: using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, tile_size_y_a, @@ -168,6 +166,7 @@ class gemm_t< wg_size_y, arch_tag>; + public: using matAcc_tile_desc_t = subgroup::tile_desc_t< tile_size_x_c, tile_size_y_c, diff --git a/include/group/gemm/impl/unaligned_xmx_xe.hpp b/include/group/gemm/impl/unaligned_xmx_xe.hpp index c43899d65..19e4c89b3 100644 --- a/include/group/gemm/impl/unaligned_xmx_xe.hpp +++ b/include/group/gemm/impl/unaligned_xmx_xe.hpp @@ -129,8 +129,6 @@ class gemm_t< /******** set tile **********/ static constexpr reg_layout reg_layout_a = reg_layout::tiled; - - public: using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, tile_size_y_a, @@ -216,6 +214,7 @@ class gemm_t< wg_size_y, arch_tag>; + public: using matAcc_tile_desc_t = subgroup::tile_desc_t< tile_size_x_c, tile_size_y_c, diff --git a/include/group/global_reduction.hpp b/include/group/global_reduction.hpp index 84d91875b..0fe83d6a8 100644 --- a/include/group/global_reduction.hpp +++ b/include/group/global_reduction.hpp @@ -137,7 +137,6 @@ class global_reduce_t< xetla_store_global< dtype_cnt, 1, - data_size::default_size, cache_hint::uncached, cache_hint::write_back>((dtype_cnt*)address, 0, zeros); } diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index 3a2c13ce4..cb6c5270b 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -253,44 +253,38 @@ class gemm_universal_t< bool implementable = true; if (gemm_t::msg_type_a != msg_type::unaligned_2d) { if (gemm_t::msg_type_a == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matA_base.base), - gemm_t::is_col_major_a ? args.matrix_m : args.matrix_k, - gemm_t::is_col_major_a ? args.matrix_k : args.matrix_m, - args.matA_ld); + implementable &= kernel::block_2d::check_tensor( + (uint64_t)(args.matA_base.base), + gemm_t::is_col_major_a ? args.matrix_m : args.matrix_k, + gemm_t::is_col_major_a ? args.matrix_k : args.matrix_m, + args.matA_ld); } else { - implementable &= - kernel::general_1d::check_alignment( - args.matA_base.base, args.matA_ld); + implementable &= kernel::general_1d::check_alignment( + args.matA_base.base, args.matA_ld); } } if (gemm_t::msg_type_b != msg_type::unaligned_2d) { if (gemm_t::msg_type_b == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matB_base.base), - gemm_t::is_col_major_b ? args.matrix_k : args.matrix_n, - gemm_t::is_col_major_b ? args.matrix_n : args.matrix_k, - args.matB_ld); + implementable &= kernel::block_2d::check_tensor( + (uint64_t)(args.matB_base.base), + gemm_t::is_col_major_b ? args.matrix_k : args.matrix_n, + gemm_t::is_col_major_b ? args.matrix_n : args.matrix_k, + args.matB_ld); } else { - implementable &= - kernel::general_1d::check_alignment( - args.matB_base.base, args.matB_ld); + implementable &= kernel::general_1d::check_alignment( + args.matB_base.base, args.matB_ld); } } if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); + implementable &= kernel::block_2d::check_tensor( + (uint64_t)(args.matC_base.base), + args.matrix_n, + args.matrix_m, + args.matC_ld); } else { - implementable &= - kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); + implementable &= kernel::general_1d::check_alignment( + args.matC_base.base, args.matC_ld); } } @@ -310,7 +304,7 @@ class gemm_universal_t< sycl::nd_item<3>& item, const arguments_t& args, uint32_t slm_base = 0, - uint32_t nbarrier_base = 0) { + uint32_t nbarrier_base = 0) const { // set up workgroup level coordinates and boundaries group_swizzle_t group_swizzle; int start_m = group_swizzle.template get_tile_idx<1>(item) * wg_tile_m; diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index 59b8f621a..7b74226e5 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -365,44 +365,38 @@ class gemm_universal_t< bool implementable = true; if (gemm_t::msg_type_a != msg_type::unaligned_2d) { if (gemm_t::msg_type_a == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matA_base.base), - gemm_t::is_col_major_a ? args.matrix_m : args.matrix_k, - gemm_t::is_col_major_a ? args.matrix_k : args.matrix_m, - args.matA_ld); + implementable &= kernel::block_2d::check_tensor( + (uint64_t)(args.matA_base.base), + gemm_t::is_col_major_a ? args.matrix_m : args.matrix_k, + gemm_t::is_col_major_a ? args.matrix_k : args.matrix_m, + args.matA_ld); } else { - implementable &= - kernel::general_1d::check_alignment( - args.matA_base.base, args.matA_ld); + implementable &= kernel::general_1d::check_alignment( + args.matA_base.base, args.matA_ld); } } if (gemm_t::msg_type_b != msg_type::unaligned_2d) { if (gemm_t::msg_type_b == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matB_base.base), - gemm_t::is_col_major_b ? args.matrix_k : args.matrix_n, - gemm_t::is_col_major_b ? args.matrix_n : args.matrix_k, - args.matB_ld); + implementable &= kernel::block_2d::check_tensor( + (uint64_t)(args.matB_base.base), + gemm_t::is_col_major_b ? args.matrix_k : args.matrix_n, + gemm_t::is_col_major_b ? args.matrix_n : args.matrix_k, + args.matB_ld); } else { - implementable &= - kernel::general_1d::check_alignment( - args.matB_base.base, args.matB_ld); + implementable &= kernel::general_1d::check_alignment( + args.matB_base.base, args.matB_ld); } } if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); + implementable &= kernel::block_2d::check_tensor( + (uint64_t)(args.matC_base.base), + args.matrix_n, + args.matrix_m, + args.matC_ld); } else { - implementable &= - kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); + implementable &= kernel::general_1d::check_alignment( + args.matC_base.base, args.matC_ld); } } @@ -423,7 +417,7 @@ class gemm_universal_t< sycl::nd_item<3>& item, const arguments_t& args, uint32_t slm_base = 0, - uint32_t nbarrier_base = 0) { + uint32_t nbarrier_base = 0) const { // set up workgroup level coordinates and boundaries work_group_t g(item.get_local_linear_id() % work_group_size); uint32_t wg_id = item.get_local_linear_id() / work_group_size; @@ -563,4 +557,6 @@ class gemm_universal_t< } }; +/// @} xetla_gemm_universal + } // namespace gpu::xetla::kernel diff --git a/include/kernel/gemm/impl/stream_k_xe.hpp b/include/kernel/gemm/impl/stream_k_xe.hpp index 0ad5483fd..e281e53ae 100644 --- a/include/kernel/gemm/impl/stream_k_xe.hpp +++ b/include/kernel/gemm/impl/stream_k_xe.hpp @@ -31,16 +31,17 @@ namespace gpu::xetla::kernel { /// /// @tparam gemm_t_ Is the gemm functor to compose a GEMM_UNIVERSAL. /// @tparam epilogue_t_ Is the epilogue functor to compose a GEMM_UNIVERSAL. -template +template class gemm_universal_t< - dispatch_policy_stream_k, + dispatch_policy_stream_k, gemm_t_, epilogue_t_> { + static constexpr gpu_arch arch_tag = arch_tag_; using gemm_t = gemm_t_; using epilogue_t = epilogue_t_; using gemm_args_t = typename gemm_t::arguments_t; using epilogue_args_t = typename epilogue_t::arguments_t; - using dispatch_stream_k = dispatch_policy_stream_k; + using dispatch_stream_k = dispatch_policy_stream_k; // Scratchspace to accumulate partials using mem_desc_d_t = @@ -63,8 +64,10 @@ class gemm_universal_t< using work_group_t = typename gemm_t::work_group_t; - static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same"); + static_assert( + arch_tag == epilogue_t::arch_tag, + "arch_tag should be the same"); using mem_desc_a_t = typename gemm_t::mem_desc_a_t; using mem_desc_b_t = typename gemm_t::mem_desc_b_t; @@ -82,7 +85,8 @@ class gemm_universal_t< tile_shape, epilogue_t, mem_desc_d_t, - mem_desc_atomic_sync_t>; + mem_desc_atomic_sync_t, + arch_tag>; public: /// @brief GEMM arguments. @@ -435,7 +439,7 @@ class gemm_universal_t< sycl::nd_item<3>& item, const arguments_t& args, uint32_t slm_base = 0, - uint32_t nbarrier_base = 0) { + uint32_t nbarrier_base = 0) const { const dispatch_stream_k& workgroup_mapping = args.stream_k_args; int group_idx = item.get_group(2); @@ -625,6 +629,7 @@ class gemm_universal_t< } } }; + /// @} xetla_gemm_universal } // namespace gpu::xetla::kernel diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 2fe09261c..9385c700f 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -106,6 +106,20 @@ __XETLA_API typename std::enable_if_t process_1d_tail( [[maybe_unused]] payload_t& payload, [[maybe_unused]] uint32_t offset) {} +template < + uint32_t remained_len, + uint32_t base_len, + process_flag flag, + cache_hint L1, + cache_hint L2, + typename payload_t, + typename tile_t> +__XETLA_API typename std::enable_if_t process_1d_tail( + [[maybe_unused]] tile_t& tile, + [[maybe_unused]] payload_t& payload, + [[maybe_unused]] uint32_t offset, + [[maybe_unused]] uint32_t address_offset) {} + template < uint32_t remained_len, uint32_t base_len, @@ -124,14 +138,11 @@ process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { auto reg_sub = tile.reg.xetla_select(offset); if constexpr (flag == process_flag::load) { - reg_sub.xetla_format() = xetla_load_global< - mem_dtype, - base_len, - data_size::default_size, - L1, - L2>(payload.base_ptr, payload.base_offset + address_offset); + reg_sub.xetla_format() = + xetla_load_global( + payload.base_ptr, payload.base_offset + address_offset); } else { - xetla_store_global( + xetla_store_global( payload.base_ptr, payload.base_offset + address_offset, reg_sub.xetla_format()); @@ -154,19 +165,23 @@ template < typename tile_t> __XETLA_API typename std::enable_if_t< base_len != 0 && payload_t::memory_space == mem_space::local> -process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { +process_1d_tail( + tile_t& tile, + payload_t& payload, + uint32_t offset, + uint32_t address_offset) { using mem_dtype = typename payload_t::mem_dtype; if constexpr (remained_len >= base_len) { auto reg_sub = tile.reg.xetla_select(offset); - uint32_t address_offset = offset * sizeof(typename tile_t::dtype); if constexpr (flag == process_flag::load) { reg_sub.xetla_format() = xetla_load_local( - payload.address + address_offset); + payload.base_address + payload.address + address_offset); } else { xetla_store_local( - payload.address + address_offset, reg_sub.xetla_format()); + payload.base_address + payload.address + address_offset, + reg_sub.xetla_format()); } process_1d_tail< remained_len - base_len, @@ -175,7 +190,13 @@ process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { L1, L2, payload_t, - tile_t>(tile, payload, offset + base_len * payload_t::scale_factor); + tile_t>( + tile, + payload, + offset + base_len * payload_t::scale_factor, + address_offset + + base_len * payload_t::scale_factor * + sizeof(typename tile_t::dtype)); } else { process_1d_tail< remained_len, @@ -184,12 +205,13 @@ process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { L1, L2, payload_t, - tile_t>(tile, payload, offset); + tile_t>(tile, payload, offset, address_offset); } } // This will end up with base_len equal to 8 because we had made tile_size_x // divisible by 8/16/32, depends on dtype +// this is for prefetch only and use different func arg compare with load/store template < uint32_t remained_len, uint32_t base_len, @@ -275,6 +297,17 @@ struct is_same_layout { (T_src::tile_size_x == T_dst::tile_size_x); }; +template +struct is_1d_src { + static constexpr bool value = (T_src::tile_elems == T_dst::tile_elems) && + (T_src::block_size_y == 1) && (T_src::tile_size_y == 1); +}; + +template +struct is_same_elements { + static constexpr bool value = (T_src::tile_elems == T_dst::tile_elems); +}; + template struct is_floating_to_integer { static constexpr bool value = @@ -298,12 +331,8 @@ struct msg_type_query { : msg_type::scatter); }; -template < - typename tile_desc_, - mem_space memory_space, - mem_layout memory_layout = mem_layout::row_major> -constexpr msg_type msg_type_v = - msg_type_query::value; +template +constexpr msg_type msg_type_v = msg_type_query::value; template < typename dtype, diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 70ddd6d74..216a57d96 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -414,12 +414,9 @@ tile_load(tile_t& tile, payload_t& payload) { auto reg_sub = tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); - reg_sub.xetla_format() = xetla_load_global< - load_dtype, - max_load_vec_len, - data_size::default_size, - L1, - L2>(payload.base_ptr, payload.base_offset + address_offset); + reg_sub.xetla_format() = + xetla_load_global( + payload.base_ptr, payload.base_offset + address_offset); } } @@ -441,7 +438,7 @@ tile_load(tile_t& tile, payload_t& payload) { /// @tparam payload_t Is the mem_payload_t struct describing the memory /// information. Payload indicates the source of load operation. /// @tparam L1 Is the cache hint for L1 cache. -/// @tparam L3 Is the cache hint for L3 cache. +/// @tparam L2 Is the cache hint for L2 cache. /// @param tile Is the tile object with type tile_t, holds the return data of /// the loads. /// @param payload Is the payload object with type payload_t. Contains all the @@ -449,7 +446,7 @@ tile_load(tile_t& tile, payload_t& payload) { /// @return No return, update in place. template < cache_hint L1 = cache_hint::cached, - cache_hint L3 = cache_hint::cached, + cache_hint L2 = cache_hint::cached, typename tile_t, typename payload_t> __XETLA_API typename std::enable_if_t< @@ -476,7 +473,7 @@ tile_load(tile_t& tile, payload_t& payload) { for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; sub_block_y += num_channel) { xetla_vector reg_tmp = 0; - uint32_t address_offset = payload_t::trans + uint32_t address_offset = payload_t::mem_transpose ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) : offset_x * sizeof(dtype) + (offset_y + 0) * payload.pitch_in_bytes; @@ -499,7 +496,7 @@ tile_load(tile_t& tile, payload_t& payload) { payload_t::simd_exec_size, data_size::default_size, L1, - L3, + L2, payload_t::num_channel>( payload.base_ptr, payload.channel_offset + payload.base_offset + address_offset, @@ -550,7 +547,7 @@ tile_load(tile_t& tile, payload_t& payload) { /// @tparam payload_t Is the mem_payload_t struct describing the memory /// information. Payload indicates the source of load operation. /// @tparam L1 Is the cache hint for L1 cache. -/// @tparam L3 Is the cache hint for L3 cache. +/// @tparam L2 Is the cache hint for L2 cache. /// @param tile Is the tile object with type tile_t, holds the return data of /// the loads. /// @param payload Is the payload object with type payload_t. Contains all the @@ -558,7 +555,7 @@ tile_load(tile_t& tile, payload_t& payload) { /// @return No return, update in place. template < cache_hint L1 = cache_hint::cached, - cache_hint L3 = cache_hint::cached, + cache_hint L2 = cache_hint::cached, typename tile_t, typename payload_t> __XETLA_API typename std::enable_if_t< @@ -568,9 +565,7 @@ __XETLA_API typename std::enable_if_t< tile_load(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; - using load_dtype = typename payload_t::mem_dtype; - constexpr uint32_t load_elems = payload_t::simd_exec_size; - constexpr uint32_t pack_factor = payload_t::pack_factor; + constexpr uint32_t load_elems = tile_desc::block_size_x; #pragma unroll for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { @@ -583,23 +578,16 @@ tile_load(tile_t& tile, payload_t& payload) { #pragma unroll for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; sub_block_y += 1) { - xetla_vector reg_tmp = 0; - uint32_t address_offset = payload_t::trans + uint32_t address_offset = payload_t::mem_transpose ? offset_x * payload.pitch_in_bytes + (offset_y + sub_block_y) * sizeof(dtype) : offset_x * sizeof(dtype) + (offset_y + sub_block_y) * payload.pitch_in_bytes; - reg_tmp = xetla_load_global< - load_dtype, - payload_t::simd_exec_size, - data_size::default_size, - L1, - L3>(payload.base_ptr, payload.base_offset + address_offset); - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp; + reg_sub.xetla_select( + sub_block_y * tile_desc::block_size_x) = + xetla_load_global( + (dtype*)payload.base_ptr, payload.base_offset + address_offset); } } } @@ -612,7 +600,6 @@ tile_load(tile_t& tile, payload_t& payload) { SW_BARRIER(); vnni_convert(tile); } - } /// @brief This function loads data from unaligned-2D memory surface. @@ -623,7 +610,7 @@ tile_load(tile_t& tile, payload_t& payload) { /// @tparam payload_t Is the mem_payload_t struct describing the memory /// information. Payload indicates the source of load operation. /// @tparam L1 Is the cache hint for L1 cache. -/// @tparam L3 Is the cache hint for L3 cache. +/// @tparam L2 Is the cache hint for L2 cache. /// @param tile Is the tile object with type tile_t, holds the return data of /// the loads. /// @param payload Is the payload object with type payload_t. Contains all the @@ -631,7 +618,7 @@ tile_load(tile_t& tile, payload_t& payload) { /// @return No return, update in place. template < cache_hint L1 = cache_hint::cached, - cache_hint L3 = cache_hint::cached, + cache_hint L2 = cache_hint::cached, typename tile_t, typename payload_t, typename oob_check_tag = global_atomic_oob_check_on_tag> @@ -682,7 +669,7 @@ tile_load( 1, data_size::default_size, L1, - L3, + L2, load_elems>( payload.base_ptr, payload.channel_offset + payload.base_offset + address_offset, @@ -731,7 +718,7 @@ tile_load( 1, data_size::default_size, L1, - L3, + L2, load_elems>( payload.base_ptr, payload.channel_offset + payload.base_offset + address_offset, @@ -828,6 +815,10 @@ tile_load(tile_t& tile, payload_t& payload) { } } } + if constexpr (payload_t::reg_transpose) { + SW_BARRIER(); + tile_transpose(tile); + } if constexpr (mem_transform) { SW_BARRIER(); vnni_convert(tile); @@ -861,33 +852,42 @@ tile_load(tile_t& tile, payload_t& payload) { using load_dtype = typename payload_t::mem_dtype; constexpr uint32_t scale_factor = payload_t::scale_factor; - static constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor; - static constexpr gpu_arch arch_tag = payload_t::arch_tag; + constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor; + constexpr gpu_arch arch_tag = payload_t::arch_tag; using load_store_attr = load_store_attr_t; - static constexpr uint32_t max_load_vec_len = - load_store_attr::max_load_vec_len; + constexpr uint32_t max_load_vec_len = load_store_attr::max_load_vec_len; - static constexpr uint32_t load_iter_steps = load_len / max_load_vec_len; - - if constexpr (load_len >= max_load_vec_len) { + constexpr uint32_t load_iter_steps = load_len / max_load_vec_len; #pragma unroll - for (uint32_t j = 0; j < load_iter_steps; j++) { - uint32_t offset_x = j * max_load_vec_len * scale_factor; - auto reg_sub = - tile.reg.xetla_select(offset_x); - uint32_t address_offset = offset_x * sizeof(dtype); - reg_sub.xetla_format() = xetla_load_local< - load_dtype, - max_load_vec_len, - data_size::default_size>(payload.address + address_offset); + for (uint32_t i = 0; i < tile_desc::tile_size_y; i++) { + uint32_t offset_y = i * tile_desc::tile_size_x; + uint32_t address_offset_y = i * payload.pitch_in_bytes; + if constexpr (load_len >= max_load_vec_len) { +#pragma unroll + for (uint32_t j = 0; j < load_iter_steps; j++) { + uint32_t offset_x = j * max_load_vec_len * scale_factor; + auto reg_sub = + tile.reg.xetla_select( + offset_x + offset_y); + uint32_t address_offset = address_offset_y + offset_x * sizeof(dtype); + reg_sub.xetla_format() = xetla_load_local< + load_dtype, + max_load_vec_len, + data_size::default_size>( + payload.base_address + payload.address + address_offset); + } } + uint32_t tail_offset = + offset_y + load_iter_steps * max_load_vec_len * scale_factor; + uint32_t tail_address_offset = address_offset_y + + load_iter_steps * max_load_vec_len * scale_factor * sizeof(dtype); + detail::process_1d_tail< + load_len % max_load_vec_len, + (max_load_vec_len >> 1), + detail::process_flag::load, + L1, + L2>(tile, payload, tail_offset, tail_address_offset); } - detail::process_1d_tail< - load_len % max_load_vec_len, - (max_load_vec_len >> 1), - detail::process_flag::load, - L1, - L2>(tile, payload, load_iter_steps * max_load_vec_len * scale_factor); } } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/mma_xe.hpp b/include/subgroup/tile/impl/mma_xe.hpp index 7292a013f..371040397 100644 --- a/include/subgroup/tile/impl/mma_xe.hpp +++ b/include/subgroup/tile/impl/mma_xe.hpp @@ -67,7 +67,8 @@ struct tile_mma_t< static constexpr uint32_t tile_size_n = matDst_t::tile_size_x; static constexpr uint32_t tile_elems = tile_size_m * tile_size_n; static constexpr uint32_t block_size_n = matDst_t::block_size_x; - static constexpr uint32_t block_size_k = a_block_size_x; + static constexpr uint32_t block_size_k = + a_block_size_x; // cannot use b_block_size_y static constexpr uint32_t block_size_m = matDst_t::block_size_y; static constexpr uint32_t block_elems = block_size_m * block_size_n; @@ -87,8 +88,8 @@ struct tile_mma_t< block_size_n == b_block_size_x, "matAcc block n should match with matB block n"); static_assert( - a_block_size_x == b_block_size_y, - "matA block w should match with matB block h"); + b_block_size_y % a_block_size_x == 0, + "matA block k should match with matB block k"); static_assert( (tile_size_k % block_size_k) == 0, "matAcc tile_size_k should be a multiple of block_size_k"); @@ -101,6 +102,8 @@ struct tile_mma_t< static constexpr int32_t num_block_n = matDst_t::num_block_x; static constexpr int32_t num_block_m = matDst_t::num_block_y; static constexpr int32_t num_block_k = tile_size_k / block_size_k; + static constexpr int32_t num_block_mma_b = b_block_size_y / block_size_k; + static constexpr uint32_t b_block_mma_elems = b_block_elems / num_block_mma_b; using mma_attr = mma_attr_t; static constexpr int32_t mma_m = mma_attr::mma_m_in_elem; @@ -136,22 +139,20 @@ struct tile_mma_t< (i * num_block_k) * a_block_elems); auto a_sub_blk = a_block.xetla_select(mma_i * a_mma_elems); - auto b_sub_blk = + auto b_blk = b.reg.xetla_select(j * b_block_elems); + auto b_sub_blk = b_blk.xetla_select(0); dst_sub_blk = xetla_mma< gpu::xetla::detail::mma_argument_type(), gpu::xetla::detail::mma_argument_type(), mma_k, mma_m, dtype_src, - uint32_t, - uint32_t, + dtype_b, + dtype_a, c_mma_elems, - b_block_elems / (sizeof(uint32_t) / sizeof(dtype_b)), - a_mma_elems / (sizeof(uint32_t) / sizeof(dtype_a))>( - src_sub_blk, - b_sub_blk.xetla_format(), - a_sub_blk.xetla_format()); + b_block_mma_elems, + a_mma_elems>(src_sub_blk, b_sub_blk, a_sub_blk); } #pragma unroll @@ -160,22 +161,23 @@ struct tile_mma_t< (i * num_block_k + k) * a_block_elems); auto a_sub_blk = a_block.xetla_select(mma_i * a_mma_elems); - auto b_sub_blk = b.reg.xetla_select( - (j + k * num_block_n) * b_block_elems); + int inter_k_b = k / num_block_mma_b; + int inner_k_b = k % num_block_mma_b; + auto b_blk = b.reg.xetla_select( + (j + inter_k_b * num_block_n) * b_block_elems); + auto b_sub_blk = b_blk.xetla_select( + inner_k_b * b_block_mma_elems); dst_sub_blk = xetla_mma< gpu::xetla::detail::mma_argument_type(), gpu::xetla::detail::mma_argument_type(), mma_k, mma_m, dtype_src, - uint32_t, - uint32_t, + dtype_b, + dtype_a, c_mma_elems, - b_block_elems / (sizeof(uint32_t) / sizeof(dtype_b)), - a_mma_elems / (sizeof(uint32_t) / sizeof(dtype_a))>( - dst_sub_blk, - b_sub_blk.xetla_format(), - a_sub_blk.xetla_format()); + b_block_mma_elems, + a_mma_elems>(dst_sub_blk, b_sub_blk, a_sub_blk); } } } @@ -203,22 +205,21 @@ struct tile_mma_t< a.reg.xetla_select(a_tail_elems_start); auto a_sub_blk = a_block.xetla_select(mma_i * a_mma_elems); - auto b_sub_blk = + auto b_blk = b.reg.xetla_select(j * b_block_elems); + auto b_sub_blk = b_blk.xetla_select(0); + dst_sub_blk = xetla_mma< gpu::xetla::detail::mma_argument_type(), gpu::xetla::detail::mma_argument_type(), mma_k, mma_m, dtype_src, - uint32_t, - uint32_t, + dtype_b, + dtype_a, c_mma_elems, - b_block_elems / (sizeof(uint32_t) / sizeof(dtype_b)), - a_mma_elems / (sizeof(uint32_t) / sizeof(dtype_a))>( - src_sub_blk, - b_sub_blk.xetla_format(), - a_sub_blk.xetla_format()); + b_block_mma_elems, + a_mma_elems>(src_sub_blk, b_sub_blk, a_sub_blk); } #pragma unroll for (uint32_t k = 1; k < num_block_k; k++) { @@ -226,22 +227,24 @@ struct tile_mma_t< a_tail_elems_start + k * a_tail_block_elems); auto a_sub_blk = a_block.xetla_select(mma_i * a_mma_elems); - auto b_sub_blk = b.reg.xetla_select( - (j + k * num_block_n) * b_block_elems); + int inter_k_b = k / num_block_mma_b; + int inner_k_b = k % num_block_mma_b; + auto b_blk = b.reg.xetla_select( + (j + inter_k_b * num_block_n) * b_block_elems); + auto b_sub_blk = b_blk.xetla_select( + inner_k_b * b_block_mma_elems); + dst_sub_blk = xetla_mma< gpu::xetla::detail::mma_argument_type(), gpu::xetla::detail::mma_argument_type(), mma_k, mma_m, dtype_src, - uint32_t, - uint32_t, + dtype_b, + dtype_a, c_mma_elems, - b_block_elems / (sizeof(uint32_t) / sizeof(dtype_b)), - a_mma_elems / (sizeof(uint32_t) / sizeof(dtype_a))>( - dst_sub_blk, - b_sub_blk.xetla_format(), - a_sub_blk.xetla_format()); + b_block_mma_elems, + a_mma_elems>(dst_sub_blk, b_sub_blk, a_sub_blk); } } } diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index c8c053090..85e83b45b 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -36,7 +36,7 @@ template __XETLA_API typename std::enable_if_t< (T_src::register_layout != reg_layout::linear) && (T_dst::register_layout != reg_layout::linear) && - is_same_layout::value && + (is_same_elements::value) && (!is_floating_to_integer::value)> elemwise_cvt(T_dst& dst, T_src& src) { constexpr uint32_t block_size_x = T_dst::block_size_x; @@ -44,7 +44,7 @@ elemwise_cvt(T_dst& dst, T_src& src) { using dtype_src = typename T_src::dtype; using dtype_dst = typename T_dst::dtype; if constexpr (std::is_same::value) { - dst.reg = src.reg; + dst.reg = xetla_cvt(src.reg); } else { #pragma unroll for (uint32_t i = 0; i < tile_elems; i += block_size_x) { @@ -55,6 +55,25 @@ elemwise_cvt(T_dst& dst, T_src& src) { } } +template +__XETLA_API typename std::enable_if_t< + std::is_same::value> +elemwise_cvt(T_dst& dst, T_src& src) { + using dtype_src = typename T_src::dtype; + using dtype_dst = typename T_dst::dtype; + constexpr uint32_t tile_elems = T_src::tile_elems; + constexpr uint32_t unroll_src_size = 64; + constexpr uint32_t unroll_dst_size = + unroll_src_size / get_packed_num::value; +#pragma unroll + for (uint32_t i = 0; i < tile_elems; i += unroll_src_size) { + dst.reg.xetla_select( + i / get_packed_num::value) = + xetla_cvt( + src.reg.xetla_select(i)); + } +} + /// @brief Is the element wise data conversion from floating point to integral, /// the src and dst tile should have the same layout. /// @tparam T_dst Is the destination tile data type. @@ -148,7 +167,7 @@ __XETLA_API auto reg_dst_2d = reg_dst.xetla_format, move_rows, move_cols>(); #pragma unroll - for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { + for (int vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { reg_dst_2d.xetla_select( 0, vnni_i) = reg_2d.xetla_select( @@ -166,7 +185,7 @@ __XETLA_API constexpr int32_t remain_move_cols = block_size_x * vnni_stride; constexpr int32_t remain_move_rows = remain_size_y / vnni_stride; #pragma unroll - for (uint32_t j = 0; j < num_block_x; j++) { + for (int j = 0; j < num_block_x; j++) { auto reg = (mat_Acc.reg) .xetla_select( remain_elems_start + j * remain_block_elems); @@ -178,7 +197,7 @@ __XETLA_API native_type_t, remain_move_rows, remain_move_cols>(); - for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { + for (int vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { reg_dst_2d.xetla_select( 0, vnni_i) = reg_2d.xetla_select( @@ -195,7 +214,9 @@ __XETLA_API /// @param mat_Acc Is the reference of the tile object. /// @return No return, update the data in-place. template -__XETLA_API typename std::enable_if_t +__XETLA_API typename std::enable_if_t< + T::register_layout == reg_layout::tiled || + T::register_layout == reg_layout::vnni_tiled> vnni_reverse(T& mat_Acc) { constexpr uint32_t tile_size_y = T::tile_size_y; constexpr uint32_t tile_size_x = T::tile_size_x; @@ -228,7 +249,7 @@ vnni_reverse(T& mat_Acc) { reg_dst .xetla_format, block_size_y, block_size_x>(); #pragma unroll - for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { + for (int vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { reg_dst_2d.xetla_select( vnni_i, 0) = reg_2d.xetla_select( @@ -246,7 +267,7 @@ vnni_reverse(T& mat_Acc) { constexpr int32_t remain_move_cols = block_size_x * vnni_stride; constexpr int32_t remain_move_rows = remain_size_y / vnni_stride; #pragma unroll - for (uint32_t j = 0; j < num_block_x; j++) { + for (int j = 0; j < num_block_x; j++) { auto reg = (mat_Acc.reg) .xetla_select( remain_elems_start + j * remain_block_elems); @@ -260,7 +281,7 @@ vnni_reverse(T& mat_Acc) { native_type_t, remain_size_y, block_size_x>(); - for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { + for (int vnni_i = 0; vnni_i < vnni_stride; vnni_i++) { reg_dst_2d.xetla_select( vnni_i, 0) = reg_2d.xetla_select( @@ -468,10 +489,15 @@ vnni_transform(T_dst& dst, T_src& src) { /// @return No return, update the data in-place. template __XETLA_API void tile_transpose(T& mat_Acc) { - constexpr uint32_t tile_size_y = T::tile_size_y; - constexpr uint32_t tile_size_x = T::tile_size_x; - constexpr uint32_t block_size_y = T::block_size_y; - constexpr uint32_t block_size_x = T::block_size_x; + constexpr uint32_t tile_size_y = + T ::reg_transpose ? T::tile_size_y : T::tile_size_x; + constexpr uint32_t tile_size_x = + T ::reg_transpose ? T::tile_size_x : T::tile_size_y; + constexpr uint32_t block_size_y = + T ::reg_transpose ? T::block_size_y : T::block_size_x; + constexpr uint32_t block_size_x = + T ::reg_transpose ? T::block_size_x : T::block_size_y; + constexpr uint32_t block_elems = block_size_y * block_size_x; constexpr int32_t num_block_x = tile_size_x / block_size_x; constexpr int32_t num_block_y = tile_size_y / block_size_y; diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index aff34e936..c895614e0 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -124,6 +124,14 @@ struct mem_payload_t< prepare_tdesc(base_tdesc); } + __XETLA_API void init(xetla_tdescriptor base_tdesc) { + int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) / + int32_t(scale_factor); + gpu::xetla::detail::xetla_set_tensor_offset_x( + base_tdesc.xetla_format(), offset); + prepare_tdesc(base_tdesc); + } + __XETLA_API void init( dtype* p, uint32_t surface_width, @@ -170,6 +178,149 @@ struct mem_payload_t< } } + __XETLA_API void update_tdesc_base_address(int offset) { + auto payloads_2d = payloads.xetla_format(); +#pragma unroll + for (int i = 0; i < num_block; i++) { + xetla_update_tdesc_base_address(payloads_2d.row(i), offset); + } + } + + __XETLA_API void set_tdesc_width(uint32_t size) { + auto payloads_2d = payloads.xetla_format(); +#pragma unroll + for (int i = 0; i < num_block; i++) { + xetla_set_tdesc_width(payloads_2d.row(i), size); + } + } + + __XETLA_API void set_tdesc_pitch(uint32_t size) { + auto payloads_2d = payloads.xetla_format(); +#pragma unroll + for (int i = 0; i < num_block; i++) { + xetla_set_tdesc_pitch(payloads_2d.row(i), size); + } + } + + __XETLA_API void set_tdesc_height(uint32_t size) { + auto payloads_2d = payloads.xetla_format(); +#pragma unroll + for (int i = 0; i < num_block; i++) { + xetla_set_tdesc_height(payloads_2d.row(i), size); + } + } + + __XETLA_API void update_tdesc_base_address_masked( + int offset, + uint16_t mask = 1) { + auto payloads_2d = payloads.xetla_format(); +#pragma unroll + for (int i = 0; i < num_block; i++) { + xetla_update_tdesc_base_address(payloads_2d.row(i), offset); + } + +#pragma unroll + for (int i = 0; i < num_block; i++) { + xetla_tdesc_mask_op(payloads_2d.row(i), mask); + } + } + __XETLA_API void set_tdesc_base_address_masked( + uint64_t offset, + uint16_t mask = 1) { + auto payloads_2d = payloads.xetla_format(); +#pragma unroll + for (int i = 0; i < num_block; i++) { + gpu::xetla::detail::xetla_set_tensor_base_address( + payloads_2d.row(i), offset); + } + +#pragma unroll + for (int i = 0; i < num_block; i++) { + xetla_tdesc_mask_op(payloads_2d.row(i), mask); + } + } + + __XETLA_API void set_tdesc_base_address(uint64_t addr) { + auto payloads_2d = payloads.xetla_format(); +#pragma unroll + for (int i = 0; i < num_block; i++) { + gpu::xetla::detail::xetla_set_tensor_base_address( + payloads_2d.row(i), addr); + } + } + + __XETLA_API void set_offset(int32_t offset_x, int32_t offset_y) { + auto payloads_2d = payloads.xetla_format(); + constexpr uint32_t arr_len = 1; + int32_t base_offset_y = offset_y; +#pragma unroll + for (int i = 0; i < num_block_y; i++) { + auto tdesc_row_2d = + payloads_2d.xetla_select(i * num_block_x, 0); + + int32_t base_offset_x = offset_x; +#pragma unroll + for (int j = 0; j < num_block_x; j++) { + // To mimic dw transpose for word/byte data type with transpose and pack + constexpr uint8_t block_width = mem_transpose + ? (block_size_y / scale_factor) + : (block_size_x / scale_factor); + constexpr uint8_t block_height = + mem_transpose ? block_size_x : block_size_y; + constexpr uint32_t block_widthx_widthy_arrlen = (block_width - 1) | + ((block_height - 1) << 8) | ((arr_len - 1) << 16); + gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( + tdesc_row_2d.row(j), block_widthx_widthy_arrlen); + + int32_t offset_width = mem_transpose + ? (base_offset_y / int32_t(scale_factor)) + : (base_offset_x / int32_t(scale_factor)); + int32_t offset_height = mem_transpose ? base_offset_x : base_offset_y; + + gpu::xetla::detail::xetla_set_tensor_offset_x( + tdesc_row_2d.row(j), offset_width); + gpu::xetla::detail::xetla_set_tensor_offset_y( + tdesc_row_2d.row(j), offset_height); + + base_offset_x += block_size_x * arr_len; + } + base_offset_y += block_size_y; + } + // process the tail + if constexpr (remained_size_y > 0) { + auto tdesc_row_2d = payloads_2d.xetla_select( + num_block_y * num_block_x, 0); + // this is exactly copy paste from above. so maybe worth createing some + // function + int32_t base_offset_x = offset_x; +#pragma unroll + for (int j = 0; j < num_block_x; j++) { + // To mimic dw transpose for word/byte data type with transpose and pack + constexpr uint8_t block_width = mem_transpose + ? (block_size_y / scale_factor) + : (block_size_x / scale_factor); + constexpr uint8_t block_height = + mem_transpose ? block_size_x : block_size_y; + constexpr uint32_t block_widthx_widthy_arrlen = (block_width - 1) | + ((block_height - 1) << 8) | ((arr_len - 1) << 16); + gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( + tdesc_row_2d.row(j), block_widthx_widthy_arrlen); + + int32_t offset_width = mem_transpose + ? (base_offset_y / int32_t(scale_factor)) + : (base_offset_x / int32_t(scale_factor)); + int32_t offset_height = mem_transpose ? base_offset_x : base_offset_y; + + gpu::xetla::detail::xetla_set_tensor_offset_x( + tdesc_row_2d.row(j), offset_width); + gpu::xetla::detail::xetla_set_tensor_offset_y( + tdesc_row_2d.row(j), offset_height); + + base_offset_x += block_size_x; + } + } + } + private: __XETLA_API void prepare_tdesc(xetla_tdescriptor base_tdesc) { auto payloads_2d = payloads.xetla_format(); @@ -385,7 +536,9 @@ struct mem_payload_t< static constexpr msg_type message_type = msg_type::atomic_add; static constexpr uint32_t alignment_in_bytes = mem_desc_t::alignment_in_bytes; static constexpr gpu_arch arch_tag = arch_tag_; - static_assert(sizeof(dtype) >= 4, "for atomic add, we only support DW or QW"); + static_assert( + sizeof(dtype) >= 2, + "for atomic add, we only support W, DW or QW"); private: static constexpr uint32_t tile_size_x = tile_desc::tile_size_x; @@ -411,11 +564,14 @@ struct mem_payload_t< : 16; static constexpr uint32_t num_channel = (simd_channel >= block_size_x) ? block_size_x : simd_channel; - static constexpr uint32_t num_channel_x = block_size_x; // 16 - static constexpr uint32_t num_channel_y = num_channel / num_channel_x; // 1 + static constexpr uint32_t num_channel_x = block_size_x; + static constexpr uint32_t num_channel_y = num_channel / num_channel_x; static constexpr uint32_t store_elems = num_channel_y * block_size_x; - xetla_vector channel_offset; + // may need to set it to be configurable later + using Toffset = uint32_t; + + xetla_vector channel_offset; xetla_vector step_x; xetla_vector step_y; uint32_t pitch_in_bytes; @@ -590,70 +746,77 @@ struct mem_payload_t< dtype>::type>::type; static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype); + uint32_t base_address; uint32_t address; uint32_t pitch_in_bytes; - inline mem_payload_t(mem_desc_t& mem_tdesc) { + __XETLA_API void init(mem_desc_t& mem_tdesc) { + base_address = mem_tdesc.base.base; pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; - address = mem_tdesc.base.base + offset_y * pitch_in_bytes + - offset_x * sizeof(dtype); + address = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); } - inline mem_payload_t( + + __XETLA_API void init( uint32_t base, - [[maybe_unused]] int surface_width, - [[maybe_unused]] int surface_height, - int surface_pitch, - int surface_offset_x, - int surface_offset_y) { + [[maybe_unused]] uint32_t surface_width, + [[maybe_unused]] uint32_t surface_height, + uint32_t surface_pitch, + int32_t surface_offset_x, + int32_t surface_offset_y) { + base_address = base; uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; pitch_in_bytes = surface_pitch * sizeof(dtype); - address = base + offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + address = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); } - __XETLA_API void init(mem_desc_t& mem_tdesc) { - pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); - uint32_t offset_x = mem_tdesc.coord.x; - uint32_t offset_y = mem_tdesc.coord.y; - address = mem_tdesc.base.base + offset_y * pitch_in_bytes + - offset_x * sizeof(dtype); + inline mem_payload_t(mem_desc_t& mem_tdesc) { + init(mem_tdesc); } - - __XETLA_API void init( + inline mem_payload_t( uint32_t base, - [[maybe_unused]] int surface_width, - [[maybe_unused]] int surface_height, - int surface_pitch, - int surface_offset_x, - int surface_offset_y) { - uint32_t offset_x = surface_offset_x; - uint32_t offset_y = surface_offset_y; - pitch_in_bytes = surface_pitch * sizeof(dtype); - address = base + offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + [[maybe_unused]] uint32_t surface_width, + [[maybe_unused]] uint32_t surface_height, + uint32_t surface_pitch, + int32_t surface_offset_x, + int32_t surface_offset_y) { + init( + base, + surface_width, + surface_height, + surface_pitch, + surface_offset_x, + surface_offset_y); } inline mem_payload_t(const this_payload_t& rhs) { + this->base_address = rhs.base_address; this->address = rhs.address; this->pitch_in_bytes = rhs.pitch_in_bytes; } inline mem_payload_t() = default; inline this_payload_t& operator=(const this_payload_t& rhs) { + this->base_address = rhs.base_address; this->address = rhs.address; this->pitch_in_bytes = rhs.pitch_in_bytes; return *this; } template - __XETLA_API void update_tdesc(int offset) { + __XETLA_API void update_tdesc(int32_t offset) { if constexpr (update_dir == tdesc_update_dir::x_dir) { address += offset * sizeof(dtype); } else { address += offset * pitch_in_bytes; } } + + __XETLA_API void set_base_address(uint32_t base) { + this->base_address = base; + } }; /// @brief Is to describe the global memory surface for unaligned-2d load/store @@ -685,7 +848,7 @@ struct mem_payload_t< static constexpr mem_layout memory_layout = mem_layout_; static constexpr msg_type message_type = msg_type::unaligned_2d; static constexpr uint32_t alignment_in_bytes = mem_desc_t::alignment_in_bytes; - static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + static constexpr gpu_arch arch_tag = arch_tag_; private: static constexpr uint32_t tile_size_x = tile_desc::tile_size_x; @@ -693,11 +856,8 @@ struct mem_payload_t< static constexpr uint32_t block_size_x = tile_desc::block_size_x; static constexpr uint32_t block_size_y = tile_desc::block_size_y; - using this_payload_t = mem_payload_t< - mem_desc_t, - tile_desc, - msg_type::unaligned_2d, - gpu_arch::XeHpc>; + using this_payload_t = + mem_payload_t; public: static constexpr bool mem_transpose = memory_layout == mem_layout::col_major; @@ -886,7 +1046,6 @@ struct mem_payload_t< } } }; - /// @brief Is to describe the global memory surface for unaligned-2d load/store /// for each block in one tile, a payload message is prepared here. /// in tile_load case, memory transpose, register transpose, memory transform @@ -908,7 +1067,8 @@ struct mem_payload_t< msg_type::block_2d, arch_tag_, std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpg)>> { - using dtype = dtype_; + using dtype = + std::conditional_t, uint8_t, dtype_>; using mem_desc_t = mem_desc_t; using tile_desc = tile_desc_; @@ -960,10 +1120,20 @@ struct mem_payload_t< // for pvc, we can use simd16 or simd32 static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype); static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype); - static constexpr uint32_t num_channel = block_size_y; + static constexpr uint32_t simd_channel = + ((tile_bytes % max_store_bytes) == 0 && + (block_bytes % max_store_bytes) == 0) + ? 32 + : 16; + static constexpr uint32_t num_channel = mem_transpose + ? (simd_channel >= block_size_x) ? block_size_x : simd_channel + : (simd_channel >= block_size_y) ? block_size_y + : simd_channel; static constexpr uint32_t simd_exec_size = - block_size_x >= pack_factor ? block_size_x / pack_factor : 1; + (mem_transpose ? block_size_y : block_size_x) >= pack_factor + ? (mem_transpose ? block_size_y : block_size_x) / pack_factor + : 1; xetla_vector channel_offset; xetla_vector step_x; @@ -1020,6 +1190,7 @@ struct mem_payload_t< pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); base_x = mem_tdesc.coord.x; base_y = mem_tdesc.coord.y; + width_in_elems = mem_tdesc.shape.x; height_in_elems = mem_tdesc.shape.y; base_offset = mem_transpose @@ -1135,6 +1306,8 @@ struct mem_payload_t< public: static constexpr reg_layout register_layout = tile_desc::register_layout; + static constexpr bool reg_transpose = + register_layout == reg_layout::transpose_tiled; static constexpr bool mem_transform = (sizeof(dtype) < 4) && register_layout == reg_layout::vnni_tiled; @@ -1271,6 +1444,8 @@ struct mem_payload_t< /// @tparam dtype Is the data type /// @tparam tile_desc_ Is the tile descriptor /// @tparam mem_layout_ Is the memory layout +/// @note this is used for Atrans SLM path, so, we can add more limitation for +/// best performance. template < typename dtype_, uint32_t tile_size_x_, @@ -1321,139 +1496,100 @@ struct mem_payload_t< using store_dtype = uint32_t; static constexpr uint32_t vnni_scale_factor = sizeof(store_dtype) / sizeof(dtype); - static constexpr uint32_t is_simd16_vec = - (block_size_x == 16) && ((tile_size_y & (tile_size_y - 1)) == 0); - static constexpr uint32_t num_vector_size = is_simd16_vec - ? detail::gcd::value - : 1; - - static constexpr uint32_t min_store_bytes = 16 * sizeof(store_dtype); - static constexpr uint32_t max_store_bytes = 32 * sizeof(store_dtype); - static constexpr uint32_t num_channel = is_simd16_vec - ? 16 - : (((tile_bytes % max_store_bytes) == 0 && - (block_bytes % max_store_bytes) == 0) - ? 32 - : 16); - static constexpr uint32_t num_channel_x = block_size_x; - static constexpr uint32_t num_channel_y = - is_simd16_vec ? 1 : num_channel / num_channel_x; + static_assert( + block_size_x % 16 == 0, + "block size x at least need to be 16 channel aligned"); + static constexpr uint32_t num_channel = block_size_x; + static constexpr uint32_t max_vector_size = (block_size_x == 16) ? 8 : 4; + static constexpr uint32_t num_vector_size = + detail::gcd::value; static constexpr uint32_t store_elems = - num_channel_y * num_vector_size * vnni_scale_factor * block_size_x; - xetla_vector address; + num_channel * num_vector_size * vnni_scale_factor; + + uint32_t base_address; uint32_t pitch_in_bytes; - uint32_t cyclic_count; - uint32_t wg_width_in_bytes; - uint32_t wg_height_in_elems; + xetla_vector channel_address; - // Be aware of the risks: Rule of three (copy constructor, copy assignment, - // destructor) Please check if you need to add self-define destructor - // ~mem_payload_t(){} - inline mem_payload_t(mem_desc_t mem_tdesc) { - xetla_tdescriptor base_tdesc = mem_tdesc.get_tdesc(); - cyclic_count = 0; - pitch_in_bytes = base_tdesc[4]; - wg_width_in_bytes = base_tdesc[2]; - wg_height_in_elems = base_tdesc[3]; - uint32_t offset_x = base_tdesc[5]; - uint32_t offset_y = base_tdesc[6]; - uint32_t start_address = base_tdesc[0]; - start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + __XETLA_API void init(mem_desc_t& mem_tdesc) { + base_address = mem_tdesc.base.base; + pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); + uint32_t offset_x = mem_tdesc.coord.x; // because this is row-major + uint32_t offset_y = mem_tdesc.coord.y; + uint32_t offset_address = + offset_y * pitch_in_bytes + offset_x * sizeof(dtype); xetla_vector channel_index = xetla_vector_gen(0, 1); - address = start_address + (channel_index % num_channel_x) * pitch_in_bytes + - (channel_index / num_channel_x) * sizeof(store_dtype); + channel_address = offset_address + channel_index * pitch_in_bytes; } - - inline mem_payload_t( + __XETLA_API void init( uint32_t base, - int surface_width, - int surface_height, - int surface_pitch, - int surface_offset_x, - int surface_offset_y) { + [[maybe_unused]] uint32_t surface_width, + [[maybe_unused]] uint32_t surface_height, + uint32_t surface_pitch, + int32_t surface_offset_x, + int32_t surface_offset_y) { + base_address = base; pitch_in_bytes = surface_pitch * sizeof(dtype); - wg_width_in_bytes = surface_width * sizeof(dtype); - wg_height_in_elems = surface_height; uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; - uint32_t start_address = base; - start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + uint32_t offset_address = + offset_y * pitch_in_bytes + offset_x * sizeof(dtype); xetla_vector channel_index = xetla_vector_gen(0, 1); - address = start_address + - ((channel_index % num_channel_x) * pitch_in_bytes + - (channel_index / num_channel_x) * sizeof(store_dtype)); - cyclic_count = 0; + channel_address = offset_address + channel_index * pitch_in_bytes; } - __XETLA_API void init( + inline mem_payload_t( uint32_t base, - int surface_width, - int surface_height, - int surface_pitch, - int surface_offset_x, - int surface_offset_y) { - pitch_in_bytes = surface_pitch * sizeof(dtype); - wg_width_in_bytes = surface_width * sizeof(dtype); - wg_height_in_elems = surface_height; - uint32_t offset_x = surface_offset_x; - uint32_t offset_y = surface_offset_y; - uint32_t start_address = base; - start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - xetla_vector channel_index = - xetla_vector_gen(0, 1); - address = start_address + - ((channel_index % num_channel_x) * pitch_in_bytes + - (channel_index / num_channel_x) * sizeof(store_dtype)); - cyclic_count = 0; + uint32_t surface_width, + uint32_t surface_height, + uint32_t surface_pitch, + int32_t surface_offset_x, + int32_t surface_offset_y) { + init( + base, + surface_width, + surface_height, + surface_pitch, + surface_offset_x, + surface_offset_y); } - - __XETLA_API void init(mem_desc_t mem_tdesc) { - xetla_tdescriptor base_tdesc = mem_tdesc.get_tdesc(); - cyclic_count = 0; - pitch_in_bytes = base_tdesc[4]; - wg_width_in_bytes = base_tdesc[2]; - wg_height_in_elems = base_tdesc[3]; - uint32_t offset_x = base_tdesc[5]; - uint32_t offset_y = base_tdesc[6]; - uint32_t start_address = base_tdesc[0]; - start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - xetla_vector channel_index = - xetla_vector_gen(0, 1); - address = start_address + (channel_index % num_channel_x) * pitch_in_bytes + - (channel_index / num_channel_x) * sizeof(store_dtype); + // Be aware of the risks: Rule of three (copy constructor, copy assignment, + // destructor) Please check if you need to add self-define destructor + // ~mem_payload_t(){} + inline mem_payload_t(mem_desc_t& mem_tdesc) { + init(mem_tdesc); } inline mem_payload_t(const this_payload_t& rhs) { - this->address = rhs.address; + this->base_address = rhs.base_address; this->pitch_in_bytes = rhs.pitch_in_bytes; - this->cyclic_count = 0; - this->wg_width_in_bytes = rhs.wg_width_in_bytes; - this->wg_height_in_elems = rhs.wg_height_in_elems; + this->channel_address = rhs.channel_address; } inline mem_payload_t() = default; inline this_payload_t& operator=(const this_payload_t& rhs) { - this->address = rhs.address; + this->base_address = rhs.base_address; this->pitch_in_bytes = rhs.pitch_in_bytes; - this->cyclic_count = 0; - this->wg_width_in_bytes = rhs.wg_width_in_bytes; - this->wg_height_in_elems = rhs.wg_height_in_elems; + this->channel_address = rhs.channel_address; return *this; } template __XETLA_API void update_tdesc(int offset) { if constexpr (update_dir == tdesc_update_dir::x_dir) { - address += offset * sizeof(dtype); + channel_address += offset * sizeof(dtype); } else { - address += offset * pitch_in_bytes; + channel_address += offset * pitch_in_bytes; } } -}; -/// @brief Is to describe the global memory surface to prefetch data to cache + __XETLA_API void set_base_address(uint32_t base) { + this->base_address = base; + } +}; +/// @brief Is to describe the global memory +/// surface to prefetch data to cache /// data in global memory will be prefetched into 2d tile /// @tparam tile_desc_ Is the tile descriptor /// @tparam dtype Is the data type @@ -1494,6 +1630,7 @@ struct prefetch_payload_t< reg_layout_>; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; + static constexpr msg_type message_type = msg_type::block_2d; static constexpr uint32_t alignment_in_bytes = mem_desc_t::alignment_in_bytes; static constexpr gpu_arch arch_tag = arch_tag_; @@ -1502,24 +1639,11 @@ struct prefetch_payload_t< static constexpr uint32_t tile_size_y = tile_desc::tile_size_y; static constexpr uint32_t block_size_x = tile_desc::block_size_x; static constexpr uint32_t block_size_y = tile_desc::block_size_y; - static constexpr msg_type message_type = msg_type::block_2d; static constexpr uint32_t tile_bytes = tile_size_x * tile_size_y * sizeof(dtype); static constexpr uint32_t block_bytes = block_size_x * block_size_y * sizeof(dtype); - using prefetch_dtype = typename std::conditional< - (alignment_in_bytes % (sizeof(uint64_t)) == 0), - uint64_t, - typename std::conditional< - (alignment_in_bytes % (sizeof(uint32_t)) == 0), - uint32_t, - dtype>::type>::type; - static constexpr uint32_t scale_factor = - sizeof(prefetch_dtype) / sizeof(dtype); - static constexpr uint32_t simd_exec_size = - block_size_x / scale_factor <= 0 ? 1 : block_size_x / scale_factor; - private: using this_payload_t = prefetch_payload_t; @@ -1530,7 +1654,23 @@ struct prefetch_payload_t< static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; static constexpr bool trans = mem_transpose ^ reg_transpose; - static constexpr uint32_t num_channel = block_size_y; + + using prefetch_dtype = typename std::conditional< + (alignment_in_bytes % (sizeof(uint64_t)) == 0), + uint64_t, + typename std::conditional< + (alignment_in_bytes % (sizeof(uint32_t)) == 0), + uint32_t, + dtype>::type>::type; + static constexpr uint32_t pack_factor = + sizeof(prefetch_dtype) / sizeof(dtype); + + static constexpr uint32_t simd_exec_size = + (mem_transpose ? block_size_y : block_size_x) >= pack_factor + ? (mem_transpose ? block_size_y : block_size_x) / pack_factor + : 1; + static constexpr uint32_t num_channel = + mem_transpose ? block_size_x : block_size_y; static constexpr uint32_t mem_tile_size_w = mem_transpose ? tile_size_y : tile_size_x; @@ -1652,9 +1792,9 @@ struct prefetch_payload_t< xetla_vector_gen(0, 1); channel_offset = channel_index * pitch_in_bytes; } - // Be aware of the risks: Rule of three (copy constructor, copy assignment, - // destructor) Please check if you need to add self-define destructor - // ~prefetch_payload_t(){} + // Be aware of the risks: Rule of three (copy constructor, copy + // assignment, destructor) Please check if you need to add self-define + // destructor ~prefetch_payload_t(){} template __XETLA_API void update_tdesc(int offset) { @@ -1720,9 +1860,8 @@ struct prefetch_payload_t< is_col_major ? tile_size_y : tile_size_x; static constexpr uint32_t mem_tile_size_h = is_col_major ? tile_size_x : tile_size_y; - - using load_store_attr = - typename arch_attr_t::template load_store_attr; + using load_store_attr = typename arch_attr_t< + arch_tag>::template load_store_attr; static constexpr uint32_t special_prefetch_width = load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype); static constexpr uint32_t normal_prefetch_width = @@ -1800,6 +1939,16 @@ struct prefetch_payload_t< surface_offset_y + coop_id_y * tile_size_h); prepare_tdesc(base_tdesc); } + + inline void init(xetla_tdescriptor base_tdesc, uint32_t coop_id = 0) { + uint32_t coop_id_x = coop_id % num_coop_sg_w; + uint32_t coop_id_y = coop_id / num_coop_sg_w; + xetla_update_tdesc_offsetx( + base_tdesc.xetla_format(), coop_id_x * tile_size_w); + xetla_update_tdesc_offsety( + base_tdesc.xetla_format(), coop_id_y * tile_size_h); + prepare_tdesc(base_tdesc); + } // Be aware of the risks: Rule of three (copy constructor, copy assignment, // destructor) Please check if you need to add self-define destructor // ~prefetch_payload_t(){} @@ -1819,6 +1968,92 @@ struct prefetch_payload_t< } } } + __XETLA_API void set_tdesc_width(uint32_t size) { + auto tdesc_2d = tdesc_prefetch.xetla_format(); +#pragma unroll + for (int i = 0; i < num_tdesc; i++) { + xetla_set_tdesc_width(tdesc_2d.row(i), size); + } + } + + __XETLA_API void set_tdesc_pitch(uint32_t size) { + auto tdesc_2d = tdesc_prefetch.xetla_format(); +#pragma unroll + for (int i = 0; i < num_tdesc; i++) { + xetla_set_tdesc_pitch(tdesc_2d.row(i), size); + } + } + + __XETLA_API void set_tdesc_height(uint32_t size) { + auto tdesc_2d = tdesc_prefetch.xetla_format(); +#pragma unroll + for (int i = 0; i < num_tdesc; i++) { + xetla_set_tdesc_height(tdesc_2d.row(i), size); + } + } + + __XETLA_API void update_tdesc_base_address(int offset) { + auto tdesc_2d = tdesc_prefetch.xetla_format(); +#pragma unroll + for (int i = 0; i < num_tdesc; i++) { + xetla_update_tdesc_base_address(tdesc_2d.row(i), offset); + } + } + + __XETLA_API void set_tdesc_base_address(uint64_t addr) { + auto tdesc_2d = tdesc_prefetch.xetla_format(); +#pragma unroll + for (int i = 0; i < num_tdesc; i++) { + gpu::xetla::detail::xetla_set_tensor_base_address(tdesc_2d.row(i), addr); + } + } + + __XETLA_API void update_tdesc_base_address_masked( + int offset, + uint16_t mask = 1) { + auto tdesc_2d = tdesc_prefetch.xetla_format(); +#pragma unroll + for (int i = 0; i < num_tdesc; i++) { + xetla_update_tdesc_base_address(tdesc_2d.row(i), offset); + } + +#pragma unroll + for (int i = 0; i < num_tdesc; i++) { + xetla_tdesc_mask_op(tdesc_2d.row(i), mask); + } + } + + __XETLA_API void set_offset( + int32_t offset_x, + int32_t offset_y, + uint32_t coop_id = 0) { + uint32_t coop_id_x = coop_id % num_coop_sg_w; + uint32_t coop_id_y = coop_id / num_coop_sg_w; + + auto tdesc_2d = tdesc_prefetch.xetla_format(); + int32_t base_offset_y = offset_y + + (is_col_major ? coop_id_x * tile_size_w : coop_id_y * tile_size_h); +#pragma unroll + for (int i = 0; i < num_block_h; i++) { + auto tdesc_row_2d = + tdesc_2d.xetla_select(i * num_block_w, 0); + + int32_t base_offset_x = offset_x + + (is_col_major ? coop_id_y * tile_size_h : coop_id_x * tile_size_w); +#pragma unroll + for (int j = 0; j < num_block_w; j++) { + int32_t offset_width = is_col_major ? base_offset_y : base_offset_x; + int32_t offset_height = is_col_major ? base_offset_x : base_offset_y; + gpu::xetla::detail::xetla_set_tensor_offset_x( + tdesc_row_2d.row(j), offset_width); + gpu::xetla::detail::xetla_set_tensor_offset_y( + tdesc_row_2d.row(j), offset_height); + + base_offset_x += block_size_w; + } + base_offset_y += block_size_h; + } + } private: __XETLA_API void prepare_tdesc(xetla_tdescriptor base_tdesc) { @@ -1893,7 +2128,6 @@ struct prefetch_payload_t< using tile_desc = tile_desc_t; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; - static constexpr msg_type message_type = msg_type::block_2d; static constexpr gpu_arch arch_tag = arch_tag_; private: diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index efec6b698..0f9e2d3bb 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -107,17 +107,11 @@ tile_prefetch(payload_t& payload) { #pragma unroll for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; sub_block_y += num_channel) { - uint32_t address_offset = payload_t::trans + uint32_t address_offset = payload_t::mem_transpose ? offset_x * payload.pitch_in_bytes + (offset_y + sub_block_y) * sizeof(dtype) : offset_x * sizeof(dtype) + (offset_y + sub_block_y) * payload.pitch_in_bytes; - xetla_mask pred_y = - payload.base_y + offset_y + sub_block_y + num_channel > - payload.height_in_elems - ? (xetla_vector_gen(0, 1) < - (payload.height_in_elems % num_channel)) - : 1; xetla_prefetch_global< prefetch_dtype, @@ -128,7 +122,7 @@ tile_prefetch(payload_t& payload) { payload_t::num_channel>( payload.base_ptr, payload.channel_offset + payload.base_offset + address_offset, - pred_y); + 1); } } } diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index a84469d6e..56196da6d 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -28,6 +28,7 @@ namespace gpu::xetla::subgroup { namespace detail { template struct check_store_type { + static constexpr bool is_lsc_scatter = true; static constexpr bool is_global_block_2d = (payload_t::memory_space == mem_space::global && (payload_t::message_type == msg_type::block_2d)); @@ -116,10 +117,16 @@ tile_store(tile_t& tile, payload_t& payload) { block_size_y > max_store_block_height ? max_store_block_height : block_size_y; // to make sure full CL store - static constexpr uint32_t st_block_x = - (((tile_size_x % elems_per_CL) == 0) ? elems_per_CL : block_size_x); + static constexpr uint32_t st_block_x = ((tile_size_x % elems_per_CL) == 0) + ? elems_per_CL + : (((elems_per_CL % tile_size_x) == 0) ? tile_size_x : block_size_x); - static constexpr uint8_t arr_len = st_block_x / block_size_x; + static constexpr uint8_t arr_len_candidate = st_block_x / block_size_x; + static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) || + (arr_len_candidate == 2) || (arr_len_candidate == 4); + + static constexpr uint8_t arr_len = + is_valid_arr_len_candidate ? arr_len_candidate : 1; auto payload_2d = payload.payloads.xetla_format(); #pragma unroll @@ -295,12 +302,7 @@ tile_store(tile_t& tile, payload_t& payload) { tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); - xetla_store_global< - store_dtype, - max_store_vec_len, - data_size::default_size, - L1, - L2>( + xetla_store_global( payload.base_ptr, payload.base_offset + address_offset, reg_sub.xetla_format()); @@ -347,8 +349,9 @@ tile_store( using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; using store_dtype = typename payload_t::mem_dtype; + constexpr uint32_t num_channel_y = payload_t::num_channel_y; - constexpr uint32_t load_elems = num_channel_y * payload_t::num_channel_x; + constexpr uint32_t store_elems = num_channel_y * payload_t::num_channel_x; constexpr uint32_t scale_factor = payload_t::scale_factor; #pragma unroll @@ -360,13 +363,13 @@ tile_store( uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); - xetla_mask pred_x = oob_check + xetla_mask pred_x = oob_check ? payload.step_x + payload.base_x + offset_x < payload.width_in_elems : 1; #pragma unroll for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; sub_block_y += num_channel_y) { - xetla_mask pred_y = oob_check + xetla_mask pred_y = oob_check ? payload.step_y + payload.base_y + offset_y + sub_block_y < payload.height_in_elems : 1; @@ -379,11 +382,11 @@ tile_store( data_size::default_size, L1, L3, - load_elems>( + store_elems>( payload.base_ptr, (payload.base_offset + address_offset + payload.channel_offset), reg_sub - .xetla_select( + .xetla_select( sub_block_y * tile_desc::block_size_x) .xetla_format(), (pred_x && pred_y)); @@ -402,13 +405,13 @@ tile_store( uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( processed_elems + j * remain_block_elems); - xetla_mask pred_x = oob_check + xetla_mask pred_x = oob_check ? payload.step_x + payload.base_x + offset_x < payload.width_in_elems : 1; #pragma unroll for (uint32_t sub_block_y = 0; sub_block_y < remained_size_y; sub_block_y += num_channel_y) { - xetla_mask pred_y = oob_check + xetla_mask pred_y = oob_check ? payload.step_y + payload.base_y + offset_y + sub_block_y < payload.height_in_elems : 1; @@ -421,11 +424,11 @@ tile_store( data_size::default_size, L1, L3, - load_elems>( + store_elems>( payload.base_ptr, (payload.base_offset + address_offset + payload.channel_offset), reg_sub - .xetla_select( + .xetla_select( sub_block_y * tile_desc::block_size_x) .xetla_format(), (pred_x && pred_y)); @@ -433,7 +436,6 @@ tile_store( } } } - /// @brief Is the func storing data from register file to unaligned global /// memory surface. store a rectangular region (X,Y)..(X+W,Y+H) into memory from /// registers. @@ -455,6 +457,7 @@ template < typename payload_t> __XETLA_API typename std::enable_if_t< detail::check_store_type::is_global_block_2d && + detail::check_store_type::is_lsc_scatter && !arch_has_2d_load_store> tile_store(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; @@ -462,7 +465,7 @@ tile_store(tile_t& tile, payload_t& payload) { using store_dtype = typename payload_t::mem_dtype; constexpr uint32_t num_channel = payload_t::num_channel; - constexpr uint32_t load_elems = num_channel * payload_t::simd_exec_size; + constexpr uint32_t store_elems = num_channel * payload_t::simd_exec_size; constexpr uint32_t pack_factor = payload_t::pack_factor; #pragma unroll @@ -480,12 +483,12 @@ tile_store(tile_t& tile, payload_t& payload) { uint32_t address_offset = offset_x * sizeof(dtype) + (offset_y + sub_block_y) * payload.pitch_in_bytes; - xetla_vector reg_tmp; + xetla_vector reg_tmp; if constexpr (payload_t::simd_exec_size > 1) { - xetla_vector reg_sub_before_trans = + xetla_vector reg_sub_before_trans = reg_sub - .xetla_select( + .xetla_select( sub_block_y * tile_desc::block_size_x) .xetla_format(); #pragma unroll @@ -498,7 +501,7 @@ tile_store(tile_t& tile, payload_t& payload) { } } else { reg_tmp = reg_sub - .xetla_select( + .xetla_select( sub_block_y * tile_desc::block_size_x) .xetla_format(); } @@ -544,6 +547,61 @@ tile_store(tile_t& tile, payload_t& payload) { } } +/// @brief Is the func storing data from register file to unaligned global +/// memory surface. store a rectangular region (X,Y)..(X+W,Y+H) into memory from +/// registers. +/// @tparam tile_t Is the tile_t struct contains registers +/// These registers will be the source of store operation. +/// @tparam payload_t Is the mem_payload_t struct describing the memory info +/// payload indicates the destination of store operation. +/// @tparam L1 Is the cache hint for L1 cache. +/// @tparam L3 Is the cache hint for L3 cache. +/// @param tile Is the tile object with type tile_t, contains the data to be +/// stored. +/// @param payload Is the payload object with type payload_t. Contains all the +/// information for stores. +/// @return No return, update in place. +template < + cache_hint L1 = cache_hint::write_back, + cache_hint L2 = cache_hint::write_back, + typename tile_t, + typename payload_t> +__XETLA_API typename std::enable_if_t< + detail::check_store_type::is_global_block_2d && + !detail::check_store_type::is_lsc_scatter && + !arch_has_2d_load_store> +tile_store(tile_t& tile, payload_t& payload) { + using dtype = typename payload_t::dtype; + using tile_desc = typename payload_t::tile_desc; + constexpr uint32_t store_elems = tile_desc::block_size_x; + +#pragma unroll + for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { + uint32_t offset_y = i * tile_desc::block_size_y; +#pragma unroll + for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { + uint32_t offset_x = j * tile_desc::block_size_x; + auto reg_sub = tile.reg.xetla_select( + (i * tile_desc::num_block_x + j) * tile_desc::block_elems); +#pragma unroll + for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; + sub_block_y += 1) { + uint32_t address_offset = offset_x * sizeof(dtype) + + (offset_y + sub_block_y) * payload.pitch_in_bytes; + + xetla_vector reg_tmp; + reg_tmp = reg_sub.xetla_select( + sub_block_y * tile_desc::block_size_x); + + xetla_store_global( + (dtype*)payload.base_ptr, + (payload.base_offset + address_offset), + reg_tmp); + } + } + } +} + /// @brief Is the func storing data from register file to global memory /// enable atomic adding data into the same buffer, but support float32, /// float64, uint32_t, uint64_t and int type @@ -615,7 +673,8 @@ tile_store( L1, L2, op_kind, - payload_t::arch_tag>( + payload_t::arch_tag, + typename payload_t::Toffset>( payload.base_pointer + address_offset, payload.channel_offset, reg_sub.xetla_select( @@ -669,7 +728,8 @@ tile_store( L1, L2, op_kind, - payload_t::arch_tag>( + payload_t::arch_tag, + typename payload_t::Toffset>( (uint64_t)payload.base_pointer + address_offset, payload.channel_offset, reg_sub.xetla_select( @@ -785,7 +845,6 @@ tile_store(tile_t& tile, payload_t& payload) { constexpr uint32_t vnni_scale_factor = payload_t::vnni_scale_factor; constexpr uint32_t num_vector_size = payload_t::num_vector_size; - constexpr uint32_t num_channel_y = payload_t::num_channel_y; constexpr uint32_t store_elems = payload_t::store_elems; #pragma unroll for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; @@ -798,11 +857,12 @@ tile_store(tile_t& tile, payload_t& payload) { (i * tile_desc::num_block_x + j) * tile_desc::block_elems); #pragma unroll for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; - sub_block_y += num_channel_y * num_vector_size * vnni_scale_factor) { - uint32_t address_offset = offset_x * payload.pitch_in_bytes + + sub_block_y += num_vector_size * vnni_scale_factor) { + uint32_t address_offset = payload.base_address + + offset_x * payload.pitch_in_bytes + (sub_block_y + offset_y) * sizeof(dtype); xetla_store_local( - payload.address + address_offset, + payload.channel_address + address_offset, reg_sub .xetla_select( sub_block_y * tile_desc::block_size_x) @@ -824,11 +884,12 @@ tile_store(tile_t& tile, payload_t& payload) { processed_elems + j * remain_block_elems); #pragma unroll for (uint32_t sub_block_y = 0; sub_block_y < remained_size_y; - sub_block_y += num_channel_y * num_vector_size * vnni_scale_factor) { - uint32_t address_offset = offset_x * payload.pitch_in_bytes + + sub_block_y += num_vector_size * vnni_scale_factor) { + uint32_t address_offset = payload.base_address + + offset_x * payload.pitch_in_bytes + (sub_block_y + offset_y) * sizeof(dtype); xetla_store_local( - payload.address + address_offset, + payload.channel_address + address_offset, reg_sub .xetla_select( sub_block_y * tile_desc::block_size_x) @@ -885,7 +946,8 @@ tile_store(tile_t& tile, payload_t& payload) { #pragma unroll for (uint32_t row_i = 0; row_i < tile_desc::block_size_y; row_i++) { xetla_store_local( - payload.address + address_offset + row_i * payload.pitch_in_bytes, + payload.base_address + payload.address + address_offset + + row_i * payload.pitch_in_bytes, reg_sub_2d.row(row_i).xetla_format()); } } @@ -910,7 +972,8 @@ tile_store(tile_t& tile, payload_t& payload) { #pragma unroll for (uint32_t row_i = 0; row_i < remained_size_y; row_i++) { xetla_store_local( - payload.address + address_offset + row_i * payload.pitch_in_bytes, + payload.base_address + payload.address + address_offset + + row_i * payload.pitch_in_bytes, reg_sub_2d.row(row_i).xetla_format()); } } @@ -969,7 +1032,11 @@ tile_store(tile_t& tile, payload_t& payload) { (max_store_vec_len >> 1), detail::process_flag::store, L1, - L2>(tile, payload, store_iter_steps * max_store_vec_len * scale_factor); + L2>( + tile, + payload, + store_iter_steps * max_store_vec_len * scale_factor, + store_iter_steps * max_store_vec_len * scale_factor * sizeof(dtype)); } } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 0323ca8bf..644717df8 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace gpu::xetla::subgroup { @@ -39,10 +40,19 @@ struct none_op_t { [[maybe_unused]] const arguments_t& args, [[maybe_unused]] uint32_t slm_base = 0, [[maybe_unused]] uint32_t nbarrier_base = 0) {} + // none_op_t functor for dequant_op + template + __XETLA_API KERNEL_FUNC void operator()( + [[maybe_unused]] mat_out_t& mat_out, + [[maybe_unused]] mat_in_t& mat_in, + [[maybe_unused]] const coord_t& coord, + [[maybe_unused]] const arguments_t& args) { + mat_out = mat_in; + } }; /// @brief Is the element-wise relu op functor. -/// Get the relu input from matAcc, update the the relu output in place, +/// Get the relu input from matAcc, update the relu output in place, /// Used in epilogue::tile_op or chained_tile_op. struct relu_op_t { struct arguments_t {}; @@ -53,7 +63,7 @@ struct relu_op_t { [[maybe_unused]] const arguments_t& args, [[maybe_unused]] uint32_t slm_base = 0, [[maybe_unused]] uint32_t nbarrier_base = 0) { - xetla_mask mask = matAcc.reg <= 0; + xetla_mask mask = matAcc.reg < 0; matAcc.reg.xetla_merge(0, mask); } }; @@ -77,11 +87,11 @@ struct tanh_op_t { auto sub_vec = matAcc.reg.xetla_select(elems * i); sub_vec = xetla_tanh(sub_vec); } - constexpr int remained_elems = matAcc_t::tile_desc::tile_elems % elems; - if constexpr (remained_elems != 0) { - auto sub_vec = matAcc.reg.xetla_select( + constexpr int remaining_elems = matAcc_t::tile_desc::tile_elems % elems; + if constexpr (remaining_elems != 0) { + auto sub_vec = matAcc.reg.xetla_select( elems * (matAcc_t::tile_elems / elems)); - sub_vec = xetla_tanh(sub_vec); + sub_vec = xetla_tanh(sub_vec); } } }; @@ -99,29 +109,49 @@ struct sigmoid_op_t { [[maybe_unused]] uint32_t nbarrier_base = 0) { constexpr int elems = matAcc_t::tile_desc::block_elems; constexpr int rounds = matAcc_t::tile_desc::tile_elems / elems; - constexpr float one = 1.0f; #pragma unroll for (uint32_t i = 0; i < rounds; ++i) { auto sub_vec = matAcc.reg.xetla_select(elems * i); - xetla_mask mask = sub_vec >= 10; - xetla_vector temp_vec = - xetla_exp(sub_vec); + sub_vec = xetla_sigmoid(sub_vec); + } + constexpr int remaining_elems = matAcc_t::tile_desc::tile_elems % elems; + if constexpr (remaining_elems != 0) { + auto sub_vec = matAcc.reg.xetla_select( + elems * (matAcc_t::tile_elems / elems)); + sub_vec = + xetla_sigmoid(sub_vec); + } + } +}; + +/// @brief Is the element-wise silu op functor. +/// Get the silu input from matAcc, update the the silu output in place, +/// Used in epilogue::tile_op or chained_tile_op. +struct silu_op_t { + struct arguments_t {}; + template + __XETLA_API KERNEL_FUNC void operator()( + [[maybe_unused]] matAcc_t& matAcc, + [[maybe_unused]] const coord_t& coord, + [[maybe_unused]] const arguments_t& args, + [[maybe_unused]] uint32_t slm_base = 0, + [[maybe_unused]] uint32_t nbarrier_base = 0) { + constexpr int elems = matAcc_t::tile_desc::block_elems; + constexpr int rounds = matAcc_t::tile_desc::tile_elems / elems; +#pragma unroll + for (int i = 0; i < rounds; ++i) { + auto sub_vec = matAcc.reg.xetla_select(elems * i); xetla_vector sigmoid_value = - temp_vec / (temp_vec + one); - sigmoid_value.xetla_merge(1, mask); - sub_vec = sigmoid_value; + xetla_sigmoid(sub_vec); + sub_vec = sub_vec * sigmoid_value; } - constexpr int remained_elems = matAcc_t::tile_desc::tile_elems % elems; - if constexpr (remained_elems != 0) { - auto sub_vec = matAcc.reg.xetla_select( + constexpr int remaining_elems = matAcc_t::tile_desc::tile_elems % elems; + if constexpr (remaining_elems != 0) { + auto sub_vec = matAcc.reg.xetla_select( elems * (matAcc_t::tile_elems / elems)); - xetla_mask mask = sub_vec >= 250; - xetla_vector temp_vec = - xetla_exp(sub_vec); - xetla_vector sigmoid_value = - temp_vec / (temp_vec + one); - sigmoid_value.xetla_merge(1, mask); - sub_vec = sigmoid_value; + xetla_vector sigmoid_value = + xetla_sigmoid(sub_vec); + sub_vec = sub_vec * sigmoid_value; } } }; @@ -144,10 +174,9 @@ struct gelu_fwd_op_t { // total flag register constexpr int elems = 8 * 16; constexpr int rounds = matAcc_t::tile_elems / elems; - - if constexpr (rounds > 0) { + if constexpr (rounds != 0) { #pragma unroll - for (uint32_t i = 0; i < rounds; ++i) { + for (int i = 0; i < rounds; ++i) { auto sub_vec = matAcc.reg.xetla_select(elems * i); xetla_vector sub_vec_x = (sqrt_two_over_pi * sub_vec * (1.f + C0 * sub_vec * sub_vec)); @@ -157,13 +186,14 @@ struct gelu_fwd_op_t { } } - constexpr int remained_elems = matAcc_t::tile_elems % elems; - if constexpr (remained_elems != 0) { - auto sub_vec = matAcc.reg.xetla_select(elems * rounds); - xetla_vector sub_vec_x = + constexpr int remaining_elems = matAcc_t::tile_elems % elems; + if constexpr (remaining_elems != 0) { + auto sub_vec = matAcc.reg.xetla_select( + elems * (matAcc_t::tile_elems / elems)); + xetla_vector sub_vec_x = (sqrt_two_over_pi * sub_vec * (1.f + C0 * sub_vec * sub_vec)); - xetla_vector tanh_value = - xetla_tanh(sub_vec_x); + xetla_vector tanh_value = + xetla_tanh(sub_vec_x); sub_vec = 0.5f * sub_vec * (1.f + tanh_value); } } @@ -886,12 +916,17 @@ struct elemwise_reduce_op_t< template < reduce_op reduce_kind, typename dtype_in, - gpu_arch arch_tag = gpu_arch::XeHpc> + gpu_arch arch_tag = gpu_arch::XeHpc, + class enable = void> struct elemwise_reduce_op_stream_k_t {}; /// @brief Is the element-wise reduce op functor, specialized for Xe /// architecture. -template -struct elemwise_reduce_op_stream_k_t { +template +struct elemwise_reduce_op_stream_k_t< + reduce_kind_, + dtype_in_, + arch_tag, + std::enable_if_t<(arch_tag <= gpu_arch::XeHpc)>> { using dtype_in = dtype_in_; using mem_desc_in_t = mem_desc_t; @@ -909,9 +944,9 @@ struct elemwise_reduce_op_stream_k_t { }; template __XETLA_API KERNEL_FUNC void operator()( - matAcc_t& matAcc, - const coord_t& coord, - const arguments_t& args, + [[maybe_unused]] matAcc_t& matAcc, + [[maybe_unused]] const coord_t& coord, + [[maybe_unused]] const arguments_t& args, [[maybe_unused]] uint32_t slm_base = 0, [[maybe_unused]] uint32_t nbarrier_base = 0) { using dtype_acc = typename matAcc_t::dtype; @@ -933,7 +968,7 @@ struct elemwise_reduce_op_stream_k_t { mem_desc_in_t, mat_in_tile_desc_t, msg_type_v, - gpu_arch::XeHpc>; + arch_tag>; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); mat_in_tile_t mat_in; mat_in_tile_t mat_zero(0); @@ -992,6 +1027,7 @@ struct elemwise_reduce_op_stream_k_t { /// @brief Is the dropout op functor. /// Load the mask from memory and get input from matAcc, /// do the scaling and zero out, update the output in place. +/// The mask has the same layout as the output. /// Used in epilogue::tile_op or chained_tile_op. /// @tparam dtype_mask Is the mask data type. /// @tparam arch_tag Is the hardware architecture tag. @@ -1052,13 +1088,15 @@ struct dropout_op_t< mask_in_tile_t mask_in; mask_in_payload_t mask_in_payload(mem_desc_mask); tile_load(mask_in, mask_in_payload); + if constexpr (tile_elems / unroll_size != 0) { #pragma unroll - for (uint32_t i = 0; i < tile_elems / unroll_size; i++) { - xetla_mask mask_flag = - mask_in.reg.xetla_select(i * unroll_size) > 0; - auto dst_reg = matAcc.reg.xetla_select(i * unroll_size); - dst_reg *= args.scale; - dst_reg.xetla_merge(0, mask_flag); + for (uint32_t i = 0; i < tile_elems / unroll_size; i++) { + xetla_mask mask_flag = + mask_in.reg.xetla_select(i * unroll_size) > 0; + auto dst_reg = matAcc.reg.xetla_select(i * unroll_size); + dst_reg *= args.scale; + dst_reg.xetla_merge(0, mask_flag); + } } if constexpr (tile_elems % unroll_size != 0) { constexpr uint32_t remain_len = tile_elems % unroll_size; @@ -1074,8 +1112,9 @@ struct dropout_op_t< /// @brief Is the random number generator and dropout op functor. /// Generate the mask data and get input from matAcc, do the scaling and zero -/// out, update the output in place, dump the mask buffer to memory. Used in -/// epilogue::tile_op or chained_tile_op. +/// out, update the output in place, dump the mask buffer to memory. The mask +/// has the same layout as the output. Used in epilogue::tile_op or +/// chained_tile_op. /// @tparam dtype_mask Is the mask data type. /// @tparam arch_tag Is the hardware architecture tag. template @@ -1149,12 +1188,9 @@ struct rng_dropout_op_t< // calculate the scale internally float scale = 1.f / (1.f - args.prob); uint32_t threshold = uint32_t(args.prob * float(4294967296)); - xetla_vector rand_offset_v = xetla_load_global< - uint64_t, - 1, - data_size::default_size, - cache_hint::cached, - cache_hint::cached>(args.rand_offset_ptr, 0); + xetla_vector rand_offset_v = + xetla_load_global( + args.rand_offset_ptr, 0); uint64_t rand_offset = rand_offset_v[0]; uint64_t rand_subseq = uint64_t(coord.y) << 32 | uint64_t(coord.x); rand_gen.init(args.rand_seed, rand_subseq, rand_offset); @@ -1162,16 +1198,18 @@ struct rng_dropout_op_t< mem_desc_mask_t mem_desc_mask(args.mask_base, args.mask_shape, coord); mask_out_tile_t mask_out; mask_out_payload_t mask_out_payload(mem_desc_mask); - + if constexpr (tile_elems / random_len != 0) { #pragma unroll - for (uint32_t i = 0; i < tile_elems / random_len; i++) { - auto out_sub = matAcc.reg.xetla_select(i * random_len); - auto mask_sub = mask_out.reg.xetla_select(i * random_len); - xetla_vector rand_val = rand_gen.rand(); - xetla_mask mask_flag = rand_val < threshold; - out_sub *= scale; - out_sub.xetla_merge(0, mask_flag); - mask_sub.xetla_merge(1, 0, mask_flag); + for (uint32_t i = 0; i < tile_elems / random_len; i++) { + auto out_sub = matAcc.reg.xetla_select(i * random_len); + auto mask_sub = + mask_out.reg.xetla_select(i * random_len); + xetla_vector rand_val = rand_gen.rand(); + xetla_mask mask_flag = rand_val < threshold; + out_sub *= scale; + out_sub.xetla_merge(0, mask_flag); + mask_sub.xetla_merge(1, 0, mask_flag); + } } if constexpr (tile_elems % random_len != 0) { constexpr uint32_t remain_len = tile_elems % random_len; diff --git a/tests/integration/data_transformer/common.hpp b/tests/integration/data_transformer/common.hpp index c15991aad..b4309c0da 100644 --- a/tests/integration/data_transformer/common.hpp +++ b/tests/integration/data_transformer/common.hpp @@ -18,21 +18,13 @@ #include #include "xetla.hpp" +#ifdef _WIN32 +#include "utils/windows_functions.hpp" +#endif + using namespace gpu::xetla; using namespace cl::sycl; -namespace { -// abs for floating point types is non-standard and has been deprecated. -// Please use fabs instead. [-Wdeprecated-declarations] -template -inline T _abs(const T& v) { - if constexpr (is_floating_point_v) - return fabs(v); - else - return abs(v); -}; -} // namespace - template int data_transformer_result_validate( data_type_in* in_device, @@ -59,7 +51,7 @@ int data_transformer_result_validate( int idx = i * mat_n + j; cpu_max = - (cpu_max > _abs(in[idx])) ? cpu_max : _abs((data_type_acc)in[idx]); + (cpu_max > fabs(in[idx])) ? cpu_max : abs((data_type_acc)in[idx]); res = out[idx]; @@ -71,7 +63,7 @@ int data_transformer_result_validate( : (data_type_out)(in[j * mat_m + i]); } - if (_abs(res - ref) > _abs(0.01 * res)) { + if (abs(res - ref) > abs(0.01 * res)) { std::cout << "i: " << i << " j: " << j << " idx: " << idx << " in: " << in[idx] << " cpu: " << ref << " gpu: " << res << std::endl; @@ -86,7 +78,7 @@ int data_transformer_result_validate( cpu_max = cpu_max * scale[0]; if (need_fp8_op) { - if (_abs(cpu_max - amax_ptr[0]) > _abs(0.01 * cpu_max)) { + if (abs(cpu_max - amax_ptr[0]) > abs(0.01 * cpu_max)) { std::cout << "cpu_max: " << cpu_max << " gpu_max: " << amax_ptr[0] << std::endl; return 1; @@ -122,7 +114,6 @@ class TestBase { using data_type_in = float; using data_type_out = bf16; using data_type_acc = float; - static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; }; class Test_fp32tobf16_128_64 : public TestBase { diff --git a/tests/integration/fmha/fmha.cpp b/tests/integration/fmha/fmha.cpp index d4f0207d5..1921cf206 100644 --- a/tests/integration/fmha/fmha.cpp +++ b/tests/integration/fmha/fmha.cpp @@ -14,127 +14,171 @@ * limitations under the License. *******************************************************************************/ #include +#include +#include #include +#include #include "fmha_forward.hpp" #include "fmha_forward_policy.h" #include "xetla.hpp" -using FMHA_T = fp16; -using policy_t = stage0; -// using policy_t = stage0; +const auto IS_VERBOSE = false; + +struct test_params_t { + // Q: [FxBxNxH] or [BxFxMxH] ; similar for K/V/O + // BIAS: [1/B, 1/N, 1/F, T] + bool kUseBias; + bool kSeqLast; + uint32_t bs; + uint32_t hn; + uint32_t hs; + uint32_t qlen; + uint32_t klen; -constexpr uint32_t num_batches = 1; -constexpr uint32_t num_heads = 32; -constexpr uint32_t head_size = 128; -constexpr uint32_t num_queries = 1024; -// constexpr uint32_t num_queries = 1; -constexpr uint32_t num_keys = 1024; -constexpr float softmax_scale = 0.125; + static std::vector cases() { + std::vector ret; + std::vector> shapes{ + {1, 32, 64, 1, 33}, + {1, 32, 64, 34, 34}, + {1, 32, 64, 1023, 1023}, + + {1, 32, 128, 1, 33}, + {1, 32, 128, 1, 1023}, + {1, 32, 128, 1, 16384}, + {1, 32, 128, 34, 34}, + {1, 32, 128, 34, 1023}, + {1, 32, 128, 1023, 1023}, + }; + for (auto [bs, hn, hs, qlen, klen] : shapes) + for (auto kUseBias : {false, true}) + for (auto kSeqLast : {false, true}) + ret.emplace_back(kUseBias, kSeqLast, bs, hn, hs, qlen, klen); + return ret; + } -// Q: [FxBxNxH] or [BxFxMxH] -// similar for K/V/O -constexpr bool kSeqLast = true; + std::string to_string() const { + std::vector params; + params.push_back(std::string("kUseBias") + (kUseBias ? "ON" : "OFF")); + params.push_back(std::string("kSeqLast") + (kSeqLast ? "ON" : "OFF")); + params.push_back("bs" + std::to_string(bs)); + params.push_back("hn" + std::to_string(hn)); + params.push_back("hs" + std::to_string(hs)); + params.push_back("qlen" + std::to_string(qlen)); + params.push_back("klen" + std::to_string(klen)); + return std::accumulate( + std::next(params.begin()), + params.end(), + params[0], + [](std::string a, std::string b) { return a + '_' + b; }); + } +}; + +using FMHA_T = fp16; +// using FMHA_T = bf16; -template +template int fma_result_validate( + const test_params_t& p, FMHA_T* q_device, FMHA_T* k_device, FMHA_T* v_device, FMHA_T* DST_device, + FMHA_T* BIAS_device, sycl::queue& queue) { - auto Q_ptr = alloc_host_and_copy( - q_device, num_batches * num_heads * head_size * num_queries, queue); - auto K_ptr = alloc_host_and_copy( - k_device, num_batches * num_heads * head_size * num_keys, queue); - auto V_ptr = alloc_host_and_copy( - v_device, num_batches * num_heads * head_size * num_keys, queue); - auto DST_ptr = alloc_host_and_copy( - DST_device, num_batches * num_heads * head_size * num_queries, queue); + const auto bs = p.bs; + const auto hn = p.hn; + const auto hs = p.hs; + const auto qlen = p.qlen; + const auto klen = p.klen; + const auto klen_pad32 = (klen + 31) / 32 * 32; + const float softmax_scale = 1.f / std::sqrt(p.hs); + auto Q_ptr = + alloc_host_and_copy(q_device, bs * hn * hs * qlen, queue); + auto K_ptr = + alloc_host_and_copy(k_device, bs * hn * hs * klen, queue); + auto V_ptr = + alloc_host_and_copy(v_device, bs * hn * hs * klen, queue); + auto DST_ptr = + alloc_host_and_copy(DST_device, bs * hn * hs * qlen, queue); + auto BIAS_ptr = kUseBias ? alloc_host_and_copy( + BIAS_device, bs * 1 * qlen * klen_pad32, queue) + : nullptr; - std::vector gold_SP( - num_batches * num_heads * num_queries * num_keys, 0); - for (uint32_t gid = 0; gid < num_batches * num_heads; gid++) { - uint32_t batch_id = gid / num_heads; // get batch idx - uint32_t head_id = gid % num_heads; // get head idx + std::vector gold_SP(bs * hn * qlen * klen, 0); + for (uint32_t gid = 0; gid < bs * hn; gid++) { + uint32_t batch_id = gid / hn; // get batch idx + uint32_t head_id = gid % hn; // get head idx const auto Q_cur = kSeqLast - ? Q_ptr + batch_id * head_size * num_heads + head_size * head_id - : Q_ptr + batch_id * num_queries * head_size * num_heads + - head_size * head_id; + ? Q_ptr + batch_id * hs * hn + hs * head_id + : Q_ptr + batch_id * qlen * hs * hn + hs * head_id; const auto K_cur = kSeqLast - ? K_ptr + batch_id * head_size * num_heads + head_size * head_id - : K_ptr + batch_id * num_keys * head_size * num_heads + - head_size * head_id; - const auto gold_cur = gold_SP.data() + gid * num_queries * num_keys; + ? K_ptr + batch_id * hs * hn + hs * head_id + : K_ptr + batch_id * klen * hs * hn + hs * head_id; + const auto gold_cur = gold_SP.data() + gid * qlen * klen; + const auto BIAS_cur = + kUseBias ? BIAS_ptr + batch_id * qlen * klen_pad32 : nullptr; - auto Q_tmp = std::unique_ptr(new FMHA_T[num_queries * head_size]); - for (uint32_t i = 0; i < num_queries; ++i) + auto Q_tmp = std::unique_ptr(new FMHA_T[qlen * hs]); + for (uint32_t i = 0; i < qlen; ++i) std::copy_n( - Q_cur + i * head_size * num_heads * (kSeqLast ? num_batches : 1), - head_size, - Q_tmp.get() + i * head_size); - auto K_tmp = std::unique_ptr(new FMHA_T[num_keys * head_size]); - for (uint32_t i = 0; i < num_keys; ++i) - for (uint32_t j = 0; j < head_size; ++j) - K_tmp[j * num_keys + i] = - K_cur[i * head_size * num_heads * (kSeqLast ? num_batches : 1) + j]; + Q_cur + i * hs * hn * (kSeqLast ? bs : 1), hs, Q_tmp.get() + i * hs); + auto K_tmp = std::unique_ptr(new FMHA_T[klen * hs]); + for (uint32_t i = 0; i < klen; ++i) + for (uint32_t j = 0; j < hs; ++j) + K_tmp[j * klen + i] = K_cur[i * hs * hn * (kSeqLast ? bs : 1) + j]; get_gemm_gold( - num_queries, - num_keys, - head_size, + qlen, + klen, + hs, mem_layout::row_major, mem_layout::row_major, Q_tmp.get(), K_tmp.get(), gold_cur); - for (uint32_t i = 0; i < num_queries; i++) - for (uint32_t j = 0; j < num_keys; j++) - gold_cur[i * num_keys + j] *= - softmax_scale; // TODO(Yi): pass scale + mask - for (uint32_t i = 0; i < num_queries; i++) { + for (uint32_t i = 0; i < qlen; i++) + for (uint32_t j = 0; j < klen; j++) { + gold_cur[i * klen + j] *= softmax_scale; + if constexpr (kUseBias) + gold_cur[i * klen + j] += BIAS_cur[i * klen_pad32 + j]; + } + for (uint32_t i = 0; i < qlen; i++) { accum_t row_max = -INFINITY; accum_t exp_sum = 0; - for (uint32_t j = 0; j < num_keys; j++) - row_max = max(row_max, gold_cur[i * num_keys + j]); - for (uint32_t j = 0; j < num_keys; j++) { - gold_cur[i * num_keys + j] = - std::exp(gold_cur[i * num_keys + j] - row_max); - exp_sum += gold_cur[i * num_keys + j]; + for (uint32_t j = 0; j < klen; j++) + row_max = max(row_max, gold_cur[i * klen + j]); + for (uint32_t j = 0; j < klen; j++) { + gold_cur[i * klen + j] = std::exp(gold_cur[i * klen + j] - row_max); + exp_sum += gold_cur[i * klen + j]; } - for (uint32_t j = 0; j < num_keys; j++) - gold_cur[i * num_keys + j] /= exp_sum; + for (uint32_t j = 0; j < klen; j++) + gold_cur[i * klen + j] /= exp_sum; } } - std::vector gold_DST( - num_batches * num_queries * num_heads * head_size, 0); + std::vector gold_DST(bs * qlen * hn * hs, 0); // second gemm on host - for (uint32_t gid = 0; gid < num_batches * num_heads; gid++) { - uint32_t batch_id = gid / num_heads; // get batch idx - uint32_t head_id = gid % num_heads; // get head idx + for (uint32_t gid = 0; gid < bs * hn; gid++) { + uint32_t batch_id = gid / hn; // get batch idx + uint32_t head_id = gid % hn; // get head idx - // TODO const auto V_cur = kSeqLast - ? V_ptr + batch_id * head_size * num_heads + head_size * head_id - : V_ptr + batch_id * num_keys * head_size * num_heads + - head_size * head_id; - const auto P_cur = gold_SP.data() + gid * num_queries * num_keys; - auto dst_cur = - std::unique_ptr(new accum_t[num_queries * head_size]); - std::fill_n(dst_cur.get(), num_queries * head_size, 0); - auto V_tmp = std::unique_ptr(new FMHA_T[num_keys * head_size]); - for (uint32_t i = 0; i < num_keys; ++i) + ? V_ptr + batch_id * hs * hn + hs * head_id + : V_ptr + batch_id * klen * hs * hn + hs * head_id; + const auto P_cur = gold_SP.data() + gid * qlen * klen; + auto dst_cur = std::unique_ptr(new accum_t[qlen * hs]); + std::fill_n(dst_cur.get(), qlen * hs, 0); + auto V_tmp = std::unique_ptr(new FMHA_T[klen * hs]); + for (uint32_t i = 0; i < klen; ++i) std::copy_n( - V_cur + i * head_size * num_heads * (kSeqLast ? num_batches : 1), - head_size, - V_tmp.get() + i * head_size); + V_cur + i * hs * hn * (kSeqLast ? bs : 1), hs, V_tmp.get() + i * hs); get_gemm_gold( - num_queries, - head_size, - num_keys, + qlen, + hs, + klen, mem_layout::row_major, mem_layout::row_major, P_cur, @@ -142,43 +186,62 @@ int fma_result_validate( dst_cur.get()); // permute 0213 - const auto gold_cur = gold_DST.data() + - batch_id * num_queries * num_heads * head_size + head_id * head_size; - for (uint32_t i = 0; i < num_queries; ++i) + const auto gold_cur = + gold_DST.data() + batch_id * qlen * hn * hs + head_id * hs; + for (uint32_t i = 0; i < qlen; ++i) std::copy_n( - dst_cur.get() + i * head_size, - head_size, - gold_cur + i * num_heads * head_size * (kSeqLast ? num_batches : 1)); + dst_cur.get() + i * hs, + hs, + gold_cur + i * hn * hs * (kSeqLast ? bs : 1)); } buff_cmp::buff_vals data( // DST_ptr, - num_queries * num_heads * num_batches, - head_size, - head_size); + qlen * hn * bs, + hs, + hs); buff_cmp::buff_vals other( - gold_DST.data(), - num_queries * num_heads * num_batches, - head_size, - head_size); - bool result = buff_cmp::xetla_buff_cmp(data, other, "fmha validation"); + gold_DST.data(), qlen * hn * bs, hs, hs); + bool result = buff_cmp::xetla_buff_cmp( + data, other, IS_VERBOSE ? "fmha validation" : ""); free(Q_ptr); free(K_ptr); free(V_ptr); free(DST_ptr); + if (BIAS_ptr) + free(BIAS_ptr); - std::cout << ((!result) ? "FAILED\n" : "PASSED\n"); + if (IS_VERBOSE || !result) + std::cout << (result ? "PASSED\n" : "FAILED\n"); return result ? 0 : 1; } -void fmha_run(uint32_t iter, uint32_t warmup = 10) { +template +void fmha_run_( + const test_params_t& p, + uint32_t iter, + uint32_t warmup, + bool b, + Ts... bs) { + return b ? fmha_run_(p, iter, warmup, bs...) + : fmha_run_(p, iter, warmup, bs...); +} + +template +void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) { + const auto bs = p.bs; + const auto hn = p.hn; + const auto hs = p.hs; + const auto qlen = p.qlen; + const auto klen = p.klen; + const auto klen_pad32 = (klen + 31) / 32 * 32; + const float softmax_scale = 1.f / std::sqrt(p.hs); using fmha_forward_op_t = gpu::xetla::fmha::fmha_forward_t< policy_t, FMHA_T, gpu_arch::XeLpg, - // gpu_arch::XeHpg, - false, false, + kUseBias, false, kSeqLast, false, @@ -191,26 +254,27 @@ void fmha_run(uint32_t iter, uint32_t warmup = 10) { auto context = queue.get_info(); auto device = queue.get_info(); - print_device_details(device); + if (IS_VERBOSE) + print_device_details(device); auto Q = alloc_device_and_init( - num_batches * num_heads * head_size * num_queries, + bs * hn * hs * qlen, [](FMHA_T* data, size_t idx) { - data[idx] = static_cast(random_float()); + data[idx] = static_cast(idx % 11); }, queue, device, context); auto K = alloc_device_and_init( - num_batches * num_heads * head_size * num_keys, + bs * hn * hs * klen, [](FMHA_T* data, size_t idx) { - data[idx] = static_cast(random_float()); + data[idx] = static_cast(idx % 11); }, queue, device, context); auto V = alloc_device_and_init( - num_batches * num_heads * head_size * num_keys, + bs * hn * hs * klen, [](FMHA_T* data, size_t idx) { data[idx] = static_cast(random_float()); }, @@ -218,34 +282,45 @@ void fmha_run(uint32_t iter, uint32_t warmup = 10) { device, context); auto DST = alloc_device_and_init( - num_batches * num_heads * head_size * num_queries, + bs * hn * hs * qlen, [](FMHA_T* data, size_t idx) { data[idx] = static_cast(9999); }, queue, device, context); + auto BIAS = kUseBias // bias / attention mask + ? alloc_device_and_init( + bs * 1 * qlen * klen_pad32, + [=](FMHA_T* data, size_t idx) { + data[idx] = + static_cast(random_float()) * softmax_scale * p.hs; + }, + queue, + device, + context) + : nullptr; auto L = alloc_device_and_init( // log sum exp - num_batches * num_heads * num_keys, + bs * hn * klen, [](accum_t* data, size_t idx) { data[idx] = static_cast(9999); }, queue, device, context); - sycl::nd_range<3> nd_range = - fmha_forward_op_t::get_nd_range(num_batches * num_heads, num_queries); + sycl::nd_range<3> nd_range = fmha_forward_op_t::get_nd_range(bs * hn, qlen); fmha_forward_op_t::check_slm_size(queue.get_info()); - std::cout << "slm_size:\t" << fmha_forward_op_t::get_slm_size() << std::endl; - std::cout << "global_size:\t" << nd_range.get_global_range()[0] << ",\t" - << nd_range.get_global_range()[1] << ",\t" - << nd_range.get_global_range()[2] << std::endl; - std::cout << "local_size:\t" << nd_range.get_local_range()[0] << ",\t" - << nd_range.get_local_range()[1] << ",\t" - << nd_range.get_local_range()[2] << std::endl; - constexpr int64_t qk_ops = static_cast(2) * num_batches * num_heads * - head_size * num_queries * num_keys; - constexpr int64_t pv_ops = static_cast(2) * num_batches * num_heads * - head_size * num_queries * num_keys; + if (IS_VERBOSE) { + std::cout << "slm_size:\t" << fmha_forward_op_t::get_slm_size() + << std::endl; + std::cout << "global_size:\t" << nd_range.get_global_range()[0] << ",\t" + << nd_range.get_global_range()[1] << ",\t" + << nd_range.get_global_range()[2] << std::endl; + std::cout << "local_size:\t" << nd_range.get_local_range()[0] << ",\t" + << nd_range.get_local_range()[1] << ",\t" + << nd_range.get_local_range()[2] << std::endl; + } + const int64_t qk_ops = static_cast(2) * bs * hn * hs * qlen * klen; + const int64_t pv_ops = static_cast(2) * bs * hn * hs * qlen * klen; - int64_t ops = qk_ops + pv_ops; + const int64_t ops = qk_ops + pv_ops; profiling_helper prof("gemm_universal", ops, "gflops"); for (uint32_t i = 0; i < iter + warmup; i++) { if (i >= warmup) { @@ -258,23 +333,23 @@ void fmha_run(uint32_t iter, uint32_t warmup = 10) { K, V, nullptr, - nullptr, + BIAS, nullptr, DST, L, - num_batches, - num_heads, - num_heads, // num_kv_heads - head_size, - num_queries, - num_keys, - -1, - -1, - -1, + bs, + hn, + hn, // num_kv_heads + hs, + qlen, + klen, + kUseBias ? klen_pad32 * qlen : 0, + kUseBias ? 0 : 0, // broadcast on N (head num) + kUseBias ? klen_pad32 : 0, softmax_scale, 0, 0, - 0, + kUseBias ? klen_pad32 : 0, (uint64_t)0, (uint64_t)0); fmha_forward_op_t{}(item, kern_args); @@ -288,16 +363,78 @@ void fmha_run(uint32_t iter, uint32_t warmup = 10) { } } // performance - prof.print_profiling_result(profiling_selector::GPU); + prof.print_profiling_result(profiling_selector::GPU, IS_VERBOSE); - ASSERT_EQ(0, fma_result_validate(Q, K, V, DST, queue)); + ASSERT_EQ( + 0, + (fma_result_validate( + p, Q, K, V, DST, BIAS, queue))); free(Q, context); free(K, context); free(V, context); free(DST, context); + if (BIAS) + free(BIAS, context); + if (L) + free(L, context); +} +template +void fmha_dispatch_policy(const test_params_t& p, Args... args) { + if (p.hs <= 64) { + if (p.qlen < 64) { + // for short query length + return fmha_run_>(p, args...); + } else { + // for long query length + return fmha_run_>(p, args...); + } + } else if (p.hs <= 128) { + if (p.qlen == 1) { + // for extremely short query length + if (p.klen < 512) { + return fmha_run_>(p, args...); + } else { + return fmha_run_>(p, args...); + } + } else if (p.qlen < 64) { + // for short query length + if (p.klen < 512) { + return fmha_run_>(p, args...); + } else { + return fmha_run_>(p, args...); + } + } else { + return fmha_run_>(p, args...); + } + } else { + std::cout << "Larger hs to be tested...\n"; + GTEST_FAIL(); + return; + } } -int main() { - fmha_run(5, 2); +void fmha_run(const test_params_t& p, uint32_t iter, uint32_t warmup = 10) { + return fmha_dispatch_policy(p, iter, warmup, p.kUseBias, p.kSeqLast); +} + +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +class FMHATest : public TestWithParam { + protected: + FMHATest() {} + ~FMHATest() {} + void SetUp() override {} + void TearDown() override {} +}; +TEST_P(FMHATest, ) { + test_params_t p = TestWithParam::GetParam(); + fmha_run(p, 5, 3); } +INSTANTIATE_TEST_SUITE_P( + XeTLA, + FMHATest, + ValuesIn(test_params_t::cases()), + [](TestParamInfo tpi) { return tpi.param.to_string(); }); diff --git a/tests/integration/fmha/fmha_forward.hpp b/tests/integration/fmha/fmha_forward.hpp index e0964370e..6231c6e58 100644 --- a/tests/integration/fmha/fmha_forward.hpp +++ b/tests/integration/fmha/fmha_forward.hpp @@ -115,25 +115,31 @@ class fmha_forward_t { static constexpr uint32_t accum_step = fmha_policy::accum_step; static constexpr uint32_t stages = fmha_policy::stages; static constexpr uint32_t sync_freq = fmha_policy::sync_freq; + static constexpr uint32_t kBr = fmha_policy::kBr; + static constexpr uint32_t kBc = fmha_policy::kBc; + static constexpr uint32_t kHm = fmha_policy::kHm; + static constexpr uint32_t kSgBr = fmha_policy::kSgBr; + static constexpr uint32_t kSgBc = fmha_policy::kSgBc; + static constexpr uint32_t kSgHm = fmha_policy::kSgHm; - using comp_attr = group::compute_attr_t; + using comp_attr = std::conditional_t< + std::is_same_v && (arch_tag < gpu_arch::XeHpc), + group::compute_attr_t, + group::compute_attr_t>; using knobs = group::perf_tuning_knob_t; + + // use fpu when M==1 even if xmx is available + static constexpr bool _use_xmx = arch_tag >= gpu_arch::XeHpg && kSgBr != 1; using compute_policy_BrBc = std::conditional_t< - (arch_tag >= gpu_arch::XeHpg), + _use_xmx, group::compute_policy_default_xmx, group::compute_policy_default_fpu>; - // TODO: add k slicing + // TODO(Yi): add k slicing? using compute_policy_BrBm = std::conditional_t< - (arch_tag >= gpu_arch::XeHpg), + _use_xmx, group::compute_policy_default_xmx, group::compute_policy_default_fpu>; // ---------------- // Tile shape and Threads // ---------------- // - static constexpr uint32_t kBr = fmha_policy::kBr; - static constexpr uint32_t kBc = fmha_policy::kBc; - static constexpr uint32_t kHm = fmha_policy::kHm; - static constexpr uint32_t kSgBr = fmha_policy::kSgBr; - static constexpr uint32_t kSgBc = fmha_policy::kSgBc; - static constexpr uint32_t kSgHm = fmha_policy::kSgHm; using tile_shape_BrBc = group::tile_shape_t; using tile_shape_BrHm = group::tile_shape_t; diff --git a/tests/integration/fmha/fmha_utils.h b/tests/integration/fmha/fmha_utils.h index 9c070da27..fc1c11909 100644 --- a/tests/integration/fmha/fmha_utils.h +++ b/tests/integration/fmha/fmha_utils.h @@ -255,7 +255,7 @@ struct bias_add_op_t { tile_load(bias, bias_payload); #pragma unroll - for (int i = 0; i < tile_size_y / block_size_y; i++) { + for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) { #pragma unroll for (int j = 0; j < num_block_x; j++) { auto dst_reg = @@ -264,7 +264,7 @@ struct bias_add_op_t { (i * num_block_x + j) * block_elems) .xetla_format(); #pragma unroll - for (int row_i = 0; row_i < block_size_y; row_i++) { + for (uint32_t row_i = 0; row_i < block_size_y; row_i++) { auto src_reg = bias.reg.xetla_select(j * block_size_x); dst_reg.row(row_i) = diff --git a/tests/integration/gemm/CMakeLists.txt b/tests/integration/gemm/CMakeLists.txt index 880bde753..cb9c535c5 100644 --- a/tests/integration/gemm/CMakeLists.txt +++ b/tests/integration/gemm/CMakeLists.txt @@ -1,7 +1,7 @@ include_directories(${CMAKE_SOURCE_DIR}/tests/integration/gemm) add_subdirectory(bf16) -add_subdirectory(bf16_stream_k) +add_subdirectory(stream_k) add_subdirectory(fp16) add_subdirectory(fp32) add_subdirectory(int8_quantization) diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 2e4e03626..7e7896f3e 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -45,44 +45,43 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } - static constexpr mma_engine engine = mma_engine::xmx; + static constexpr mma_engine engine = mma_engine::fpu; static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; }; class Test0 : public TestBase { public: - static constexpr size_t mat_m = 256; + static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 256; + static constexpr size_t wg_n = 32; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; - static constexpr size_t sg_k = 32; + static constexpr size_t sg_k = 16; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; using data_type_a = fp16; using data_type_b = fp16; using data_type_c = fp16; using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; }; class Test1 : public TestBase { public: - static constexpr size_t mat_m = 256; + static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 256; - static constexpr size_t wg_n = 256; - static constexpr size_t sg_m = 32; - static constexpr size_t sg_n = 64; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::col_major; + static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = fp16; using data_type_b = fp16; @@ -94,9 +93,9 @@ class Test2 : public TestBase { static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 16; + static constexpr size_t wg_m = 1; static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 8; + static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; static constexpr uint32_t global_kslicing = 1; @@ -105,65 +104,69 @@ class Test2 : public TestBase { static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = fp16; using data_type_b = fp16; - using data_type_c = float; + using data_type_c = fp16; using data_type_acc = float; }; class Test3 : public TestBase { public: - static constexpr size_t mat_m = 192; + static constexpr size_t mat_m = 256; static constexpr size_t mat_n = 256; static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 192; - static constexpr size_t wg_n = 256; - static constexpr size_t sg_m = 24; - static constexpr size_t sg_n = 64; - static constexpr size_t sg_k = 32; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::col_major; - static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; using data_type_a = fp16; using data_type_b = fp16; - using data_type_c = float; + using data_type_c = fp16; using data_type_acc = float; }; + class Test4 : public TestBase { public: - static constexpr size_t mat_m = 256; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 256; - static constexpr size_t wg_n = 256; - static constexpr size_t sg_m = 32; - static constexpr size_t sg_n = 64; + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 16 * 1; + static constexpr size_t wg_n = 32 * 32; + static constexpr size_t sg_m = 16; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = fp16; using data_type_b = fp16; - using data_type_c = float; + using data_type_c = fp16; using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::xmx; }; + class Test5 : public TestBase { public: - static constexpr size_t mat_m = 256; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 256; - static constexpr size_t wg_n = 256; - static constexpr size_t sg_m = 32; - static constexpr size_t sg_n = 64; - static constexpr size_t sg_k = 16; + static constexpr size_t mat_m = 1024; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 32; + static constexpr size_t wg_n = 32 * 4; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; + static constexpr size_t sg_k = 32; static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::col_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = fp16; using data_type_b = fp16; - using data_type_c = float; + using data_type_c = fp16; using data_type_acc = float; + static constexpr mma_engine engine = mma_engine::fpu; }; class Test6 : public TestBase { public: diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index 8e3ef2ffe..26a78ff91 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -41,12 +41,12 @@ template < gpu_arch gpu_arch> struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 8; - static constexpr uint32_t prefetch_distance = 3; + static constexpr uint32_t periodic_sync_interval = 0 ; //8; + static constexpr uint32_t prefetch_distance = 0 ;//256 / (sg_k * sizeof(dtype_a)); using compute_attr = typename std::conditional< (engine == mma_engine::fpu), - compute_attr_t, + compute_attr_t, compute_attr_t>::type; using perf_tuning_knob = perf_tuning_knob_t; @@ -76,6 +76,9 @@ struct fp16_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "fp16_gemm_test_func"; } diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index 98e66fc7b..400d13276 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -33,24 +33,24 @@ TYPED_TEST_P(fp16_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); using tests = ::testing::Types< - Test0, - Test1, - Test2, - Test3, - Test4, - Test5, - Test6, - Test7, - Test8, - Test9, - Test10, - Test11, - Test12, - Test13, - Test14, - Test15, - Test16, - Test17, - Test18, - Test19>; + Test4>; + // Test1, + // Test2, + // Test3>; + // Test4, + // Test5, + // Test6, + // Test7, + // Test8, + // Test9, + // Test10, + // Test11, + // Test12, + // Test13, + // Test14, + // Test15, + // Test16, + // Test17, + // Test18, + // Test19>; INSTANTIATE_TYPED_TEST_SUITE_P(fp16_gemm_test_suite, fp16_gemm_test, tests); diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index 84b80958e..a8e4da602 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -16,10 +16,10 @@ #include #include "xetla.hpp" -// #define UT_DEBUG 1 +#define UT_DEBUG 1 using namespace gpu::xetla; // The number of times the kernel is executed -constexpr int ITER = 100; +constexpr int ITER = 200; enum optional_feature { NONE, ACT_SHUFFLE }; @@ -51,519 +51,446 @@ class act_shuf_feature_next_token { class test1_xehpg { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 4096 * 1; - static constexpr size_t mat_k = 4096 * 1; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 32 * 4; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 32; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; - - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - static constexpr mma_engine mma_eng = mma_engine::xmx; - static constexpr gpu_arch arch = gpu_arch::XeHpg; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; - -class test1_gpu_xelpg { - public: - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 4096 * 1; + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096 * 3; static constexpr size_t mat_k = 4096 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 2; + static constexpr size_t wg_n = 64 * 8; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 1; + static constexpr size_t local_kslicing = 8; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test2_gpu_xelpg { +class test2_xehpg { public: - static constexpr size_t mat_m = 32; + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 4096 * 1; - static constexpr size_t mat_k = 4096 * 1; + static constexpr size_t mat_k = 4096 * 3; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 2; + static constexpr size_t wg_n = 32 * 4; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t sg_k = 32; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 1; + static constexpr size_t local_kslicing = 16; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test3_gpu_xelpg { + +class test3_xehpg { public: - static constexpr size_t mat_m = 1024; + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 4096 * 1; static constexpr size_t mat_k = 4096 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 2; + static constexpr size_t wg_n = 32 * 4; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 1; + static constexpr size_t local_kslicing = 16; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test4_gpu_xelpg { + +class test4_xehpg { public: + // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096 * 1; - static constexpr size_t mat_k = 4096 * 1; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 32064 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 32 * 2; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; + static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test5_gpu_xelpg { +class test5_xehpg { public: + // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096 * 3; - static constexpr size_t mat_k = 4096 * 1; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 9216 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 32 * 2; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; + static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test6_gpu_xelpg { +class test6_xehpg { public: + // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096 * 1; - static constexpr size_t mat_k = 4096 * 3; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 3072 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 32 * 2; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; + static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test7_gpu_xelpg { + +class test7_xehpg { public: + // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096 * 1; - static constexpr size_t mat_k = 12288; + static constexpr size_t mat_n = 16384 * 1; + static constexpr size_t mat_k = 3072 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 64 * 8; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; + static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test8_gpu_xelpg { +class test8_xehpg { public: + // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 12288; - static constexpr size_t mat_k = 4096 * 1; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 8192 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 32 * 2; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; + static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test9_gpu_xelpg { +class test9_xehpg { public: + // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096 * 1; - static constexpr size_t mat_k = 1024 * 1; + static constexpr size_t mat_n = 32064 * 1; + static constexpr size_t mat_k = 3072 * 1; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 32 * 4; + static constexpr size_t wg_n = 64 * 8 * 2; static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; + static constexpr size_t local_kslicing = 4; + static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; }; -class test10_gpu_xelpg { + +class test1_xelpg { public: + // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 1024 * 1; + static constexpr size_t mat_n = 4096 * 3; static constexpr size_t mat_k = 4096 * 1; static constexpr size_t wg_m = 1; static constexpr size_t wg_n = 32 * 4; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 2; + static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; -}; - -class t1 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1024; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 32; - static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 32; - - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; - -class t2 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1024; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 32; - - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class t3 { +class test2_xelpg { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1024; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 16; - static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096 * 1; + static constexpr size_t mat_k = 4096 * 3; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 2; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 32; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 8; + static constexpr size_t local_kslicing = 16; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv1 { +class test3_xelpg { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 12288; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096 * 1; + static constexpr size_t mat_k = 4096 * 1; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 2; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv2 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv3 { +class test4_xelpg { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 11008; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 32064 * 1; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 2; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 8; + static constexpr size_t local_kslicing = 16; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv4 { +class test5_xelpg { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 11008; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 9216 * 1; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 2; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 32; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 4; + static constexpr size_t local_kslicing = 16; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv5 { +class test6_xelpg { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 151936; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 3072 * 1; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 2; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv6 { + +class test7_xelpg { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 12288; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_n = 16384 * 1; + static constexpr size_t mat_k = 3072 * 1; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 2; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv7 { + +class test8_xelpg { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_n = 3072 * 1; + static constexpr size_t mat_k = 8192 * 1; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 2; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv8 { + +class test9_xelpg { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 11008; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t mat_n = 32064 * 1; + static constexpr size_t mat_k = 3072 * 1; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 32 * 4; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; }; -class qkv9 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 11008; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t local_kslicing = 4; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv10 { +class test10_xelpg { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 151936; - static constexpr size_t mat_k = 4096; + static constexpr size_t mat_m = 32; + static constexpr size_t mat_n = 4096 * 3; + static constexpr size_t mat_k = 4096 * 3; static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; + static constexpr size_t wg_n = 32 * 1; static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; + static constexpr size_t sg_n = 32; static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 8; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x2; using data_type_c = fp16; @@ -664,7 +591,7 @@ void dequantize_gemm_run(int iter) { using tile_shape = xetla::group::tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 1; + static constexpr uint32_t periodic_sync_interval = 0; static constexpr uint32_t prefetch_distance = 0; using mem_desc_a_t = xetla::mem_desc_t< @@ -904,7 +831,7 @@ void dequantize_gemm_run(int iter) { size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); - int constexpr warm = 10; + int constexpr warm = 100; try { for (int i = 0; i < iter + warm; i++) { if (i >= warm) @@ -1000,7 +927,7 @@ void dequantize_gemm_run(int iter) { size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); - int constexpr warm = 10; + int constexpr warm = 0; try { for (int i = 0; i < iter + warm; i++) { if (i >= warm) @@ -1090,17 +1017,7 @@ TYPED_TEST_P(dequantize_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd); -using tests = ::testing::Types< - test1_gpu_xelpg, - test2_gpu_xelpg, - test3_gpu_xelpg, - test4_gpu_xelpg, - test5_gpu_xelpg, - test6_gpu_xelpg, - test7_gpu_xelpg, - test8_gpu_xelpg, - test9_gpu_xelpg, - test10_gpu_xelpg>; +using tests = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_test_suite, diff --git a/tests/integration/gemm/bf16_stream_k/CMakeLists.txt b/tests/integration/gemm/stream_k/CMakeLists.txt similarity index 100% rename from tests/integration/gemm/bf16_stream_k/CMakeLists.txt rename to tests/integration/gemm/stream_k/CMakeLists.txt diff --git a/tests/integration/gemm/bf16_stream_k/main.cpp b/tests/integration/gemm/stream_k/main.cpp similarity index 94% rename from tests/integration/gemm/bf16_stream_k/main.cpp rename to tests/integration/gemm/stream_k/main.cpp index 8df90d529..b983a2d41 100644 --- a/tests/integration/gemm/bf16_stream_k/main.cpp +++ b/tests/integration/gemm/stream_k/main.cpp @@ -57,8 +57,7 @@ int gemm_result_validate( }); // BiasAdd for (uint32_t i = 0; i < gold_C.size(); ++i) { - uint32_t col = gold_C.size() % n; - gold_C[i] += D[col]; + gold_C[i] += D[i % n]; } } @@ -259,7 +258,9 @@ void stream_k_gemm_run(uint32_t iter) { static constexpr uint32_t periodic_sync_interval = 4; static constexpr uint32_t prefetch_distance = 4; - // Micro-kernel configuration + constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + + // Mirco-kernel configuration using gemm_config = typename xetla::group::gemm_selector_t< data_type_a, // input datatype for A data_type_b, // input datatype for B @@ -273,7 +274,7 @@ void stream_k_gemm_run(uint32_t iter) { tile_shape, // computation tile shape sg_tile_k, // elements in each iteration mma_engine::xmx, // compute engine - gpu_arch::XeHpc, + arch_tag, prefetch_distance, periodic_sync_interval> // GPU arch, prefetch stages, periodic sync // frequency @@ -291,8 +292,7 @@ void stream_k_gemm_run(uint32_t iter) { // bias_add_op_t using mem_desc_bias_t = xetla:: mem_desc_t; - using bias_op_t = - xetla::subgroup::bias_add_op_t; + using bias_op_t = xetla::subgroup::bias_add_op_t; using tile_op_t = xetla::subgroup::chained_tile_op_t< xetla::subgroup::relu_op_t, // apply elementwise ReLU bias_op_t // apply elementwise BiasAdd @@ -300,8 +300,8 @@ void stream_k_gemm_run(uint32_t iter) { using epilogue_policy_t = typename std::conditional< postop_enable == 0, - xetla::group::epilogue_policy_default, - xetla::group::epilogue_policy_tile_op>::type; + xetla::group::epilogue_policy_default, + xetla::group::epilogue_policy_tile_op>::type; using epilogue_t = xetla::group::epilogue_t< epilogue_policy_t, @@ -309,7 +309,7 @@ void stream_k_gemm_run(uint32_t iter) { mem_desc_t>; using dispatch_stream_k = - gpu::xetla::kernel::dispatch_policy_stream_k; + gpu::xetla::kernel::dispatch_policy_stream_k; using gemm_op_t = xetla::kernel:: gemm_universal_t; @@ -326,11 +326,12 @@ void stream_k_gemm_run(uint32_t iter) { sg_tile_n, avail_xecores); - static const std::string env_set_str = - "SYCL_PROGRAM_COMPILE_OPTIONS= -vc-codegen -doubleGRF " - "-vc-disable-indvars-opt -Xfinalizer ' -printregusage -enableBCR " - "-DPASTokenReduction '"; - putenv(const_cast(env_set_str.c_str())); + setenv( + "SYCL_PROGRAM_COMPILE_OPTIONS", + " -vc-codegen -doubleGRF -vc-disable-indvars-opt " + " -Xfinalizer '-printregusage -enableBCR -DPASTokenReduction '", + 1); + // Define and initialize the data required for the calculation auto A = alloc_device_and_init( size_a, @@ -387,7 +388,7 @@ void stream_k_gemm_run(uint32_t iter) { using epilogue_args_t = typename epilogue_t::arguments_t; uint32_t warmup = 0; - int64_t ops = 2 * static_cast(matrix_m) * matrix_n * matrix_k; + long ops = 2 * static_cast(matrix_m) * matrix_n * matrix_k; profiling_helper prof("stream_k_universal_gemm", ops, "gflops"); if constexpr (postop_enable) { @@ -410,9 +411,9 @@ void stream_k_gemm_run(uint32_t iter) { matrix_k, matrix_n, A, - matrix_k, + (mem_layout_a == mem_layout::row_major) ? matrix_k : matrix_m, B, - matrix_n, + (mem_layout_b == mem_layout::row_major) ? matrix_n : matrix_k, C, matrix_n, Acc, @@ -461,9 +462,9 @@ void stream_k_gemm_run(uint32_t iter) { matrix_k, matrix_n, A, - matrix_k, + (mem_layout_a == mem_layout::row_major) ? matrix_k : matrix_m, B, - matrix_n, + (mem_layout_b == mem_layout::row_major) ? matrix_n : matrix_k, C, matrix_n, Acc, @@ -502,8 +503,7 @@ void stream_k_gemm_run(uint32_t iter) { } } - static const std::string env_unset_str = "SYCL_PROGRAM_COMPILE_OPTIONS="; - putenv(const_cast(env_unset_str.c_str())); + unsetenv("SYCL_PROGRAM_COMPILE_OPTIONS"); ASSERT_EQ( 0, @@ -542,6 +542,7 @@ TYPED_TEST_P(stream_k_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(stream_k_gemm_test, esimd); + using tests = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P( diff --git a/tests/integration/vector_add/int32_1d/kernel_func.hpp b/tests/integration/vector_add/int32_1d/kernel_func.hpp index b841f7b5d..a46277c5a 100644 --- a/tests/integration/vector_add/int32_1d/kernel_func.hpp +++ b/tests/integration/vector_add/int32_1d/kernel_func.hpp @@ -58,7 +58,6 @@ KERNEL_FUNC inline void vector_add_func( xetla_vector ivector2 = xetla_load_global< dtype, SIMD, - data_size::default_size, cache_hint::uncached, cache_hint::uncached>(b, offset); diff --git a/tests/integration/vector_add/tf32_1d/kernel_func.hpp b/tests/integration/vector_add/tf32_1d/kernel_func.hpp index 75b238dcc..40935fa3a 100644 --- a/tests/integration/vector_add/tf32_1d/kernel_func.hpp +++ b/tests/integration/vector_add/tf32_1d/kernel_func.hpp @@ -58,7 +58,6 @@ KERNEL_FUNC inline void vector_add_func( xetla_vector ivector2 = xetla_load_global< dtype, SIMD, - data_size::default_size, cache_hint::uncached, cache_hint::uncached>(b, offset); //// tf32 convert to fp32 -> vadd -> fp32 convert to tf32 -> write out ///// diff --git a/tests/unit/global_load_store/kernel_func.hpp b/tests/unit/global_load_store/kernel_func.hpp index 54fecef70..9f3b7e570 100644 --- a/tests/unit/global_load_store/kernel_func.hpp +++ b/tests/unit/global_load_store/kernel_func.hpp @@ -74,8 +74,7 @@ struct global_load_block_cache { [[maybe_unused]] dtype* c) { uint64_t offset = 0; xetla_vector A_load_vec = - xetla_load_global( - a, offset); + xetla_load_global(a, offset); xetla_store_global(b, offset, A_load_vec); } }; @@ -90,8 +89,7 @@ struct global_store_block_cache { uint64_t offset = 0; xetla_vector A_load_vec = xetla_load_global(a, offset); - xetla_store_global( - b, offset, A_load_vec); + xetla_store_global(b, offset, A_load_vec); } }; @@ -217,7 +215,7 @@ struct global_load_store_scatter_nelt2 { xetla_vector_gen(0, 1); offsets = offsets * sizeof(dtype); - xetla_vector A_load_vec = xetla_load_global< + xetla_vector A_load_vec = xetla_load_global< dtype, 2, data_size::default_size, diff --git a/tests/utils/buff_compare.hpp b/tests/utils/buff_compare.hpp index c1214f9d7..bbdfac6f4 100644 --- a/tests/utils/buff_compare.hpp +++ b/tests/utils/buff_compare.hpp @@ -237,6 +237,7 @@ bool _handle_fp_types( [[maybe_unused]] std::string name, size_t ulp_tol, double abs_tol) { + const bool verbose = name != ""; if (std::is_same, gpu::xetla::bf16>::value) { if (ulp_tol == 0) ulp_tol = 8; @@ -281,15 +282,17 @@ bool _handle_fp_types( size_t aulpidx = std::max_element(aulpte.begin(), aulpte.end()) - aulpte.begin(); - std::cout << "\t" - << "max absolute ULP diff:\n"; - std::cout << "\t\t" - << "data_idx: " << data.idx_mapping[aulpidx] - << " gold_idx: " << other.idx_mapping[aulpidx] - << " abserr: " << (float)aulpte[aulpidx] << std::endl; - std::cout << "\t\t" - << "data_val: " << ulp_data[aulpidx] - << " gold_val: " << (float)ulp_other[aulpidx] << std::endl; + if (verbose) { + std::cout << "\t" + << "max absolute ULP diff:\n"; + std::cout << "\t\t" + << "data_idx: " << data.idx_mapping[aulpidx] + << " gold_idx: " << other.idx_mapping[aulpidx] + << " abserr: " << (float)aulpte[aulpidx] << std::endl; + std::cout << "\t\t" + << "data_val: " << ulp_data[aulpidx] + << " gold_val: " << (float)ulp_other[aulpidx] << std::endl; + } size_t ulp_threshold = ulp_tol; double small_num_threshold = abs_tol; @@ -319,7 +322,8 @@ bool _handle_fp_types( float fail_rate = diff_elems_count / ((float)ulp_data.size()) * 100; float pass_rate = 100 - fail_rate; - std::cout << "\tpass rate: " << pass_rate << "%\n"; + if (verbose || fail_rate != 0) + std::cout << "\tpass rate: " << pass_rate << "%\n"; return flag; } @@ -374,6 +378,7 @@ bool xetla_buff_cmp( std::cout << "ERROR: buffer size or shape mismatch!\n"; return false; } + const bool verbose = name != ""; using dtype1 = typename T1::type; using dtype2 = typename T2::type; @@ -383,26 +388,27 @@ bool xetla_buff_cmp( std::max_element(diff.rte.begin(), diff.rte.end()) - diff.rte.begin(); unsigned aidx = std::max_element(diff.ate.begin(), diff.ate.end()) - diff.ate.begin(); - - std::cout << name << ":\n"; - std::cout << "\t" - << "max relative diff:\n"; - std::cout << "\t\t" - << "data_idx: " << data.idx_mapping[ridx] - << " gold_idx: " << other.idx_mapping[ridx] - << " relerr: " << diff.rte[ridx] << std::endl; - std::cout << "\t\t" - << "data_val: " << data.buff[ridx] - << " gold_val: " << other.buff[ridx] << std::endl; - std::cout << "\t" - << "max absolute diff:\n"; - std::cout << "\t\t" - << "data_idx: " << data.idx_mapping[aidx] - << " gold_idx: " << other.idx_mapping[aidx] - << " abserr: " << diff.ate[aidx] << std::endl; - std::cout << "\t\t" - << "data_val: " << data.buff[aidx] - << " gold_val: " << other.buff[aidx] << std::endl; + if (verbose) { + std::cout << name << ":\n"; + std::cout << "\t" + << "max relative diff:\n"; + std::cout << "\t\t" + << "data_idx: " << data.idx_mapping[ridx] + << " gold_idx: " << other.idx_mapping[ridx] + << " relerr: " << diff.rte[ridx] << std::endl; + std::cout << "\t\t" + << "data_val: " << data.buff[ridx] + << " gold_val: " << other.buff[ridx] << std::endl; + std::cout << "\t" + << "max absolute diff:\n"; + std::cout << "\t\t" + << "data_idx: " << data.idx_mapping[aidx] + << " gold_idx: " << other.idx_mapping[aidx] + << " abserr: " << diff.ate[aidx] << std::endl; + std::cout << "\t\t" + << "data_val: " << data.buff[aidx] + << " gold_val: " << other.buff[aidx] << std::endl; + } if constexpr ( std::is_floating_point_v != 0 || diff --git a/tests/utils/execution.hpp b/tests/utils/execution.hpp index bcdc28a09..46040dddd 100644 --- a/tests/utils/execution.hpp +++ b/tests/utils/execution.hpp @@ -98,6 +98,9 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { device, context); + size_t ops = 2 * matrix_m * matrix_n * matrix_k; + profiling_helper prof("gemm", ops, "gflops"); + try { std::vector kernelId = {get_kernel_id()}; auto inputBundle = @@ -128,6 +131,8 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(arg); + int constexpr warm_up = 10; + int constexpr iters = 100; for (size_t i = 0; i < batch; i++) { auto A_ptr = A + i * size_a; auto B_ptr = B + i * size_b; @@ -147,31 +152,41 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { result = test_result::skip; break; } - - auto e_esimd = queue.submit([&](handler& cgh) { - cgh.use_kernel_bundle(exeBundle); - cgh.parallel_for(nd_range, [=](nd_item<3> item) KERNEL_MAIN { - gpu::xetla::xetla_local_init(); - gpu::xetla::xetla_nbarrier_init(); - KERNEL::run( - item, - A_ptr, - B_ptr, - C_ptr, - matrix_m, - matrix_n, - matrix_k, - Acc_ptr, - Cnt_ptr); + for (int iter = 0; iter < iters + warm_up; iter++) { + if (iter >= warm_up) { + prof.cpu_start(); + } + auto e_esimd = queue.submit([&](handler& cgh) { + cgh.use_kernel_bundle(exeBundle); + cgh.parallel_for(nd_range, [=](nd_item<3> item) KERNEL_MAIN { + gpu::xetla::xetla_local_init(); + gpu::xetla::xetla_nbarrier_init(); + KERNEL::run( + item, + A_ptr, + B_ptr, + C_ptr, + matrix_m, + matrix_n, + matrix_k, + Acc_ptr, + Cnt_ptr); + }); }); - }); - e_esimd.wait(); + e_esimd.wait(); + if (iter >= warm_up) { + prof.cpu_end(); + prof.add_gpu_event(e_esimd); + } + } } } catch (cl::sycl::exception const& e) { std::cout << "SYCL exception caught: " << e.what() << '\n'; result = test_result::fail; } + // performance + prof.print_profiling_result(profiling_selector::GPU); // validation if (result == test_result::complete) { validate_func vfunc; diff --git a/tests/utils/profiling.hpp b/tests/utils/profiling.hpp index 143cdb537..6d0f6998e 100644 --- a/tests/utils/profiling.hpp +++ b/tests/utils/profiling.hpp @@ -156,9 +156,14 @@ class profiling_helper { profiling_statistics& stat, int scaling_ratio, string label = "[kernel time]", - string device = "GPU") { - vector value = {stat.max, stat.min, stat.median, stat.mean}; - vector desc = {"minimum ", "maximum ", "median ", "mean "}; + string device = "GPU", + bool is_verbose = true) { + auto value = is_verbose + ? vector{stat.max, stat.min, stat.median, stat.mean} + : vector{stat.min}; // min time for max performance + auto desc = is_verbose + ? vector{"minimum ", "maximum ", "median ", "mean "} + : vector{"maximum "}; string unit = ""; string perf_string = ""; for (uint32_t i = 0; i < value.size(); i++) { @@ -181,32 +186,38 @@ class profiling_helper { vector& time, profiling_statistics stat, string label = "[kernel time]", - string device = "GPU") { + string device = "GPU", + bool is_verbose = true) { get_statistics(time, stat); - std::cout << "============= Profiling for " << label << " " - << "=============" << std::endl; - print_statistics(kernel_id, stat, label, device); - std::cout << "======================================================" - << std::endl; + if (is_verbose) { + std::cout << "============= Profiling for " << label << " " + << "=============" << std::endl; + print_statistics(kernel_id, stat, label, device); + std::cout << "======================================================" + << std::endl; + } if (this->work_amount[kernel_id] != 0) { - std::cout << "============== " << label << " " << work_name[kernel_id] - << " ================== " << std::endl; + if (is_verbose) + std::cout << "============== " << label << " " << work_name[kernel_id] + << " ================== " << std::endl; // Different performance data correspond to different scaling ratios if (this->work_name[kernel_id] == "gflops") { - print_performance(kernel_id, stat, 1000000, label, device); + print_performance(kernel_id, stat, 1000000, label, device, is_verbose); } else if (this->work_name[kernel_id] == "mhashs") { - print_performance(kernel_id, stat, 1000, label, device); + print_performance(kernel_id, stat, 1000, label, device, is_verbose); } else if (this->work_name[kernel_id] == "GB/s") { - print_performance(kernel_id, stat, 1000000, label, device); + print_performance(kernel_id, stat, 1000000, label, device, is_verbose); } else { std::cout << "Not sure how much workload scales" << std::endl; } - std::cout << "======================================================" - << std::endl; + if (is_verbose) + std::cout << "======================================================" + << std::endl; } - std::cout << std::endl; + if (is_verbose) + std::cout << std::endl; } void set_time_vecs() { @@ -288,19 +299,32 @@ class profiling_helper { gpu_event_vec[kernel_id].push_back(gpu_event); } - void print_profiling_result(profiling_selector selector) { + void print_profiling_result( + profiling_selector selector, + bool is_verbose = true) { write_performance_metrics_into_report(); for (uint32_t i = 0; i < kernel_nums; i++) { - std::cout << "\n***************** PROFILING FOR KERNEL" << i - << " ***********************" << std::endl; + if (is_verbose) + std::cout << "\n***************** PROFILING FOR KERNEL" << i + << " ***********************" << std::endl; if (selector != profiling_selector::CPU) { get_gpu_time_from_events(i); print_profiling_data( - i, gpu_time_vec[i], gpu_statistics, "[kernel time]", "GPU"); + i, + gpu_time_vec[i], + gpu_statistics, + "[kernel time]", + "GPU", + is_verbose); } if (selector != profiling_selector::GPU) { print_profiling_data( - i, cpu_time_vec[i], cpu_statistics, "[Wall time]", "CPU"); + i, + cpu_time_vec[i], + cpu_statistics, + "[Wall time]", + "CPU", + is_verbose); } } } diff --git a/tests/utils/windows_functions.hpp b/tests/utils/windows_functions.hpp new file mode 100644 index 000000000..ec05671e3 --- /dev/null +++ b/tests/utils/windows_functions.hpp @@ -0,0 +1,16 @@ +#pragma once +#include + +int setenv(const char *name, const char *value, int overwrite) { + int errcode = 0; + if (!overwrite) { + size_t envsize = 0; + errcode = getenv_s(&envsize, NULL, 0, name); + if (errcode || envsize) return errcode; + } + return _putenv_s(name, value); +} + +int unsetenv(const char *name) { + return _putenv_s(name, ""); +}