From 2df98f5658bd14a5f23c9b51ea1266086d343c4d Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Thu, 26 Sep 2024 12:14:15 -0700 Subject: [PATCH] ensure the graph properties match when symmetrizing --- cpp/src/c_api/graph_mg.cpp | 8 ++++++++ cpp/src/c_api/graph_sg.cpp | 16 ++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/cpp/src/c_api/graph_mg.cpp b/cpp/src/c_api/graph_mg.cpp index 530547bd915..ed6cf4250f2 100644 --- a/cpp/src/c_api/graph_mg.cpp +++ b/cpp/src/c_api/graph_mg.cpp @@ -378,6 +378,14 @@ extern "C" cugraph_error_code_t cugraph_graph_create_mg( if (weight_type == cugraph_data_type_id_t::NTYPES) weight_type = p_weights[i]->type_; } + if (symmetrize == TRUE){ + CAPI_EXPECTS( + (properties->is_symmetric == TRUE), + CUGRAPH_INVALID_INPUT, + "Invalid input arguments: The graph property must be symmetric if symmetrize is set to True.", + *error); + } + CAPI_EXPECTS(p_src[i]->type_ == vertex_type, CUGRAPH_INVALID_INPUT, "Invalid input arguments: all vertex types must match", diff --git a/cpp/src/c_api/graph_sg.cpp b/cpp/src/c_api/graph_sg.cpp index 52c3df5d07a..63d7ae43c49 100644 --- a/cpp/src/c_api/graph_sg.cpp +++ b/cpp/src/c_api/graph_sg.cpp @@ -581,6 +581,14 @@ extern "C" cugraph_error_code_t cugraph_graph_create_sg( auto p_edge_type_ids = reinterpret_cast(edge_type_ids); + if (symmetrize == TRUE){ + CAPI_EXPECTS( + (properties->is_symmetric == TRUE), + CUGRAPH_INVALID_INPUT, + "Invalid input arguments: The graph property must be symmetric if symmetrize is set to True.", + *error); + } + CAPI_EXPECTS(p_src->size_ == p_dst->size_, CUGRAPH_INVALID_INPUT, "Invalid input arguments: src size != dst size.", @@ -749,6 +757,14 @@ cugraph_error_code_t cugraph_graph_create_sg_from_csr( weight_type = cugraph_data_type_id_t::FLOAT32; } + if (symmetrize == TRUE){ + CAPI_EXPECTS( + (properties->is_symmetric == TRUE), + CUGRAPH_INVALID_INPUT, + "Invalid input arguments: The graph property must be symmetric if symmetrize is set to True.", + *error); + } + CAPI_EXPECTS( (edge_type_ids == nullptr && edge_ids == nullptr) || (edge_type_ids != nullptr && edge_ids != nullptr),