Skip to content

Commit

Permalink
Add BF16 datatype for collectives (#120)
Browse files Browse the repository at this point in the history
Co-authored-by: Lu, Fengqing <fengqing.lu@intel.com>
  • Loading branch information
Lu Teng and LuFinch authored Nov 16, 2023
1 parent 100a05e commit 123546a
Showing 1 changed file with 177 additions and 56 deletions.
233 changes: 177 additions & 56 deletions xla/service/gpu/ccl_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ void allreduce_dpcpp(se::gpu::GpuStreamHandle stream, int tensor_size,

if (reduction_size == 2) {
stream->submit([&](sycl::handler& cgh) {
IN_PTR(0); IN_PTR(1); OUT_PTR(0); OUT_PTR(1);
IN_PTR(0);
IN_PTR(1);
OUT_PTR(0);
OUT_PTR(1);

cgh.parallel_for<AllReduceKernel<T, Func, 2>>(
sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup),
Expand All @@ -110,104 +113,208 @@ void allreduce_dpcpp(se::gpu::GpuStreamHandle stream, int tensor_size,
});
} else if (reduction_size == 4) {
stream->submit([&](sycl::handler& cgh) {
IN_PTR(0); IN_PTR(1); IN_PTR(2); IN_PTR(3);
OUT_PTR(0); OUT_PTR(1); OUT_PTR(2); OUT_PTR(3);
IN_PTR(0);
IN_PTR(1);
IN_PTR(2);
IN_PTR(3);
OUT_PTR(0);
OUT_PTR(1);
OUT_PTR(2);
OUT_PTR(3);

cgh.parallel_for<AllReduceKernel<T, Func, 4>>(
sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup),
sycl::range<1>(group_size)),
[=](sycl::nd_item<1> item) {
const int index = item.get_global_linear_id();
if (index >= tensor_size) return;
ADD_DPCPP(0, 1); ADD_DPCPP(0, 2); ADD_DPCPP(0, 3);
COPY_OUT_DPCPP(1, 0); COPY_OUT_DPCPP(2, 0); COPY_OUT_DPCPP(3, 0);
ADD_DPCPP(0, 1);
ADD_DPCPP(0, 2);
ADD_DPCPP(0, 3);
COPY_OUT_DPCPP(1, 0);
COPY_OUT_DPCPP(2, 0);
COPY_OUT_DPCPP(3, 0);
});
});
} else if (reduction_size == 6) {
stream->submit([&](sycl::handler& cgh) {
IN_PTR(0); IN_PTR(1); IN_PTR(2); IN_PTR(3);
IN_PTR(4); IN_PTR(5);
OUT_PTR(0); OUT_PTR(1); OUT_PTR(2); OUT_PTR(3);
OUT_PTR(4); OUT_PTR(5);
IN_PTR(0);
IN_PTR(1);
IN_PTR(2);
IN_PTR(3);
IN_PTR(4);
IN_PTR(5);
OUT_PTR(0);
OUT_PTR(1);
OUT_PTR(2);
OUT_PTR(3);
OUT_PTR(4);
OUT_PTR(5);

cgh.parallel_for<AllReduceKernel<T, Func, 6>>(
sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup),
sycl::range<1>(group_size)),
[=](sycl::nd_item<1> item) {
const int index = item.get_global_linear_id();
if (index >= tensor_size) return;
ADD_DPCPP(0, 1); ADD_DPCPP(0, 2); ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4); ADD_DPCPP(0, 5);
COPY_OUT_DPCPP(1, 0); COPY_OUT_DPCPP(2, 0); COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0); COPY_OUT_DPCPP(5, 0);
ADD_DPCPP(0, 1);
ADD_DPCPP(0, 2);
ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4);
ADD_DPCPP(0, 5);
COPY_OUT_DPCPP(1, 0);
COPY_OUT_DPCPP(2, 0);
COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0);
COPY_OUT_DPCPP(5, 0);
});
});
} else if (reduction_size == 8) {
stream->submit([&](sycl::handler& cgh) {
IN_PTR(0); IN_PTR(1); IN_PTR(2); IN_PTR(3);
IN_PTR(4); IN_PTR(5); IN_PTR(6); IN_PTR(7);
OUT_PTR(0); OUT_PTR(1); OUT_PTR(2); OUT_PTR(3);
OUT_PTR(4); OUT_PTR(5); OUT_PTR(6); OUT_PTR(7);
IN_PTR(0);
IN_PTR(1);
IN_PTR(2);
IN_PTR(3);
IN_PTR(4);
IN_PTR(5);
IN_PTR(6);
IN_PTR(7);
OUT_PTR(0);
OUT_PTR(1);
OUT_PTR(2);
OUT_PTR(3);
OUT_PTR(4);
OUT_PTR(5);
OUT_PTR(6);
OUT_PTR(7);

cgh.parallel_for<AllReduceKernel<T, Func, 8>>(
sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup),
sycl::range<1>(group_size)),
[=](sycl::nd_item<1> item) {
const int index = item.get_global_linear_id();
if (index >= tensor_size) return;
ADD_DPCPP(0, 1); ADD_DPCPP(0, 2); ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4); ADD_DPCPP(0, 5); ADD_DPCPP(0, 6); ADD_DPCPP(0, 7);
COPY_OUT_DPCPP(1, 0); COPY_OUT_DPCPP(2, 0); COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0); COPY_OUT_DPCPP(5, 0); COPY_OUT_DPCPP(6, 0);
ADD_DPCPP(0, 1);
ADD_DPCPP(0, 2);
ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4);
ADD_DPCPP(0, 5);
ADD_DPCPP(0, 6);
ADD_DPCPP(0, 7);
COPY_OUT_DPCPP(1, 0);
COPY_OUT_DPCPP(2, 0);
COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0);
COPY_OUT_DPCPP(5, 0);
COPY_OUT_DPCPP(6, 0);
COPY_OUT_DPCPP(7, 0);
});
});
} else if (reduction_size == 10) {
stream->submit([&](sycl::handler& cgh) {
IN_PTR(0); IN_PTR(1); IN_PTR(2); IN_PTR(3);
IN_PTR(4); IN_PTR(5); IN_PTR(6); IN_PTR(7);
IN_PTR(8); IN_PTR(9);
OUT_PTR(0); OUT_PTR(1); OUT_PTR(2); OUT_PTR(3);
OUT_PTR(4); OUT_PTR(5); OUT_PTR(6); OUT_PTR(7);
OUT_PTR(8); OUT_PTR(9);;
IN_PTR(0);
IN_PTR(1);
IN_PTR(2);
IN_PTR(3);
IN_PTR(4);
IN_PTR(5);
IN_PTR(6);
IN_PTR(7);
IN_PTR(8);
IN_PTR(9);
OUT_PTR(0);
OUT_PTR(1);
OUT_PTR(2);
OUT_PTR(3);
OUT_PTR(4);
OUT_PTR(5);
OUT_PTR(6);
OUT_PTR(7);
OUT_PTR(8);
OUT_PTR(9);
;

cgh.parallel_for<AllReduceKernel<T, Func, 10>>(
sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup),
sycl::range<1>(group_size)),
[=](sycl::nd_item<1> item) {
const int index = item.get_global_linear_id();
if (index >= tensor_size) return;
ADD_DPCPP(0, 1); ADD_DPCPP(0, 2); ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4); ADD_DPCPP(0, 5); ADD_DPCPP(0, 6);
ADD_DPCPP(0, 7); ADD_DPCPP(0, 8); ADD_DPCPP(0, 9);
COPY_OUT_DPCPP(1, 0); COPY_OUT_DPCPP(2, 0); COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0); COPY_OUT_DPCPP(5, 0); COPY_OUT_DPCPP(6, 0);
COPY_OUT_DPCPP(7, 0); COPY_OUT_DPCPP(8, 0); COPY_OUT_DPCPP(9, 0);
ADD_DPCPP(0, 1);
ADD_DPCPP(0, 2);
ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4);
ADD_DPCPP(0, 5);
ADD_DPCPP(0, 6);
ADD_DPCPP(0, 7);
ADD_DPCPP(0, 8);
ADD_DPCPP(0, 9);
COPY_OUT_DPCPP(1, 0);
COPY_OUT_DPCPP(2, 0);
COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0);
COPY_OUT_DPCPP(5, 0);
COPY_OUT_DPCPP(6, 0);
COPY_OUT_DPCPP(7, 0);
COPY_OUT_DPCPP(8, 0);
COPY_OUT_DPCPP(9, 0);
});
});
} else if (reduction_size == 12) {
stream->submit([&](sycl::handler& cgh) {
IN_PTR(0); IN_PTR(1); IN_PTR(2); IN_PTR(3);
IN_PTR(4); IN_PTR(5); IN_PTR(6); IN_PTR(7);
IN_PTR(8); IN_PTR(9); IN_PTR(10); IN_PTR(11);
OUT_PTR(0); OUT_PTR(1); OUT_PTR(2); OUT_PTR(3);
OUT_PTR(4); OUT_PTR(5); OUT_PTR(6); OUT_PTR(7);
OUT_PTR(8); OUT_PTR(9); OUT_PTR(10); OUT_PTR(11);
IN_PTR(0);
IN_PTR(1);
IN_PTR(2);
IN_PTR(3);
IN_PTR(4);
IN_PTR(5);
IN_PTR(6);
IN_PTR(7);
IN_PTR(8);
IN_PTR(9);
IN_PTR(10);
IN_PTR(11);
OUT_PTR(0);
OUT_PTR(1);
OUT_PTR(2);
OUT_PTR(3);
OUT_PTR(4);
OUT_PTR(5);
OUT_PTR(6);
OUT_PTR(7);
OUT_PTR(8);
OUT_PTR(9);
OUT_PTR(10);
OUT_PTR(11);

cgh.parallel_for<AllReduceKernel<T, Func, 12>>(
sycl::nd_range<1>(sycl::range<1>(group_size * num_workgroup),
sycl::range<1>(group_size)),
[=](sycl::nd_item<1> item) {
const int index = item.get_global_linear_id();
if (index >= tensor_size) return;
ADD_DPCPP(0, 1); ADD_DPCPP(0, 2); ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4); ADD_DPCPP(0, 5); ADD_DPCPP(0, 6);
ADD_DPCPP(0, 7); ADD_DPCPP(0, 8); ADD_DPCPP(0, 9);
ADD_DPCPP(0, 10); ADD_DPCPP(0, 11);
COPY_OUT_DPCPP(1, 0); COPY_OUT_DPCPP(2, 0); COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0); COPY_OUT_DPCPP(5, 0); COPY_OUT_DPCPP(6, 0);
COPY_OUT_DPCPP(7, 0); COPY_OUT_DPCPP(8, 0); COPY_OUT_DPCPP(9, 0);
COPY_OUT_DPCPP(10, 0); COPY_OUT_DPCPP(11, 0);
ADD_DPCPP(0, 1);
ADD_DPCPP(0, 2);
ADD_DPCPP(0, 3);
ADD_DPCPP(0, 4);
ADD_DPCPP(0, 5);
ADD_DPCPP(0, 6);
ADD_DPCPP(0, 7);
ADD_DPCPP(0, 8);
ADD_DPCPP(0, 9);
ADD_DPCPP(0, 10);
ADD_DPCPP(0, 11);
COPY_OUT_DPCPP(1, 0);
COPY_OUT_DPCPP(2, 0);
COPY_OUT_DPCPP(3, 0);
COPY_OUT_DPCPP(4, 0);
COPY_OUT_DPCPP(5, 0);
COPY_OUT_DPCPP(6, 0);
COPY_OUT_DPCPP(7, 0);
COPY_OUT_DPCPP(8, 0);
COPY_OUT_DPCPP(9, 0);
COPY_OUT_DPCPP(10, 0);
COPY_OUT_DPCPP(11, 0);
});
});
} else {
Expand Down Expand Up @@ -299,7 +406,7 @@ void alltoall_dpcpp(se::gpu::GpuStreamHandle stream, int tensor_size,
template <typename T, typename Func>
struct ReduceScatterKernel;

template <typename T, typename Func>
template <typename T, typename Func, typename AccT = T>
void reducescatter_dpcpp(se::gpu::GpuStreamHandle stream, int tensor_size,
std::vector<Participant>& participants,
int reduction_size) {
Expand Down Expand Up @@ -327,11 +434,11 @@ void reducescatter_dpcpp(se::gpu::GpuStreamHandle stream, int tensor_size,
const int index = item.get_global_linear_id();
if (index >= tensor_size) return;
for (int i = 0; i < reduction_size; ++i) {
out[i][index] = Func()(in[0][index + tensor_size * i],
in[1][index + tensor_size * i]);
out[i][index] = T(Func()(AccT(in[0][index + tensor_size * i]),
AccT(in[1][index + tensor_size * i])));
for (int j = 2; j < reduction_size; ++j) {
out[i][index] =
Func()(out[i][index], in[j][index + tensor_size * i]);
out[i][index] = T(Func()(AccT(out[i][index]),
AccT(in[j][index + tensor_size * i])));
}
}
});
Expand Down Expand Up @@ -444,7 +551,6 @@ void sycl_allreduce(const void* send_buffer, void* recv_buffer,
else if (dtype == BF16)
allreduce_dpcpp<bfloat16, sycl::plus<float>, float>(
stream, element_count, p, comm->nranks);

else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand Down Expand Up @@ -476,7 +582,6 @@ void sycl_allreduce(const void* send_buffer, void* recv_buffer,
else if (dtype == BF16)
allreduce_dpcpp<bfloat16, sycl::multiplies<float>, float>(
stream, element_count, p, comm->nranks);

else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand All @@ -500,7 +605,6 @@ void sycl_allreduce(const void* send_buffer, void* recv_buffer,
else if (dtype == BF16)
allreduce_dpcpp<bfloat16, sycl::minimum<float>, float>(
stream, element_count, p, comm->nranks);

else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand All @@ -524,7 +628,6 @@ void sycl_allreduce(const void* send_buffer, void* recv_buffer,
else if (dtype == BF16)
allreduce_dpcpp<bfloat16, sycl::maximum<float>, float>(
stream, element_count, p, comm->nranks);

else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand Down Expand Up @@ -581,6 +684,8 @@ void sycl_allgather(const void* send_buffer, void* recv_buffer,
allgather_dpcpp<int32_t>(stream, element_count, p, comm->nranks);
else if (dtype == S64)
allgather_dpcpp<int64_t>(stream, element_count, p, comm->nranks);
else if (dtype == BF16)
allgather_dpcpp<bfloat16>(stream, element_count, p, comm->nranks);
else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand Down Expand Up @@ -632,6 +737,8 @@ void sycl_alltoall(std::vector<const void*> send_buffers,
alltoall_dpcpp<int32_t>(stream, element_count, p, comm->nranks);
else if (dtype == S64)
alltoall_dpcpp<int64_t>(stream, element_count, p, comm->nranks);
else if (dtype == BF16)
alltoall_dpcpp<bfloat16>(stream, element_count, p, comm->nranks);
else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand Down Expand Up @@ -700,6 +807,9 @@ void sycl_reduce_scatter(const void* send_buffer, void* recv_buffer,
reducescatter_dpcpp<std::complex<double>,
sycl::plus<std::complex<double>>>(
stream, element_count, p, comm->nranks);
else if (dtype == BF16)
reducescatter_dpcpp<bfloat16, sycl::plus<float>, float>(
stream, element_count, p, comm->nranks);
else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand Down Expand Up @@ -728,6 +838,9 @@ void sycl_reduce_scatter(const void* send_buffer, void* recv_buffer,
reducescatter_dpcpp<std::complex<double>,
sycl::multiplies<std::complex<double>>>(
stream, element_count, p, comm->nranks);
else if (dtype == BF16)
reducescatter_dpcpp<bfloat16, sycl::multiplies<float>, float>(
stream, element_count, p, comm->nranks);
else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand All @@ -748,6 +861,9 @@ void sycl_reduce_scatter(const void* send_buffer, void* recv_buffer,
else if (dtype == S64)
reducescatter_dpcpp<int64_t, sycl::minimum<int64_t>>(
stream, element_count, p, comm->nranks);
else if (dtype == BF16)
reducescatter_dpcpp<bfloat16, sycl::minimum<float>, float>(
stream, element_count, p, comm->nranks);
else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand All @@ -768,6 +884,9 @@ void sycl_reduce_scatter(const void* send_buffer, void* recv_buffer,
else if (dtype == S64)
reducescatter_dpcpp<int64_t, sycl::maximum<int64_t>>(
stream, element_count, p, comm->nranks);
else if (dtype == BF16)
reducescatter_dpcpp<bfloat16, sycl::maximum<float>, float>(
stream, element_count, p, comm->nranks);
else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand Down Expand Up @@ -828,6 +947,8 @@ void sycl_collective_permute(const void* send_buffer, void* recv_buffer,
permute_dpcpp<int32_t>(stream, element_count, p, comm->nranks);
else if (dtype == S64)
permute_dpcpp<int64_t>(stream, element_count, p, comm->nranks);
else if (dtype == BF16)
permute_dpcpp<bfloat16>(stream, element_count, p, comm->nranks);
else
LOG(FATAL) << "PrimitiveType "
<< primitive_util::LowercasePrimitiveTypeName(dtype)
Expand Down

0 comments on commit 123546a

Please sign in to comment.