From 43584f8076d7f649bbf1832dacac397f95453d84 Mon Sep 17 00:00:00 2001 From: LuFengqing Date: Tue, 9 Apr 2024 11:51:32 +0800 Subject: [PATCH] [GPU] Optimize AllReduce SYCL kernel (#2667) --- itex/core/kernels/gpu/collective_ops.h | 412 ++++++++++++++++--------- itex/core/utils/gpu_helper.h | 79 +++-- 2 files changed, 321 insertions(+), 170 deletions(-) diff --git a/itex/core/kernels/gpu/collective_ops.h b/itex/core/kernels/gpu/collective_ops.h index 5e37f7dc4..53a37b2ed 100644 --- a/itex/core/kernels/gpu/collective_ops.h +++ b/itex/core/kernels/gpu/collective_ops.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "itex/core/utils/gpu_helper.h" #include "itex/core/utils/mutex.h" #include "itex/core/utils/op_kernel.h" #include "itex/core/utils/strcat.h" @@ -31,6 +32,8 @@ limitations under the License. namespace itex { constexpr int MAX_RANK_SIZE = 16; +// After tuned, we found 8 has best performance and XeLink bandwidth. +constexpr size_t VecBytes = 8; enum class ReductionOp { SUM = 0, MIN = 1, MAX = 2, PROD = 3 }; @@ -237,32 +240,6 @@ ITEX_GPUStream* CollectiveManager::GetCommStream(Participant* participant) { return comm_stream; } -ITEX_GPUStream* CollectiveManager::GetCommStream(Collective* collective) { - std::sort(collective->participants.begin(), collective->participants.end(), - [](const std::unique_ptr& a, - const std::unique_ptr& b) -> bool { - return a->gpu_device_id < b->gpu_device_id; - }); - return GetCommStream(collective->participants[0].get()); -} - -void StreamWaitStreamlist(ITEX_GPUStream* stream, - CollectiveManager::Collective* collective) { - std::vector event_list; - for (auto& participant : collective->participants) { - event_list.push_back(participant->comm_stream->ext_oneapi_submit_barrier()); - } - stream->ext_oneapi_submit_barrier(event_list); -} - -void StreamlistWaitStream(ITEX_GPUStream* stream, - CollectiveManager::Collective* collective) { - sycl::event event = stream->ext_oneapi_submit_barrier(); - for (auto& participant : collective->participants) { - participant->comm_stream->ext_oneapi_submit_barrier({event}); - } -} - void CollectiveManager::AddToAllReduce(std::unique_ptr participant, const Context& context, ReductionOp reduction_op) { @@ -357,41 +334,153 @@ void CollectiveManager::AddParticipant(std::unique_ptr participant, if (to_run != nullptr) RunCollective(to_run); } -template +template struct AllReduceKernel; template -void LaunchAllReduceKernel(ITEX_GPUStream* stream, +void LaunchAllReduceKernel(ITEX_GPUStream* stream, size_t element_count, const std::vector& inputs, - const std::vector& outputs, - size_t num_elements) { - auto group_size = - (*stream) - .get_device() - .template get_info(); - auto num_workgroup = (num_elements + group_size - 1) / group_size; - int reduction_size = inputs.size(); - if (reduction_size < MAX_RANK_SIZE) { + const std::vector& outputs, int rank, + int reduction_size) { + constexpr size_t VecSize = VecBytes / sizeof(T); + size_t vec_count = element_count / VecSize; + size_t vec_tail_element_count = element_count % VecSize; + size_t total_vec_count = vec_count + (vec_tail_element_count > 0 ? 1 : 0); + + // Each rank allreduces a sub slice of the tensors. Last rank + // also allreduce the tail vectors of the tensor. + size_t slice_vec_count = total_vec_count / reduction_size; + size_t tail_vec_count = total_vec_count % reduction_size; + size_t local_vec_count = + slice_vec_count + ((rank == (reduction_size - 1)) ? tail_vec_count : 0); + + if (local_vec_count == 0) return; + + auto device = stream->get_device(); + size_t group_size = + device.template get_info(); + + // set max_workitems = HW_workgroup_num * max_workgroup_size + int num_max_concurrent_workitem = + device.template get_info() * + device.template get_info< + sycl::ext::intel::info::device::gpu_subslices_per_slice>() * + group_size; + int num_workitem = local_vec_count <= num_max_concurrent_workitem + ? local_vec_count + : num_max_concurrent_workitem; + size_t num_vec_per_workitem = local_vec_count / num_workitem; + size_t num_tail_vec = local_vec_count % num_workitem; + + int num_workgroup = (num_workitem + group_size - 1) / group_size; + + if (reduction_size <= MAX_RANK_SIZE) { stream->submit([&](sycl::handler& cgh) { - const T* in_ptr[MAX_RANK_SIZE]; + T* in_ptr[MAX_RANK_SIZE]; T* out_ptr[MAX_RANK_SIZE]; for (int i = 0; i < reduction_size; ++i) { - in_ptr[i] = static_cast(inputs[i]); - out_ptr[i] = static_cast(outputs[i]); + in_ptr[i] = static_cast(const_cast(inputs[i])) + + rank * slice_vec_count * VecSize; + out_ptr[i] = + static_cast(outputs[i]) + rank * slice_vec_count * VecSize; } - cgh.parallel_for>( - sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup), - sycl::range<1>(group_size)), - [=](sycl::nd_item<1> item) { - const int index = item.get_global_linear_id(); - if (index >= num_elements) return; - - AccT acc = AccT(in_ptr[0][index]); - for (int i = 1; i < reduction_size; ++i) - acc = Func()(acc, AccT(in_ptr[i][index])); - for (int i = 0; i < reduction_size; ++i) out_ptr[i][index] = T(acc); - }); + // Last rank may need to process the tail elements which can't form a + // full vector and need partial block store. + if (rank != (reduction_size - 1) || vec_tail_element_count == 0) { + cgh.parallel_for>( + sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup), + sycl::range<1>(group_size)), + [=](sycl::nd_item<1> item) { + const int index = item.get_global_linear_id(); + if (index >= num_workitem) return; + + for (size_t n = 0; n < num_vec_per_workitem; ++n) { + size_t offset = (num_workitem * n + index) * VecSize; + AlignedVector result; + result.Load(*reinterpret_cast*>( + &(in_ptr[0][offset]))); + for (int i = 1; i < reduction_size; ++i) + result.Accumulate( + *reinterpret_cast*>( + &(in_ptr[i][offset]))); + for (int i = 0; i < reduction_size; ++i) + result.Store(*reinterpret_cast*>( + &(out_ptr[i][offset]))); + } + + if (index < num_tail_vec) { + size_t offset = + (num_workitem * num_vec_per_workitem + index) * VecSize; + AlignedVector result; + result.Load(*reinterpret_cast*>( + &(in_ptr[0][offset]))); + for (int i = 1; i < reduction_size; ++i) + result.Accumulate( + *reinterpret_cast*>( + &(in_ptr[i][offset]))); + for (int i = 0; i < reduction_size; ++i) + result.Store(*reinterpret_cast*>( + &(out_ptr[i][offset]))); + } + }); + } else { + cgh.parallel_for>( + sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup), + sycl::range<1>(group_size)), + [=](sycl::nd_item<1> item) { + const int index = item.get_global_linear_id(); + if (index >= num_workitem) return; + + for (size_t n = 0; n < num_vec_per_workitem; ++n) { + size_t offset = (num_workitem * n + index) * VecSize; + AlignedVector result; + result.Load(*reinterpret_cast*>( + &(in_ptr[0][offset]))); + for (int i = 1; i < reduction_size; ++i) + result.Accumulate( + *reinterpret_cast*>( + &(in_ptr[i][offset]))); + + if (local_vec_count > num_workitem || + index != (num_workitem - 1)) { + for (int i = 0; i < reduction_size; ++i) + result.Store(*reinterpret_cast*>( + &(out_ptr[i][offset]))); + } else { // the last workitem may process a partial vector + for (int i = 0; i < reduction_size; ++i) + result.PartialStore( + *reinterpret_cast*>( + &(out_ptr[i][offset])), + vec_tail_element_count); + } + } + + if (index < num_tail_vec) { + size_t offset = + (num_workitem * num_vec_per_workitem + index) * VecSize; + AlignedVector result; + result.Load(*reinterpret_cast*>( + &(in_ptr[0][offset]))); + for (int i = 1; i < reduction_size; ++i) + result.Accumulate( + *reinterpret_cast*>( + &(in_ptr[i][offset]))); + + if (index != num_tail_vec - 1) { + for (int i = 0; i < reduction_size; ++i) + result.Store(*reinterpret_cast*>( + &(out_ptr[i][offset]))); + } else { // the last workitem may process a partial vector + for (int i = 0; i < reduction_size; ++i) + result.PartialStore( + *reinterpret_cast*>( + &(out_ptr[i][offset])), + vec_tail_element_count); + } + } + }); + } }); } else { ITEX_LOG(FATAL) << "Reduction size " << reduction_size @@ -400,8 +489,6 @@ void LaunchAllReduceKernel(ITEX_GPUStream* stream, } Status CollectiveManager::RunAllReduce(Collective* collective) { - ITEX_GPUStream* comm_stream = GetCommStream(collective); - StreamWaitStreamlist(comm_stream, collective); DataType data_type = collective->data_type; ReductionOp reduction_op = collective->reduction_op; auto num_elements = collective->participants[0]->input->NumElements(); @@ -422,103 +509,130 @@ Status CollectiveManager::RunAllReduce(Collective* collective) { outputs.push_back(participant->output->data()); } - if (reduction_op == ReductionOp::SUM) { - switch (data_type) { - case DT_BFLOAT16: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_HALF: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_FLOAT: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_INT32: - LaunchAllReduceKernel, int>(comm_stream, inputs, - outputs, num_elements); - break; - default: - return errors::InvalidArgument( - "Collective Allreduce unsupports datatype ", - DataTypeString(data_type)); - } - } else if (reduction_op == ReductionOp::MIN) { - switch (data_type) { - case DT_BFLOAT16: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_HALF: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_FLOAT: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_INT32: - LaunchAllReduceKernel, int>( - comm_stream, inputs, outputs, num_elements); - break; - default: - return errors::InvalidArgument( - "Collective Allreduce unsupports datatype ", - DataTypeString(data_type)); - } - } else if (reduction_op == ReductionOp::MAX) { - switch (data_type) { - case DT_BFLOAT16: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_HALF: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_FLOAT: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_INT32: - LaunchAllReduceKernel, int>( - comm_stream, inputs, outputs, num_elements); - break; - default: - return errors::InvalidArgument( - "Collective Allreduce unsupports datatype ", - DataTypeString(data_type)); - } - } else if (reduction_op == ReductionOp::PROD) { - switch (data_type) { - case DT_BFLOAT16: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_HALF: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_FLOAT: - LaunchAllReduceKernel, float>( - comm_stream, inputs, outputs, num_elements); - break; - case DT_INT32: - LaunchAllReduceKernel, int>( - comm_stream, inputs, outputs, num_elements); - break; - default: - return errors::InvalidArgument( - "Collective Allreduce unsupports datatype ", - DataTypeString(data_type)); + std::vector comm_streams; + std::vector begin_events; + std::vector end_events; + int reduction_size = collective->participants.size(); + for (int i = 0; i < reduction_size; ++i) { + auto comm_stream = GetCommStream(collective->participants[i].get()); + comm_streams.push_back(comm_stream); + + // TODO(intel): use barrier instead of wait once barrier bug is fixed. + comm_stream->wait(); + // auto begin_event = comm_stream->ext_oneapi_submit_barrier(); + // begin_events.push_back(begin_event); + } + + for (int i = 0; i < reduction_size; ++i) { + auto comm_stream = comm_streams[i]; + // comm_stream->ext_oneapi_submit_barrier(begin_events); + + if (reduction_op == ReductionOp::SUM) { + switch (data_type) { + case DT_BFLOAT16: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_HALF: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_FLOAT: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_INT32: + LaunchAllReduceKernel, int>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + default: + return errors::InvalidArgument( + "Collective Allreduce unsupports datatype ", + DataTypeString(data_type)); + } + } else if (reduction_op == ReductionOp::MIN) { + switch (data_type) { + case DT_BFLOAT16: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_HALF: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_FLOAT: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_INT32: + LaunchAllReduceKernel, int>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + default: + return errors::InvalidArgument( + "Collective Allreduce unsupports datatype ", + DataTypeString(data_type)); + } + } else if (reduction_op == ReductionOp::MAX) { + switch (data_type) { + case DT_BFLOAT16: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_HALF: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_FLOAT: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_INT32: + LaunchAllReduceKernel, int>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + default: + return errors::InvalidArgument( + "Collective Allreduce unsupports datatype ", + DataTypeString(data_type)); + } + } else if (reduction_op == ReductionOp::PROD) { + switch (data_type) { + case DT_BFLOAT16: + LaunchAllReduceKernel, + float>(comm_stream, num_elements, inputs, + outputs, i, reduction_size); + break; + case DT_HALF: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_FLOAT: + LaunchAllReduceKernel, float>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + case DT_INT32: + LaunchAllReduceKernel, int>( + comm_stream, num_elements, inputs, outputs, i, reduction_size); + break; + default: + return errors::InvalidArgument( + "Collective Allreduce unsupports datatype ", + DataTypeString(data_type)); + } + } else { + return errors::InvalidArgument( + "Collective Allreduce unsupports ReductionOp yet!"); } - } else { - return errors::InvalidArgument( - "Collective Allreduce unsupports ReductionOp yet!"); + // auto event = comm_stream->ext_oneapi_submit_barrier(); + // end_events.push_back(event); + } + + for (int i = 0; i < reduction_size; ++i) { + // TODO(intel): use barrier instead of wait once barrier bug is fixed. + comm_streams[i]->wait(); + // comm_streams[i]->ext_oneapi_submit_barrier(end_events); } - StreamlistWaitStream(comm_stream, collective); return Status::OK(); } diff --git a/itex/core/utils/gpu_helper.h b/itex/core/utils/gpu_helper.h index 7fe9b0eed..8de432083 100644 --- a/itex/core/utils/gpu_helper.h +++ b/itex/core/utils/gpu_helper.h @@ -203,17 +203,19 @@ constexpr int NumBits(const unsigned int n) { // Represents an aligned array of N elements of T. Data pointers can be // reinterpreted as this type to generate vectorized loads/stores in a kernel. -template +template > class alignas(alignof(T) * N) AlignedVector { public: typedef T value_type; - static constexpr const int kSize = N; + static constexpr const uint32_t kSize = N; AlignedVector() = default; // Uniform initialization. explicit AlignedVector(value_type uniform) { - UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; } + UNROLL_ON_DEVICE for (uint32_t i = 0; i < kSize; ++i) { + values_[i] = uniform; + } } // Uniform initialization with explicit conversion. // Note: This is required for T=Eigen::half because it only supports explicit @@ -223,13 +225,17 @@ class alignas(alignof(T) * N) AlignedVector { int>::type = 0> explicit AlignedVector(U uniform_u) { value_type uniform(uniform_u); - UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; } + UNROLL_ON_DEVICE for (uint32_t i = 0; i < kSize; ++i) { + values_[i] = uniform; + } } // Implicit conversion. template ::value, int>::type = 0> AlignedVector(const AlignedVector& other) { - UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = other[i]; } + UNROLL_ON_DEVICE for (uint32_t i = 0; i < kSize; ++i) { + values_[i] = other[i]; + } } // Explicit conversion. template ::value, int>::type = 0> explicit AlignedVector(const AlignedVector& other) { - UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { + UNROLL_ON_DEVICE for (uint32_t i = 0; i < kSize; ++i) { values_[i] = T(other[i]); } } - value_type& operator[](int i) { return values_[i]; } - const value_type& operator[](int i) const { return values_[i]; } + template + void Load(const AlignedVector& other) { + UNROLL_ON_DEVICE for (uint32_t i = 0; i < N; ++i) { + values_[i] = static_cast(other[i]); + } + } + + template + void Accumulate(const AlignedVector& other) { + UNROLL_ON_DEVICE for (uint32_t i = 0; i < N; ++i) { + values_[i] = Func()(values_[i], static_cast(other[i])); + } + } + + template + void Store(AlignedVector& other) { + UNROLL_ON_DEVICE for (uint32_t i = 0; i < N; ++i) { + other[i] = static_cast(values_[i]); + } + } + + template + void PartialStore(AlignedVector& other, uint32_t num, + uint32_t offset = 0) { + UNROLL_ON_DEVICE for (uint32_t i = 0; i < N && i < num; ++i) { + other[i] = static_cast(values_[i + offset]); + } + } + + value_type& operator[](uint32_t i) { return values_[i]; } + const value_type& operator[](uint32_t i) const { return values_[i]; } -#define DEFINE_BINARY_UPDATE_OPERATOR(op) \ - AlignedVector& operator op(const AlignedVector& rhs) { \ - UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] op rhs[i]; } \ - return *this; \ +#define DEFINE_BINARY_UPDATE_OPERATOR(op) \ + AlignedVector& operator op(const AlignedVector& rhs) { \ + UNROLL_ON_DEVICE for (uint32_t i = 0; i < kSize; ++i) { \ + values_[i] op rhs[i]; \ + } \ + return *this; \ } DEFINE_BINARY_UPDATE_OPERATOR(+=) DEFINE_BINARY_UPDATE_OPERATOR(-=) @@ -260,7 +297,7 @@ class alignas(alignof(T) * N) AlignedVector { friend AlignedVector operator op(const AlignedVector& lhs, \ const AlignedVector& rhs) { \ AlignedVector ret; \ - UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { \ + UNROLL_ON_DEVICE for (uint32_t i = 0; i < kSize; ++i) { \ ret[i] = lhs[i] op rhs[i]; \ } \ return ret; \ @@ -271,14 +308,14 @@ class alignas(alignof(T) * N) AlignedVector { DEFINE_BINARY_OPERATOR(/) #undef DEFINE_BINARY_OPERATOR -#define DEFINE_BINARY_FUNCTION(func) \ - friend AlignedVector func(const AlignedVector& lhs, \ - const AlignedVector& rhs) { \ - AlignedVector ret; \ - UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { \ - ret[i] = func(lhs[i], rhs[i]); \ - } \ - return ret; \ +#define DEFINE_BINARY_FUNCTION(func) \ + friend AlignedVector func(const AlignedVector& lhs, \ + const AlignedVector& rhs) { \ + AlignedVector ret; \ + UNROLL_ON_DEVICE for (uint32_t i = 0; i < kSize; ++i) { \ + ret[i] = func(lhs[i], rhs[i]); \ + } \ + return ret; \ } DEFINE_BINARY_FUNCTION(min) DEFINE_BINARY_FUNCTION(max)