Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Xetla Bug fix (#256)
Browse files Browse the repository at this point in the history
1 Restore first token configuration
2 Bug fixes for arch_config
3 Bug fixes for fmha
4 Synchronized part of the code with innersource
5 cmake compilation parameters are the same as ipex
6 FP16 UT bugfix dtype_mma_a and dtype_mma_b should be fp16
7 Updated policy for int4 and default FPU
8 FP16 gemm MatB col_major bugfix

---------

Co-authored-by: Ding, Yi1 <yi1.ding@intel.com>
  • Loading branch information
sunjiweiswift and DDEle authored May 16, 2024
1 parent b13e02f commit e5510c6
Show file tree
Hide file tree
Showing 54 changed files with 2,147 additions and 1,257 deletions.
43 changes: 33 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ if (NOT CMAKE_BUILD_TYPE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
if(UNIX)
set(CMAKE_C_COMPILER icx)
set(CMAKE_CXX_COMPILER icpx)
else() # Windows
# Force CMake to use icx-cl rather than the default C++ compiler/linker
# (needed on Windows only)
Expand All @@ -24,7 +26,7 @@ include(CTest)
enable_testing()

if(UNIX)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/tools/cmake")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/tools/cmake")
endif()
find_package(MKL CONFIG REQUIRED)
message(STATUS "MKL_VERSION=${MKL_VERSION}")
Expand All @@ -33,7 +35,7 @@ message(STATUS "MKL_IMPORTED_TARGETS=${MKL_IMPORTED_TARGETS}")
# debug option
message(STATUS "'DEBUG' is set to " ${DEBUG})
if (${DEBUG})
add_compile_options(-debug=minimal -Rno-debug-disables-optimization -DDEBUG=${DEBUG})
add_compile_options(-debug=minimal -Rno-debug-disables-optimization -DDEBUG=${DEBUG})
endif ()

# log message print
Expand All @@ -43,20 +45,41 @@ if (${LOG} STREQUAL "on")
add_definitions(-DLOG_PRINT)
endif ()

# For large registers mode, enable 256 registers for kernels
set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-codegen")
# Enable bank conflict reduction.
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -enableBCR")
# Optimization to reduce the tokens used for DPAS instruction.
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -DPASTokenReduction")

# AOT device
set(AOT_DEVICE "" CACHE STRING "Set device list for AOT build")
set(USE_AOT_DEVLIST "" CACHE STRING "Set device list for AOT build")
if (USE_AOT_DEVLIST)
add_compile_options(-fsycl-targets=spir64_gen)
add_link_options(-fsycl-targets=spir64_gen)
# For registers usage verbose at AOT
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -printregusage")
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "-options '${XETLA_OFFLINE_OPTIONS}' -device '${USE_AOT_DEVLIST}'")
else()
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "${XETLA_OFFLINE_OPTIONS}")
endif()

add_compile_options(-fsycl -fsycl-device-code-split=per_kernel)
add_compile_options(-Wall -Wextra -Werror)

include(ProcessorCount)
ProcessorCount(nproc)
add_link_options(-fsycl -fsycl-device-code-split=per_kernel -fsycl-max-parallel-link-jobs=${nproc})
add_link_options(${XETLA_KERNEL_FLAGS})

add_compile_options(-fsycl)
add_link_options(-fsycl)
if(UNIX)
if (AOT_DEVICE)
add_compile_options(-fsycl-targets=spir64_gen)
add_link_options(-fsycl-targets=spir64_gen -Xs "-device ${AOT_DEVICE}") # MTL
endif()
add_compile_options(-fp-model=precise -Wall -Wextra -Werror)
add_compile_options(-fp-model=precise)
add_link_options(-lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lpthread -lm)
link_libraries(-lgtest -lgtest_main)
else() # Windows
add_compile_options(/fp:precise)
add_compile_options(/EHsc)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
add_compile_options(/MDd)
Expand Down
21 changes: 1 addition & 20 deletions examples/11_stream_k_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,6 @@
set(TARGET stream_k_gemm)

set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl)
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl-targets=spir64_gen)

# disable loop invariance optimization, this is for performance
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
# For large registers mode, enable 256 registers for kernels
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -doubleGRF")
# For registers usage verbose at AOT
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -printregusage")
# Enable bank conflict reduction.
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -enableBCR")
# Optimization to reduce the tokens used for DPAS instruction.
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -DPASTokenReduction")

set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs)
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} "-device pvc -options '${XETLA_OFFLINE_OPTIONS} ' ")

#build test
add_executable(${TARGET} stream_k_gemm.cpp)
target_link_options(${TARGET} PRIVATE ${XETLA_KERNEL_FLAGS})
# Disable vector combine, to remove redundant loads and stores
#target_compile_options(${TARGET} PRIVATE -mllvm -disable-vector-combine -fsycl -fsycl-targets=spir64_gen)

# target_compile_options(${TARGET} PRIVATE -mllvm -disable-vector-combine -fsycl -fsycl-targets=spir64_gen)
54 changes: 35 additions & 19 deletions include/common/core/arch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>

template <gpu_arch arch_tag>
inline constexpr bool arch_has_2d_load_store =
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;

