diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index ebee588309b..782ceb9d403 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -867,18 +867,9 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { handle_.get_stream()); std::optional> start_vertex_labels{std::nullopt}; - std::optional> start_vertex_offsets{std::nullopt}; std::optional> label_to_comm_rank{std::nullopt}; if (start_vertex_offsets_ != nullptr) { - start_vertex_offsets = - rmm::device_uvector{start_vertex_offsets_->size_, handle_.get_stream()}; - raft::copy(start_vertex_offsets->data(), - start_vertex_offsets_->as_type(), - start_vertex_offsets_->size_, - handle_.get_stream()); - - // Get the number of labels on each GPU auto num_local_labels = start_vertex_offsets_->size_ - 1; auto global_labels = cugraph::host_scalar_allgather( @@ -889,16 +880,16 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { // Compute the global start_vertex_label_offsets cugraph::detail::transform_increment(handle_.get_stream(), - (*start_vertex_offsets).begin(), - (*start_vertex_offsets).size(), + start_vertex_offsets_->as_type(), + start_vertex_offsets_->size_, global_labels[handle_.get_comms().get_rank()]); // Retrieve the start_vertex_labels start_vertex_labels = cugraph::detail::convert_starting_vertex_offsets_to_labels( handle_, - raft::device_span{(*start_vertex_offsets).data(), - (*start_vertex_offsets).size()}); + raft::device_span{start_vertex_offsets_->as_type(), + start_vertex_offsets_->size_}); } if constexpr (multi_gpu) { @@ -1246,7 +1237,7 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { raft::device_span{ vertex_type_offsets.data(), vertex_type_offsets.size()}, - start_vertex_offsets ? start_vertex_offsets->size() : size_t{1}, + start_vertex_offsets_ ? start_vertex_offsets_->size_ : size_t{1}, hop ? fan_out_->size_ : size_t{1}, size_t{1}, size_t{1},