Skip to content

Commit

Permalink
add exit condition
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Sep 9, 2024
1 parent 4c1c610 commit fe35c80
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions cpp/src/sampling/neighbor_sampling_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,14 @@ neighbor_sample_impl(
std::vector<size_t> level_sizes{};
int32_t hop{0};
int32_t edge_type_id_max{1}; // A value of 1 translate to homogeneous neighbor sample
int32_t num_edge_type_per_hop{0};

auto cur_graph_view = modified_graph_view ? *modified_graph_view : graph_view;

if (heterogeneous_fan_out) {
num_edge_type_per_hop = std::get<0>(*heterogeneous_fan_out).back() - 1;
}

while(true) {
int32_t k_level{0};
if (fan_out) {
Expand All @@ -210,8 +216,11 @@ neighbor_sample_impl(
break;
}
} else if (heterogeneous_fan_out) {
// initially edge type
edge_type_id_max = std::get<0>(*heterogeneous_fan_out).back() - 1;
if (num_edge_type_per_hop == 0) {
break;
}
edge_type_id_max = std::get<0>(*heterogeneous_fan_out).back() - 1;

}

for (int i = 0; i < edge_type_id_max; i++) {
Expand All @@ -223,7 +232,10 @@ neighbor_sample_impl(
auto k_level_size = (std::get<1>(*heterogeneous_fan_out)[i + 1] - std::get<1>(*heterogeneous_fan_out)[i]);
if (k_level_size > hop) {
k_level = i + hop;
} // otherwise, k_level = 0
} else { // otherwise, k_level = 0
--num_edge_type_per_hop ;

}
}
rmm::device_uvector<vertex_t> srcs(0, handle.get_stream());
rmm::device_uvector<vertex_t> dsts(0, handle.get_stream());
Expand Down

0 comments on commit fe35c80

Please sign in to comment.