template <gpu_arch arch_tag>
struct load_store_attr_t<msg_type::block_1d, arch_tag> {
Expand Down Expand Up @@ -149,9 +149,19 @@ struct register_nums_t {
};

template <gpu_arch arch_tag>
struct register_bytes_t {
struct register_bytes_t;
template <>
struct register_bytes_t<gpu_arch::XeHpc> {
static constexpr uint32_t reg_in_bytes = 64;
};
template <>
struct register_bytes_t<gpu_arch::XeHpg> {
static constexpr uint32_t reg_in_bytes = 32;
};
template <>
struct register_bytes_t<gpu_arch::XeLpg> {
static constexpr uint32_t reg_in_bytes = 32;
};

template <grf_mode grf_num_mode, gpu_arch arch_tag>
struct register_attr_t {
Expand Down Expand Up @@ -188,41 +198,47 @@ struct mma_attr_t<arch_tag, m, std::enable_if_t<!arch_has_xmx<arch_tag>>> {
template <gpu_arch arch_tag>
struct arch_attr_t {};

template <gpu_arch arch_tag>
struct client_arch_attr_base_t {
template <>
struct arch_attr_t<gpu_arch::XeHpc> {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, arch_tag>;
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpc>;

template <grf_mode grf_num_mode = grf_mode::normal>
using register_attr = register_attr_t<grf_num_mode, arch_tag>;
template <grf_mode grf_num_mode = grf_mode::double_grf>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpc>;

using dpas_attr = dpas_attr_t<arch_tag>;
using dpas_attr = dpas_attr_t<gpu_arch::XeHpc>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t local_mem_size = 64 * 1024;
static constexpr uint32_t local_mem_size = 128 * 1024;
};

template <>
struct arch_attr_t<gpu_arch::XeHpc> {
struct arch_attr_t<gpu_arch::XeHpg> {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpc>;
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpg>;

template <grf_mode grf_num_mode = grf_mode::double_grf>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpc>;
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpg>;

using dpas_attr = dpas_attr_t<gpu_arch::XeHpc>;
using dpas_attr = dpas_attr_t<gpu_arch::XeHpg>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t local_mem_size = 128 * 1024;
static constexpr uint32_t local_mem_size = 64 * 1024;
};

template <>
struct arch_attr_t<gpu_arch::XeHpg>
: public client_arch_attr_base_t<gpu_arch::XeHpg> {};
struct arch_attr_t<gpu_arch::XeLpg> {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeLpg>;

template <>
struct arch_attr_t<gpu_arch::XeLpg>
: public client_arch_attr_base_t<gpu_arch::XeLpg> {};
template <grf_mode grf_num_mode = grf_mode::double_grf>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeLpg>;

using dpas_attr = dpas_attr_t<gpu_arch::XeLpg>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t local_mem_size = 64 * 1024;
};

/// @} xetla_core_arch_config

Expand Down
31 changes: 30 additions & 1 deletion include/common/core/base_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ using fp16 = sycl::half;
///
using tf32 = sycl::ext::intel::experimental::esimd::tfloat32;

/// @brief mx_fp4(E2M1) data packed as 8bits data type.
struct mx_fp4 {
uint8_t data;
operator uint8_t() const {
return data;
}
mx_fp4() = default;
mx_fp4(uint8_t val) {
data = val;
}
};

template <typename T>
struct get_packed_num {
static constexpr uint32_t value = 1;
};

template <>
struct get_packed_num<mx_fp4> {
static constexpr uint32_t value = 2;
};

template <typename T, typename = void>
struct is_host_callable : std::false_type {};
template <typename T>
Expand All @@ -66,7 +88,8 @@ struct is_host_callable<T, std::enable_if_t<T::host_callable == true>>
template <typename T>
struct is_internal_type {
static constexpr bool value = std::is_same<remove_const_t<T>, bf16>::value ||
std::is_same<remove_const_t<T>, tf32>::value;
std::is_same<remove_const_t<T>, tf32>::value ||
std::is_same<remove_const_t<T>, mx_fp4>::value;
};
template <typename T>
inline constexpr bool is_internal_type_v = is_internal_type<T>::value;
Expand Down Expand Up @@ -108,6 +131,12 @@ struct native_type {
using type = T;
};

/// @brief Set uint8_t as the native data type of mx_fp4.
template <>
struct native_type<mx_fp4> {
using type = uint8_t;
};

/// @brief Return the native data type of T
template <typename T>
using native_type_t = typename native_type<T>::type;
Expand Down
33 changes: 31 additions & 2 deletions include/common/core/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,44 @@ enum class msg_type : uint8_t {
// prefetch_1d = 5
};

/// L1 or L2 cache hint kinds.
/// L1, L2 or L3 cache hints.
enum class cache_hint : uint8_t {
none = 0,
/// load/store/atomic: do not cache data to cache;
uncached = 1,

// load: cache data to cache;
cached = 2,

/// store: write data into cache level and mark the cache line as "dirty".
/// Upon eviction, the "dirty" data will be written into the furthest
/// subsequent cache;
write_back = 3,

/// store: immediately write data to the subsequent furthest cache, marking
/// the cache line in the current cache as "not dirty";
write_through = 4,

/// load: cache data to cache using the evict-first policy to minimize cache
/// pollution caused by temporary streaming data that may only be accessed
/// once or twice;
/// store/atomic: same as write-through, but use the evict-first policy
/// to limit cache pollution by streaming;
streaming = 5,
read_invalidate = 6

/// load: asserts that the cache line containing the data will not be read
/// again until it’s overwritten, therefore the load operation can invalidate
/// the cache line and discard "dirty" data. If the assertion is violated
/// (the cache line is read again) then behavior is undefined.
read_invalidate = 6,

// TODO: Implement the verification of this enum in check_cache_hint().
/// load, L2 cache only, next gen GPU after Xe required: asserts that
/// the L2 cache line containing the data will not be written until all
/// invocations of the shader or kernel execution are finished.
/// If the assertion is violated (the cache line is written), the behavior
/// is undefined.
const_cached = 7
};

/// Data size or format to read or store
Expand Down
Loading

0 comments on commit e5510c6

Please sign in to comment.