Skip to content

Commit

Permalink
call scatter instead of gather and fix type bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Sep 23, 2024
1 parent 36c25ad commit 4857b36
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cpp/src/sampling/detail/conversion_utilities.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ rmm::device_uvector<int32_t> flatten_label_map(
{
rmm::device_uvector<int32_t> label_map(0, handle.get_stream());

label_t max_label = thrust::reduce(handle.get_thrust_policy(),
label_t max_label = thrust::scatter(handle.get_thrust_policy(),
std::get<0>(label_to_output_comm_rank).begin(),
std::get<0>(label_to_output_comm_rank).end(),
label_t{0},
thrust::maximum<label_t>());

label_map.resize(max_label, handle.get_stream());

thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), label_t{0});
thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), int32_t);
thrust::gather(handle.get_thrust_policy(),
std::get<0>(label_to_output_comm_rank).begin(),
std::get<0>(label_to_output_comm_rank).end(),
Expand Down

0 comments on commit 4857b36

Please sign in to comment.