diff --git a/include/ppqsort.h b/include/ppqsort.h index 070a60d..c802c9a 100644 --- a/include/ppqsort.h +++ b/include/ppqsort.h @@ -64,6 +64,11 @@ namespace ppqsort { impl::call_sort(std::forward(policy), begin, end); } + template + void sort(ExecutionPolicy&& policy, RandomIt begin, RandomIt end, const int threads) { + impl::call_sort(std::forward(policy), begin, end, threads); + } + template void sort(RandomIt begin, RandomIt end, Compare comp) { impl::seq_ppqsort(begin, end, comp); @@ -73,4 +78,9 @@ namespace ppqsort { void sort(ExecutionPolicy&& policy, RandomIt begin, RandomIt end, Compare comp) { impl::call_sort(std::forward(policy), begin, end, comp); } + + template + void sort(ExecutionPolicy&& policy, RandomIt begin, RandomIt end, Compare comp, const int threads) { + impl::call_sort(std::forward(policy), begin, end, comp, threads); + } } \ No newline at end of file diff --git a/include/ppqsort/parallel/cpp/mainloop_par.h b/include/ppqsort/parallel/cpp/mainloop_par.h index 1476a11..50b40c8 100644 --- a/include/ppqsort/parallel/cpp/mainloop_par.h +++ b/include/ppqsort/parallel/cpp/mainloop_par.h @@ -18,6 +18,9 @@ namespace ppqsort::impl { namespace cpp { struct ThreadPools { + ThreadPools() = default; + explicit ThreadPools(const int threads) : partition(threads), tasks(threads) {} + ThreadPool<> partition; ThreadPool<> tasks; }; @@ -127,16 +130,15 @@ namespace ppqsort::impl { } } - template ::value_type>, bool Branchless = use_branchless::value_type, Compare>::value> - void par_ppqsort(RandomIt begin, RandomIt end, Compare comp = Compare()) { + void par_ppqsort(RandomIt begin, RandomIt end, Compare comp = Compare(), + int threads = static_cast(std::jthread::hardware_concurrency())) { if (begin == end) return; constexpr bool branchless = Force_branchless || Branchless; - int threads = static_cast(std::jthread::hardware_concurrency()); auto size = end - begin; if ((threads < 2) || (size < parameters::seq_threshold)) return seq_loop(begin, end, comp, log2(size)); @@ -144,7 +146,7 @@ namespace ppqsort::impl { int seq_thr = (end - begin + 1) / threads / parameters::par_thr_div; seq_thr = std::max(seq_thr, branchless ? parameters::insertion_threshold_primitive : parameters::insertion_threshold); - cpp::ThreadPools threadpools; + cpp::ThreadPools threadpools(threads); threadpools.tasks.push_task([begin, end, comp, seq_thr, threads, &threadpools] { cpp::par_loop(begin, end, comp, log2(end - begin), @@ -154,4 +156,11 @@ namespace ppqsort::impl { threadpools.tasks.wait_and_stop(); threadpools.partition.wait_and_stop(); } + + template ::value_type>> + void par_ppqsort(RandomIt begin, RandomIt end, int threads) { + return par_ppqsort(begin, end, Compare(), threads); + } } \ No newline at end of file diff --git a/include/ppqsort/parallel/openmp/mainloop_par.h b/include/ppqsort/parallel/openmp/mainloop_par.h index 38095e5..afb7454 100644 --- a/include/ppqsort/parallel/openmp/mainloop_par.h +++ b/include/ppqsort/parallel/openmp/mainloop_par.h @@ -114,11 +114,11 @@ namespace ppqsort::impl { typename RandomIt, typename Compare = std::less::value_type>, bool Branchless = use_branchless::value_type, Compare>::value> - void par_ppqsort(RandomIt begin, RandomIt end, Compare comp = Compare()) { + void par_ppqsort(RandomIt begin, RandomIt end, + Compare comp = Compare(), const int threads = omp_get_max_threads()) { if (begin == end) return; constexpr bool branchless = Force_branchless || Branchless; - const int threads = omp_get_max_threads(); auto size = end - begin; if ((threads < 2) || (size < parameters::seq_threshold)) return seq_loop(begin, end, comp, log2(size)); @@ -138,4 +138,12 @@ namespace ppqsort::impl { } } } + + template ::value_type>> + void par_ppqsort(RandomIt begin, RandomIt end, const int threads) { + Compare comp = Compare(); + par_ppqsort(begin, end, comp, threads); + } } \ No newline at end of file