Skip to content

Commit

Permalink
Fix device communicator dependency (dmlc#9346)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored Jun 29, 2023
1 parent f479871 commit f90771e
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 119 deletions.
6 changes: 3 additions & 3 deletions src/collective/communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
} else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();
Expand Down
33 changes: 14 additions & 19 deletions src/collective/device_communicator_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,53 @@ namespace collective {

class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
}

~DeviceCommunicatorAdapter() override = default;

void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();

segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);

host_buffer_.reserve(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size; ++i) {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
if (i == rank_) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault));
}
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
Expand All @@ -76,7 +70,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {

private:
int const device_ordinal_;
Communicator *communicator_;
int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};
Expand Down
52 changes: 21 additions & 31 deletions src/collective/nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,24 @@
namespace xgboost {
namespace collective {

NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}

int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();

if (world == 1) {
if (world_size_ == 1) {
return;
}

std::vector<uint64_t> uuids(world * kUuidLength, 0);
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);

// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);

std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
Expand All @@ -41,18 +34,18 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);

CHECK_EQ(n_uniques, world)
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";

nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}

NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
if (cuda_stream_) {
Expand Down Expand Up @@ -139,9 +132,8 @@ void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func,

void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const world_size = communicator_->GetWorldSize();
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();

// First gather data from all the workers.
Expand All @@ -152,15 +144,15 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
cuda_stream_);
break;
default:
Expand All @@ -170,7 +162,7 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si

void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

Expand All @@ -189,24 +181,22 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();

segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);

size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
Expand All @@ -216,7 +206,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
}

void NcclDeviceCommunicator::Synchronize() {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
Expand Down
10 changes: 5 additions & 5 deletions src/collective/nccl_device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace collective {

class NcclDeviceCommunicator : public DeviceCommunicator {
public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator);
explicit NcclDeviceCommunicator(int device_ordinal);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
Expand Down Expand Up @@ -49,19 +49,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (communicator_->GetRank() == kRootRank) {
if (rank_ == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
static_cast<int>(kRootRank));
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}

void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op);

int const device_ordinal_;
Communicator *communicator_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{};
Expand Down
7 changes: 1 addition & 6 deletions tests/cpp/collective/test_nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@ namespace xgboost {
namespace collective {

TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) {
auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; };
auto construct = []() { NcclDeviceCommunicator comm{-1}; };
EXPECT_THROW(construct(), dmlc::Error);
}

Expand Down
9 changes: 8 additions & 1 deletion tests/cpp/plugin/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ class ServerForTest {
}

~ServerForTest() {
using namespace std::chrono_literals;
while (!server_) {
std::this_thread::sleep_for(100ms);
}
server_->Shutdown();
while (!server_thread_) {
std::this_thread::sleep_for(100ms);
}
server_thread_->join();
}

Expand All @@ -56,7 +63,7 @@ class BaseFederatedTest : public ::testing::Test {

void TearDown() override { server_.reset(nullptr); }

static int constexpr kWorldSize{3};
static int constexpr kWorldSize{2};
std::unique_ptr<ServerForTest> server_;
};

Expand Down
Loading

0 comments on commit f90771e

Please sign in to comment.