Skip to content

Commit

Permalink
Merge pull request #1 from arthw/update_warp
Browse files Browse the repository at this point in the history
[SYCL] Fix WARP_SIZE=16 bug of Intel GPU (ggerganov#8266) cherry-pick b549a1b
  • Loading branch information
arthw authored Jul 13, 2024
2 parents c5009e6 + 74e3185 commit aeaed61
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 70 deletions.
2 changes: 1 addition & 1 deletion ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ if (GGML_SYCL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
else()
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
endif()

file(GLOB GGML_HEADERS_SYCL "ggml-sycl/*.hpp")
Expand Down
21 changes: 20 additions & 1 deletion ggml/src/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const

const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
const int nthreads = block_size;
const int nwarps = nthreads / WARP_SIZE;
int nreduce = nwarps / WARP_SIZE;


float slope = 1.0f;

Expand All @@ -919,7 +923,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
slope = sycl::pow(base, float(exp));
}

float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
float max_val = -INFINITY;

for (int col0 = 0; col0 < ncols; col0 += block_size) {
Expand All @@ -943,6 +947,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
for (size_t i = 1; i < nreduce; i += 1)
buf[lane_id + i * WARP_SIZE] = -INFINITY;

}
item_ct1.barrier(sycl::access::fence_space::local_space);

Expand All @@ -952,6 +959,11 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
item_ct1.barrier(sycl::access::fence_space::local_space);

max_val = buf[lane_id];
for (size_t i = 1; i < nreduce; i += 1)
{
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
}

max_val = warp_reduce_max(max_val, item_ct1);
}

Expand All @@ -975,6 +987,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
item_ct1.barrier(sycl::access::fence_space::local_space);
if (warp_id == 0) {
buf[lane_id] = 0.f;
for (size_t i = 1; i < nreduce; i += 1)
buf[lane_id + i * WARP_SIZE] = 0.f;

}
item_ct1.barrier(sycl::access::fence_space::local_space);

Expand All @@ -984,6 +999,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
item_ct1.barrier(sycl::access::fence_space::local_space);

tmp = buf[lane_id];
for (size_t i = 1; i < nreduce; i += 1)
{
tmp += buf[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
}

Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-sycl/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ void sycl_device_mgr::detect_all_sycl_device_list() try {
dpct::get_device_info(prop, device);
work_group_sizes.push_back(prop.get_max_work_group_size());
max_compute_units.push_back(prop.get_max_compute_units());
hw_familys.push_back(get_device_family(&device));
}
return;
} catch (sycl::exception const &exc) {
Expand Down Expand Up @@ -498,4 +499,8 @@ int ggml_sycl_device_info::get_device_id(int device_index) {
}
}

int ggml_sycl_device_info::hw_family(int device_id) {
return device_mgr->hw_familys[device_id];
}

//--ggml_sycl_device_info--
4 changes: 4 additions & 0 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "dpct/helper.hpp"
#include "ggml-sycl.h"
#include "presets.hpp"
#include "sycl_hw.hpp"

#define GGML_COMMON_DECL_SYCL
#define GGML_COMMON_IMPL_SYCL
Expand Down Expand Up @@ -188,6 +189,8 @@ class sycl_device_mgr {
std::vector<sycl::device> devices;
std::vector<int> max_compute_units;
std::vector<int> work_group_sizes;
std::vector<int> hw_familys;

sycl::queue *first_queue;
std::vector<sycl::queue> _queues;
std::vector<sycl::context> ctxs;
Expand Down Expand Up @@ -236,6 +239,7 @@ struct ggml_sycl_device_info {
bool is_allowed_device(int device_id);
const char* devices_list();
int get_device_id(int device_index);
int hw_family(int device_id);
};

struct ggml_sycl_pool {
Expand Down
Loading

0 comments on commit aeaed61

Please sign in to comment.