Skip to content

Commit

Permalink
remove unnecessary copy
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Sep 25, 2024
1 parent 6081978 commit 73b3ffe
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions cpp/src/c_api/neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,18 +867,9 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
handle_.get_stream());

std::optional<rmm::device_uvector<label_t>> start_vertex_labels{std::nullopt};
std::optional<rmm::device_uvector<size_t>> start_vertex_offsets{std::nullopt};
std::optional<rmm::device_uvector<label_t>> label_to_comm_rank{std::nullopt};

if (start_vertex_offsets_ != nullptr) {
start_vertex_offsets =
rmm::device_uvector<size_t>{start_vertex_offsets_->size_, handle_.get_stream()};
raft::copy(start_vertex_offsets->data(),
start_vertex_offsets_->as_type<size_t>(),
start_vertex_offsets_->size_,
handle_.get_stream());


// Get the number of labels on each GPU
auto num_local_labels = start_vertex_offsets_->size_ - 1;
auto global_labels = cugraph::host_scalar_allgather(
Expand All @@ -889,16 +880,16 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {

// Compute the global start_vertex_label_offsets
cugraph::detail::transform_increment(handle_.get_stream(),
(*start_vertex_offsets).begin(),
(*start_vertex_offsets).size(),
start_vertex_offsets_->as_type<size_t>(),
start_vertex_offsets_->size_,
global_labels[handle_.get_comms().get_rank()]);

// Retrieve the start_vertex_labels
start_vertex_labels =
cugraph::detail::convert_starting_vertex_offsets_to_labels(
handle_,
raft::device_span<size_t const>{(*start_vertex_offsets).data(),
(*start_vertex_offsets).size()});
raft::device_span<size_t const>{start_vertex_offsets_->as_type<size_t>(),
start_vertex_offsets_->size_});
}

if constexpr (multi_gpu) {
Expand Down Expand Up @@ -1246,7 +1237,7 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
raft::device_span<vertex_t const>{
vertex_type_offsets.data(), vertex_type_offsets.size()},

start_vertex_offsets ? start_vertex_offsets->size() : size_t{1},
start_vertex_offsets_ ? start_vertex_offsets_->size_ : size_t{1},
hop ? fan_out_->size_ : size_t{1},
size_t{1},
size_t{1},
Expand Down

0 comments on commit 73b3ffe

Please sign in to comment.