Skip to content

Commit

Permalink
leverage raft span instead of raw pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Oct 30, 2024
1 parent b159e1b commit 3b0c016
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 47 deletions.
9 changes: 3 additions & 6 deletions cpp/include/cugraph/detail/utility_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ void scalar_fill(raft::handle_t const& handle, value_t* d_value, size_t size, va
* @param [in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
* @param[out] d_value device array to sort
* @param[in] size number of elements in array
*
*/
template <typename value_t>
void sort(raft::handle_t const& handle, value_t* d_value, size_t size);
void sort(raft::handle_t const& handle, raft::device_span<value_t> d_value);

/**
* @brief Keep unique element from a buffer
Expand All @@ -87,12 +86,11 @@ void sort(raft::handle_t const& handle, value_t* d_value, size_t size);
* @param [in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
* @param[in] d_value device array of unique elements.
* @param[in] size number of elements in array
* @return the number of unique elements.
*
*/
template <typename value_t>
size_t unique(raft::handle_t const& handle, value_t* d_value, size_t size);
size_t unique(raft::handle_t const& handle, raft::device_span<value_t> d_value);

/**
* @brief Increment the values of a buffer by a constant value
Expand All @@ -102,12 +100,11 @@ size_t unique(raft::handle_t const& handle, value_t* d_value, size_t size);
* @param [in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
* @param[out] d_value device array to update
* @param[in] size number of elements in array
* @param[in] value value to be added to each element of the buffer
*
*/
template <typename value_t>
void transform_increment(raft::handle_t const& handle, value_t* d_value, size_t size, size_t value);
void transform_increment(rmm::cuda_stream_view const& stream_view, raft::device_span<value_t> d_value, value_t value);

/**
* @brief Fill a buffer with a sequence of values
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/c_api/neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,9 +884,9 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
start_vertex_offsets_->size_});

// Compute the global start_vertex_label_offsets

cugraph::detail::transform_increment(handle_.get_stream(),
(label_t*)(*start_vertex_labels).data(),
(size_t)(*start_vertex_labels).size(),
raft::device_span<label_t>{(*start_vertex_labels).data(), (*start_vertex_labels).size()},
(label_t)global_labels[handle_.get_comms().get_rank()]
);
}
Expand All @@ -902,9 +902,10 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {

// Get unique labels
// sort the start_vertex_labels
cugraph::detail::sort(handle_.get_stream(), unique_labels.begin(), unique_labels.size());
cugraph::detail::sort(handle_.get_stream(), raft::device_span<label_t>{unique_labels.data(), unique_labels.size()});

auto num_unique_labels = cugraph::detail::unique(
handle_.get_stream(), unique_labels.begin(), unique_labels.size());
handle_.get_stream(), raft::device_span<label_t const>{unique_labels.data(), unique_labels.size()});

(*local_label_to_comm_rank).resize(num_unique_labels, handle_.get_stream());

Expand Down
15 changes: 7 additions & 8 deletions cpp/src/detail/utility_wrappers_32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ template void scalar_fill(raft::handle_t const& handle, size_t* d_value, size_t

template void scalar_fill(raft::handle_t const& handle, float* d_value, size_t size, float value);

template void sort(raft::handle_t const& handle, int32_t* d_value, size_t size);
template void sort(raft::handle_t const& handle, raft::device_span<int32_t> d_value);

template size_t unique(raft::handle_t const& handle, int32_t* d_value, size_t size);
template size_t unique(raft::handle_t const& handle, raft::device_span<int32_t> d_value);
template size_t unique(raft::handle_t const& handle, raft::device_span<uint32_t> d_value);

template void sequence_fill(rmm::cuda_stream_view const& stream_view,
int32_t* d_value,
Expand All @@ -78,14 +79,12 @@ template void sequence_fill(rmm::cuda_stream_view const& stream_view,
uint32_t start_value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
int32_t* d_value,
size_t size,
size_t value);
raft::device_span<int32_t> d_value,
int32_t value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
uint32_t* d_value,
size_t size,
size_t value);
raft::device_span<uint32_t> d_value,
uint32_t value);

template void stride_fill(rmm::cuda_stream_view const& stream_view,
int32_t* d_value,
Expand Down
15 changes: 7 additions & 8 deletions cpp/src/detail/utility_wrappers_64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ template void scalar_fill(raft::handle_t const& handle,

template void scalar_fill(raft::handle_t const& handle, double* d_value, size_t size, double value);

template void sort(raft::handle_t const& handle, int64_t* d_value, size_t size);
template void sort(raft::handle_t const& handle, raft::device_span<int64_t> d_value);

template size_t unique(raft::handle_t const& handle, int32_t* d_value, size_t size);
template size_t unique(raft::handle_t const& handle, raft::device_span<int64_t> d_value);
template size_t unique(raft::handle_t const& handle, raft::device_span<uint64_t> d_value);

template void sequence_fill(rmm::cuda_stream_view const& stream_view,
int64_t* d_value,
Expand All @@ -76,14 +77,12 @@ template void sequence_fill(rmm::cuda_stream_view const& stream_view,
uint64_t start_value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
int64_t* d_value,
size_t size,
size_t value);
raft::device_span<int64_t> d_value,
int64_t value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
uint64_t* d_value,
size_t size,
size_t value);
raft::device_span<uint64_t> d_value,
uint64_t value);

template void stride_fill(rmm::cuda_stream_view const& stream_view,
int64_t* d_value,
Expand Down
47 changes: 26 additions & 21 deletions cpp/src/detail/utility_wrappers_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,16 @@ void scalar_fill(raft::handle_t const& handle, value_t* d_value, size_t size, va
}

template <typename value_t>
void sort(raft::handle_t const& handle, value_t* d_value, size_t size)
void sort(raft::handle_t const& handle, raft::device_span<value_t> d_value)
{
thrust::sort(handle.get_thrust_policy(), d_value, d_value + size);
thrust::sort(handle.get_thrust_policy(), d_value.begin(), d_value.end());
}

template <typename value_t>
size_t unique(raft::handle_t const& handle, value_t* d_value, size_t size)
size_t unique(raft::handle_t const& handle, raft::device_span<value_t> d_value)
{
// auto unique_element_last = thrust::unique(handle.get_thrust_policy(), d_value, d_value + size);
auto unique_element_last = thrust::unique(handle.get_thrust_policy(), d_value, d_value + size);
// auto num_unique_element =
return thrust::distance(d_value, unique_element_last);
// masked_edgelist_srcs.resize(2* masked_edgelist_srcs.size(), handle.get_stream());
auto unique_element_last = thrust::unique(handle.get_thrust_policy(), d_value.begin(), d_value.end());
return thrust::distance(d_value.begin(), unique_element_last);
}

template <typename value_t>
Expand All @@ -89,20 +86,28 @@ void sequence_fill(rmm::cuda_stream_view const& stream_view,
thrust::sequence(rmm::exec_policy(stream_view), d_value, d_value + size, start_value);
}


template <typename value_t>
void transform_increment(rmm::cuda_stream_view const& stream_view,
value_t* d_value,
size_t size,
size_t incr)
{
thrust::transform(rmm::exec_policy(stream_view),
d_value,
d_value + size,
d_value,
cuda::proclaim_return_type<value_t>([incr] __device__(value_t value) {
return static_cast<value_t>(value + incr);
}));
}
void transform_increment(rmm::cuda_stream_view const& stream_view,
raft::device_span<value_t> d_value,
value_t incr)
{
thrust::transform(rmm::exec_policy(stream_view),
d_value.begin(),
d_value.end(),
d_value.begin(),
cuda::proclaim_return_type<value_t>([incr] __device__(value_t value) {
return static_cast<value_t>(value + incr);
}));
}


template <typename value_t>
void transform_increment_(rmm::cuda_stream_view const& stream_view,
value_t d_value)
{

}

template <typename value_t>
void stride_fill(rmm::cuda_stream_view const& stream_view,
Expand Down

0 comments on commit 3b0c016

Please sign in to comment.