Skip to content

Commit

Permalink
update type combination
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Oct 22, 2024
1 parent 14e9a99 commit aebfd08
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 76 deletions.
61 changes: 61 additions & 0 deletions cpp/src/sampling/detail/shuffle_and_organize_output_mg_v32_e64.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "sampling/detail/shuffle_and_organize_output_impl.cuh"

namespace cugraph {
namespace detail {

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<float>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
shuffle_and_organize_output(
raft::handle_t const& handle,
rmm::device_uvector<int32_t>&& majors,
rmm::device_uvector<int32_t>&& minors,
std::optional<rmm::device_uvector<float>>&& weights,
std::optional<rmm::device_uvector<int64_t>>&& edge_ids,
std::optional<rmm::device_uvector<int32_t>>&& edge_types,
std::optional<rmm::device_uvector<int32_t>>&& hops,
std::optional<rmm::device_uvector<int32_t>>&& labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
shuffle_and_organize_output(
raft::handle_t const& handle,
rmm::device_uvector<int32_t>&& majors,
rmm::device_uvector<int32_t>&& minors,
std::optional<rmm::device_uvector<double>>&& weights,
std::optional<rmm::device_uvector<int64_t>>&& edge_ids,
std::optional<rmm::device_uvector<int32_t>>&& edge_types,
std::optional<rmm::device_uvector<int32_t>>&& hops,
std::optional<rmm::device_uvector<int32_t>>&& labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank);

} // namespace detail
} // namespace cugraph
14 changes: 0 additions & 14 deletions cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class Tests_Heterogeneous_Biased_Neighbor_Sampling
auto edge_weight_view =
edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt;

std::optional<cugraph::edge_property_t<decltype(graph_view), bool>> edge_mask{std::nullopt};

constexpr float select_probability{0.05};

// FIXME: Update the tests to initialize RngState and use it instead
Expand Down Expand Up @@ -251,12 +249,6 @@ TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int32Float)
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand All @@ -269,12 +261,6 @@ TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float)
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down
16 changes: 2 additions & 14 deletions cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ class Tests_Heterogeneous_Uniform_Neighbor_Sampling

// Generate the edge types

std::optional<cugraph::edge_property_t<decltype(graph_view), edge_type_t>> edge_types{
std::optional<cugraph::edge_property_t<decltype(graph_view), int32_t>> edge_types{
std::nullopt};

if (heterogeneous_uniform_neighbor_sampling_usecase.num_edge_types > 1) {
edge_types = cugraph::test::generate<decltype(graph_view), edge_type_t>::edge_property(
edge_types = cugraph::test::generate<decltype(graph_view), int32_t>::edge_property(
handle, graph_view, heterogeneous_uniform_neighbor_sampling_usecase.num_edge_types);
}

Expand Down Expand Up @@ -249,12 +249,6 @@ TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float)
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand All @@ -267,12 +261,6 @@ TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float)
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down
12 changes: 0 additions & 12 deletions cpp/tests/sampling/homogeneous_biased_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,6 @@ TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int32Float)
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_File, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand All @@ -251,12 +245,6 @@ TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float)
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down
17 changes: 5 additions & 12 deletions cpp/tests/sampling/homogeneous_uniform_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class Tests_Homogeneous_Uniform_Neighbor_Sampling
graph_view.attach_edge_mask((*edge_mask).view());
}

// FIXME: Read a tuple of two edge mask and mask out if edge mask is set in either 1 (OR) and create
// a new one.
// No graph view can have two mask and perform OR in itself, and need to OR the mask
// manually by itself.

constexpr float select_probability{0.05};

// FIXME: Update the tests to initialize RngState and use it instead
Expand Down Expand Up @@ -231,12 +236,6 @@ TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float)
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_File, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand All @@ -249,12 +248,6 @@ TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float)
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,6 @@ TEST_P(Tests_MGHeterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHeterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHeterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,6 @@ TEST_P(Tests_MGHeterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Floa
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHeterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHeterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,6 @@ TEST_P(Tests_MGHomogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float)
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHomogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHomogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,6 @@ TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float)
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float)
{
run_current_test<int32_t, int64_t, float>(
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
}

TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float)
{
run_current_test<int64_t, int64_t, float>(
Expand Down

0 comments on commit aebfd08

Please sign in to comment.