diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index 77414028d7b..d708b6fbcc2 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -416,7 +416,7 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { (options_.compression_type_ == cugraph_compression_type_t::COO); if (options_.renumber_results_) { - if (fan_out_ != nullptr) { + if (num_edge_types_ == 1) { // homogeneous renumbering if (options_.compression_type_ == cugraph_compression_type_t::COO) { // COO @@ -442,13 +442,14 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { ? std::make_optional(raft::device_span{ start_vertices_->as_type(), start_vertices_->size_}) : std::nullopt, - options_.retain_seeds_ ? std::make_optional(raft::device_span{ - label_offsets_->as_type(), label_offsets_->size_}) + options_.retain_seeds_ ? (is_deprecated_api_? std::make_optional(raft::device_span{ + label_offsets_->as_type(), label_offsets_->size_}) : std::make_optional(raft::device_span{ + start_vertex_offsets_->as_type(), start_vertex_offsets_->size_})) : std::nullopt, offsets ? std::make_optional( raft::device_span{offsets->data(), offsets->size()}) : std::nullopt, - edge_label ? edge_label->size() : size_t{1}, // FIXME: update edge_label + edge_label ? edge_label->size() : size_t{1}, // FIXME: update edge_label ? hop ? fan_out_->size_ : size_t{1}, src_is_major, do_expensive_check_); @@ -463,6 +464,8 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { rmm::device_uvector output_major_offsets(0, handle_.get_stream()); rmm::device_uvector output_renumber_map(0, handle_.get_stream()); + // FIXME: Update this function to handle the new API with 'starting_vertex_offsets' + // and indices. std::tie(majors, output_major_offsets, minors, @@ -484,8 +487,9 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { ? std::make_optional(raft::device_span{ start_vertices_->as_type(), start_vertices_->size_}) : std::nullopt, - options_.retain_seeds_ ? std::make_optional(raft::device_span{ - label_offsets_->as_type(), label_offsets_->size_}) + options_.retain_seeds_ ? (is_deprecated_api_? std::make_optional(raft::device_span{ + label_offsets_->as_type(), label_offsets_->size_}) : std::make_optional(raft::device_span{ + start_vertex_offsets_->as_type(), start_vertex_offsets_->size_})) : std::nullopt, offsets ? std::make_optional( raft::device_span{offsets->data(), offsets->size()}) @@ -505,7 +509,7 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { hop.reset(); offsets.reset(); - } else { // heterogeneous + } else { // heterogeneous renumbering rmm::device_uvector vertex_type_offsets(graph_view.local_vertex_partition_range_size(), handle_.get_stream()); @@ -523,6 +527,8 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { // extract the edge_type from label_type_hop_offsets std::optional> label_type_hop_offsets{std::nullopt}; + // FIXME: Update this function to handle the new API with 'starting_vertex_offsets' + // and indices. std::tie(output_majors, minors, wgt, @@ -546,7 +552,7 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { start_vertices_->as_type(), start_vertices_->size_}) : std::nullopt, options_.retain_seeds_ ? std::make_optional(raft::device_span{ - label_offsets_->as_type(), label_offsets_->size_}) + start_vertex_offsets_->as_type(), start_vertex_offsets_->size_}) : std::nullopt, offsets ? std::make_optional( raft::device_span{offsets->data(), offsets->size()}) @@ -580,9 +586,6 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { CUGRAPH_FAIL("Can only use COO format if not renumbering"); } - if (num_edge_types_ != 1) { - CUGRAPH_FAIL("Can only use COO format for homogeneous neighborhood sampling"); - } std::tie(src, dst, wgt, edge_id, edge_type, label_hop_offsets) = cugraph::sort_sampled_edgelist(handle_, diff --git a/cpp/src/community/k_truss_impl.cuh b/cpp/src/community/k_truss_impl.cuh index 4facf093535..e052a892917 100644 --- a/cpp/src/community/k_truss_impl.cuh +++ b/cpp/src/community/k_truss_impl.cuh @@ -148,13 +148,13 @@ k_truss(raft::handle_t const& handle, std::nullopt, std::nullopt, cugraph::graph_properties_t{true, graph_view.is_multigraph()}, - false); + true); modified_graph_view = (*modified_graph).view(); } // 2. Find (k-1)-core and exclude edges that do not belong to (k-1)-core - /* + { auto cur_graph_view = modified_graph_view ? *modified_graph_view : graph_view; @@ -211,7 +211,6 @@ k_truss(raft::handle_t const& handle, renumber_map = std::move(tmp_renumber_map); } - */ // 3. Keep only the edges from a low-degree vertex to a high-degree vertex. @@ -278,7 +277,7 @@ k_truss(raft::handle_t const& handle, std::nullopt, std::nullopt, cugraph::graph_properties_t{false /* now asymmetric */, cur_graph_view.is_multigraph()}, - false); + true); modified_graph_view = (*modified_graph).view(); if (renumber_map) { // collapse renumber_map @@ -341,12 +340,8 @@ k_truss(raft::handle_t const& handle, edge_weight_view ? std::make_optional(*edge_weight_view) : std::nullopt, std::optional>{std::nullopt}, std::optional>{std::nullopt}, - /* std::make_optional( - raft::device_span((*renumber_map).data(), (*renumber_map).size())) - */ - std::optional>(std::nullopt) - ); + raft::device_span((*renumber_map).data(), (*renumber_map).size()))); std::tie(edgelist_srcs, edgelist_dsts, edgelist_wgts) = symmetrize_edgelist(handle, @@ -354,12 +349,6 @@ k_truss(raft::handle_t const& handle, std::move(edgelist_dsts), std::move(edgelist_wgts), false); - - raft::print_device_vector("edgelist_srcs", edgelist_srcs.data(), edgelist_srcs.size(), std::cout); - raft::print_device_vector("edgelist_dsts", edgelist_dsts.data(), edgelist_dsts.size(), std::cout); - - printf("\nK-TRUSS Successfully completed and the subgraph size = %d\n", edgelist_srcs.size()); - //raft::print_device_vector("edgelist_wgts", count_3_.data(), count_3_.size(), std::cout); return std::make_tuple( std::move(edgelist_srcs), std::move(edgelist_dsts), std::move(edgelist_wgts));