Skip to content

Commit

Permalink
add sequential slice and subtensor ttm version
Browse files Browse the repository at this point in the history
  • Loading branch information
bassoy committed Apr 20, 2024
1 parent 2e2fff6 commit 149f3fc
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/tlib/detail/tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@

namespace tlib::parallel_policy
{
struct sequential_t {}; // multithreaded gemm
struct threaded_gemm_t {}; // multithreaded gemm
struct omp_taskloop_t {}; // omp_taskloops with single threaded gemm
struct omp_forloop_t {}; // omp_for with single threaded gemm
struct omp_forloop_and_threaded_gemm_t {}; // omp_for with multi-threaded gemm
struct batched_gemm_t {}; // multithreaded batched gemm with collapsed loops

inline constexpr sequential_t sequential;
inline constexpr threaded_gemm_t threaded_gemm;
inline constexpr omp_taskloop_t omp_taskloop;
inline constexpr omp_forloop_t omp_forloop;
Expand Down
78 changes: 78 additions & 0 deletions include/tlib/detail/ttm.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,46 @@ inline void ttm(



template<class value_t, class size_t>
inline void ttm(
parallel_policy::sequential_t, slicing_policy::slice_t, fusion_policy::none_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
value_t *c, size_t const*const nc, size_t const*const wc
)
{
set_blas_threads_min();
//assert(get_blas_threads() > 1 || get_blas_threads() <= hwthreads);

auto is_cm = pib[0] == 1;

if(!is_case<8>(p,q,pia)){
if(is_cm)
mtm_cm(q, p, a, na, pia, b, nb, c, nc );
else
mtm_rm(q, p, a, na, pia, b, nb, c, nc );
}
else {
auto const qh = tlib::detail::inverse_mode(pia, pia+p, q);

using namespace std::placeholders;

auto n1 = na[pia[0]-1];
auto m = nc[q-1];
auto nq = na[q-1];
auto wq = wa[q-1];

auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run<value_t>,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c
auto gemm_row = std::bind(tlib::detail::gemm_row:: run<value_t>,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c

if(is_cm) multiple_gemm_with_slices(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else multiple_gemm_with_slices(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);
}
}



template<class value_t, class size_t>
inline void ttm(
parallel_policy::threaded_gemm_t, slicing_policy::slice_t, fusion_policy::none_t,
Expand Down Expand Up @@ -466,6 +506,44 @@ inline void ttm(



template<class value_t, class size_t>
inline void ttm(
parallel_policy::sequential_t, slicing_policy::subtensor_t, fusion_policy::none_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
value_t *c, size_t const*const nc, size_t const*const wc
)
{
auto is_cm = pib[0] == 1;

set_blas_threads_min();

if(!is_case<8>(p,q,pia)){
if(is_cm)
mtm_cm(q, p, a, na, pia, b, nb, c, nc );
else
mtm_rm(q, p, a, na, pia, b, nb, c, nc );
}
else {
auto const qh = tlib::detail::inverse_mode(pia, pia+p, q);
auto const nnq = product(na, pia, 1, qh);
auto const m = nc[q-1];
auto const nq = na[q-1];


using namespace std::placeholders;
auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run<value_t>,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c
auto gemm_row = std::bind(tlib::detail::gemm_row:: run<value_t>,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c

if(is_cm) multiple_gemm_with_subtensors(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else multiple_gemm_with_subtensors(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);

}
}



template<class value_t, class size_t>
inline void ttm(
parallel_policy::threaded_gemm_t, slicing_policy::subtensor_t, fusion_policy::none_t,
Expand Down
4 changes: 2 additions & 2 deletions test/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
CXX :=g++
#CXX :=clang++
#CXX :=clang++-12

CXX_FLAGS +=-Wextra -Wall -Wpedantic -O3 -std=c++20 -pthread -fopenmp
CXX_FLAGS +=-Wextra -Wall -Wpedantic -O3 -std=c++2a -pthread -fopenmp

ifeq ($(BLAS_FLAG), OPENBLAS)
include ../openblas.mk
Expand Down
30 changes: 30 additions & 0 deletions test/src/gtest_tlib_ttm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,21 @@ inline void check_tensor_times_matrix(const size_type init, const size_type step
}


TEST(TensorTimesMatrix, SequentialSliceNoFusion)
{
using value_type = double;
using size_type = std::size_t;
using execution_policy = tlib::parallel_policy::sequential_t;
using slicing_policy = tlib::slicing_policy::slice_t;
using fusion_policy = tlib::fusion_policy::none_t;

check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,2u>(2u,3);
check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,3u>(2u,3);
check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,4u>(2u,3);
// check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,5u>(2u,3);
}


TEST(TensorTimesMatrix, ThreadedGemmSliceNoLoopFusion)
{
using value_type = double;
Expand Down Expand Up @@ -333,6 +348,21 @@ TEST(TensorTimesMatrix, OmpForLoopThreadedGemmSliceAllFusion)



TEST(TensorTimesMatrix, SequentialSubtensorNoFusion)
{
using value_type = double;
using size_type = std::size_t;
using execution_policy = tlib::parallel_policy::sequential_t;
using slicing_policy = tlib::slicing_policy::subtensor_t;
using fusion_policy = tlib::fusion_policy::none_t;

check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,2u>(2u,3);
check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,3u>(2u,3);
check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,4u>(2u,3);
// check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,5u>(2u,3);
}


TEST(TensorTimesMatrix, OmpForLoopSubtensorOuterFusion)
{
using value_type = double;
Expand Down

0 comments on commit 149f3fc

Please sign in to comment.