diff --git a/src/atlas/CMakeLists.txt b/src/atlas/CMakeLists.txt index 60b93b9a1..016136191 100644 --- a/src/atlas/CMakeLists.txt +++ b/src/atlas/CMakeLists.txt @@ -709,6 +709,8 @@ linalg/sparse/SparseMatrixMultiply.h linalg/sparse/SparseMatrixMultiply.tcc linalg/sparse/SparseMatrixMultiply_EckitLinalg.h linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc +linalg/sparse/SparseMatrixMultiply_HicSparse.h +linalg/sparse/SparseMatrixMultiply_HicSparse.cc linalg/sparse/SparseMatrixMultiply_OpenMP.h linalg/sparse/SparseMatrixMultiply_OpenMP.cc linalg/dense.h @@ -1001,6 +1003,7 @@ ecbuild_add_library( TARGET atlas eckit_option atlas_io hic + hicsparse $<${atlas_HAVE_EIGEN}:Eigen3::Eigen> $<${atlas_HAVE_OMP_CXX}:OpenMP::OpenMP_CXX> $<${atlas_HAVE_GRIDTOOLS_STORAGE}:GridTools::gridtools> diff --git a/src/atlas/linalg/sparse/Backend.cc b/src/atlas/linalg/sparse/Backend.cc index b9084884a..c5f29e105 100644 --- a/src/atlas/linalg/sparse/Backend.cc +++ b/src/atlas/linalg/sparse/Backend.cc @@ -99,6 +99,13 @@ bool Backend::available() const { if (t == backend::openmp::type()) { return true; } + if (t == backend::hicsparse::type()) { +#if ATLAS_HAVE_GPU + return true; +#else + return false; +#endif + } if (t == backend::eckit_linalg::type()) { if (has("backend")) { #if ATLAS_ECKIT_HAVE_ECKIT_585 diff --git a/src/atlas/linalg/sparse/Backend.h b/src/atlas/linalg/sparse/Backend.h index fea22178f..698b327a8 100644 --- a/src/atlas/linalg/sparse/Backend.h +++ b/src/atlas/linalg/sparse/Backend.h @@ -43,6 +43,11 @@ struct eckit_linalg : Backend { static std::string type() { return "eckit_linalg"; } eckit_linalg(): Backend(type()) {} }; + +struct hicsparse : Backend { + static std::string type() { return "hicsparse"; } + hicsparse(): Backend(type()) {} +}; } // namespace backend diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply.h b/src/atlas/linalg/sparse/SparseMatrixMultiply.h index 9e337fe08..19268e39f 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply.h +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply.h @@ -115,3 +115,4 @@ struct SparseMatrixMultiply { #include "SparseMatrixMultiply.tcc" #include "SparseMatrixMultiply_EckitLinalg.h" #include "SparseMatrixMultiply_OpenMP.h" +#include "SparseMatrixMultiply_HicSparse.h" diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc b/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc index 4b44b984c..e69b5c007 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc @@ -122,6 +122,9 @@ void sparse_matrix_multiply( const Matrix& matrix, const SourceView& src, Target #endif sparse::dispatch_sparse_matrix_multiply( matrix, src, tgt, indexing, util::Config("backend",type) ); } + else if ( type == sparse::backend::hicsparse::type() ) { + sparse::dispatch_sparse_matrix_multiply( matrix, src, tgt, indexing, config ); + } else { throw_NotImplemented( "sparse_matrix_multiply cannot be performed with unsupported backend [" + type + "]", Here() ); @@ -160,6 +163,9 @@ void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, Ta #endif sparse::dispatch_sparse_matrix_multiply_add( matrix, src, tgt, indexing, util::Config("backend",type) ); } + else if ( type == sparse::backend::hicsparse::type() ) { + sparse::dispatch_sparse_matrix_multiply_add( matrix, src, tgt, indexing, config ); + } else { throw_NotImplemented( "sparse_matrix_multiply_add cannot be performed with unsupported backend [" + type + "]", Here() ); diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.cc b/src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.cc new file mode 100644 index 000000000..46b832b0c --- /dev/null +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.cc @@ -0,0 +1,324 @@ +/* + * (C) Copyright 2024 ECMWF. + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + * In applying this licence, ECMWF does not waive the privileges and immunities + * granted to it by virtue of its status as an intergovernmental organisation + * nor does it submit to any jurisdiction. + */ + +#include "atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.h" + +#include "atlas/parallel/omp/omp.h" +#include "atlas/runtime/Exception.h" + +#include "hic/hic.h" +#include "hic/hic_library_types.h" +#include "hic/hicsparse.h" + +namespace { + +class HicSparseHandleRAIIWrapper { +public: + HicSparseHandleRAIIWrapper() { hicsparseCreate(&handle_); }; + ~HicSparseHandleRAIIWrapper() { hicsparseDestroy(handle_); } + hicsparseHandle_t value() { return handle_; } +private: + hicsparseHandle_t handle_; +}; + +hicsparseHandle_t getDefaultHicSparseHandle() { + static auto handle = HicSparseHandleRAIIWrapper(); + return handle.value(); +} + +template +constexpr hicsparseIndexType_t getHicsparseIndexType() { + using base_type = std::remove_const_t; + if constexpr (std::is_same_v) { + return HICSPARSE_INDEX_32I; + } else { + static_assert(std::is_same_v, "Unsupported index type"); + return HICSPARSE_INDEX_64I; + } +} + +template +constexpr auto getHicsparseValueType() { + using base_type = std::remove_const_t; + if constexpr (std::is_same_v) { + return HIC_R_32F; + } else { + static_assert(std::is_same_v, "Unsupported value type");\ + return HIC_R_64F; + } +} + +template +hicsparseOrder_t getHicsparseOrder(const atlas::linalg::View& v) { + constexpr int row_idx = (IndexLayout == atlas::linalg::Indexing::layout_left) ? 0 : 1; + constexpr int col_idx = (IndexLayout == atlas::linalg::Indexing::layout_left) ? 1 : 0; + + if (v.stride(row_idx) == 1) { + return HICSPARSE_ORDER_COL; + } else if (v.stride(col_idx) == 1) { + return HICSPARSE_ORDER_ROW; + } else { + atlas::throw_Exception("Unsupported dense matrix memory order", Here()); + return HICSPARSE_ORDER_COL; + } +} + +template +int64_t getLeadingDimension(const atlas::linalg::View& v) { + if (v.stride(0) == 1) { + return v.stride(1); + } else if (v.stride(1) == 1) { + return v.stride(0); + } else { + atlas::throw_Exception("Unsupported dense matrix memory order", Here()); + return 0; + } +} + +} + +namespace atlas { +namespace linalg { +namespace sparse { + +template +void hsSpMV(const SparseMatrix& W, const View& src, TargetValue beta, View& tgt) { + // Assume that src and tgt are device views + + ATLAS_ASSERT(src.shape(0) >= W.cols()); + ATLAS_ASSERT(tgt.shape(0) >= W.rows()); + + // Check if W is on the device and if not, copy it to the device + if (W.deviceNeedsUpdate()) { + W.updateDevice(); + } + + auto handle = getDefaultHicSparseHandle(); + + // Create a sparse matrix descriptor + hicsparseConstSpMatDescr_t matA; + HICSPARSE_CALL(hicsparseCreateConstCsr( + &matA, + W.rows(), W.cols(), W.nonZeros(), + W.device_outer(), // row_offsets + W.device_inner(), // column_indices + W.device_data(), // values + getHicsparseIndexType(), + getHicsparseIndexType(), + HICSPARSE_INDEX_BASE_ZERO, + getHicsparseValueType())); + + // Create dense matrix descriptors + hicsparseConstDnVecDescr_t vecX; + HICSPARSE_CALL(hicsparseCreateConstDnVec( + &vecX, + static_cast(W.cols()), + src.data(), + getHicsparseValueType::value_type>())); + + hicsparseDnVecDescr_t vecY; + HICSPARSE_CALL(hicsparseCreateDnVec( + &vecY, + W.rows(), + tgt.data(), + getHicsparseValueType::value_type>())); + + using ComputeType = typename View::value_type; + constexpr auto compute_type = getHicsparseValueType(); + + ComputeType alpha = 1; + + // Determine buffer size + size_t bufferSize = 0; + HICSPARSE_CALL(hicsparseSpMV_bufferSize( + handle, + HICSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, + matA, + vecX, + &beta, + vecY, + compute_type, + HICSPARSE_SPMV_ALG_DEFAULT, + &bufferSize)); + + // Allocate buffer + char* buffer; + HIC_CALL(hicMalloc(&buffer, bufferSize)); + + // Perform SpMV + HICSPARSE_CALL(hicsparseSpMV( + handle, + HICSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, + matA, + vecX, + &beta, + vecY, + compute_type, + HICSPARSE_SPMV_ALG_DEFAULT, + buffer)); + + HIC_CALL(hicFree(buffer)); + HICSPARSE_CALL(hicsparseDestroyDnVec(vecX)); + HICSPARSE_CALL(hicsparseDestroyDnVec(vecY)); + HICSPARSE_CALL(hicsparseDestroySpMat(matA)); + + HIC_CALL(hicDeviceSynchronize()); +} + + +template +void hsSpMM(const SparseMatrix& W, const View& src, TargetValue beta, View& tgt) { + // Assume that src and tgt are device views + + constexpr int row_idx = (IndexLayout == Indexing::layout_left) ? 0 : 1; + constexpr int col_idx = (IndexLayout == Indexing::layout_left) ? 1 : 0; + + ATLAS_ASSERT(src.shape(row_idx) >= W.cols()); + ATLAS_ASSERT(tgt.shape(row_idx) >= W.rows()); + ATLAS_ASSERT(src.shape(col_idx) == tgt.shape(col_idx)); + + // Check if W is on the device and if not, copy it to the device + if (W.deviceNeedsUpdate()) { + W.updateDevice(); + } + + auto handle = getDefaultHicSparseHandle(); + + // Create a sparse matrix descriptor + hicsparseConstSpMatDescr_t matA; + HICSPARSE_CALL(hicsparseCreateConstCsr( + &matA, + W.rows(), W.cols(), W.nonZeros(), + W.device_outer(), // row_offsets + W.device_inner(), // column_indices + W.device_data(), // values + getHicsparseIndexType(), + getHicsparseIndexType(), + HICSPARSE_INDEX_BASE_ZERO, + getHicsparseValueType())); + + // Create dense matrix descriptors + hicsparseConstDnMatDescr_t matB; + HICSPARSE_CALL(hicsparseCreateConstDnMat( + &matB, + W.cols(), src.shape(col_idx), + getLeadingDimension(src), + src.data(), + getHicsparseValueType::value_type>(), + getHicsparseOrder(src))); + + hicsparseDnMatDescr_t matC; + HICSPARSE_CALL(hicsparseCreateDnMat( + &matC, + W.rows(), tgt.shape(col_idx), + getLeadingDimension(tgt), + tgt.data(), + getHicsparseValueType::value_type>(), + getHicsparseOrder(tgt))); + + using ComputeType = typename View::value_type; + constexpr auto compute_type = getHicsparseValueType(); + + ComputeType alpha = 1; + + // Determine buffer size + size_t bufferSize = 0; + HICSPARSE_CALL(hicsparseSpMM_bufferSize( + handle, + HICSPARSE_OPERATION_NON_TRANSPOSE, + HICSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, + matA, + matB, + &beta, + matC, + compute_type, + HICSPARSE_SPMM_ALG_DEFAULT, + &bufferSize)); + + // Allocate buffer + char* buffer; + HIC_CALL(hicMalloc(&buffer, bufferSize)); + + // Perform SpMM + HICSPARSE_CALL(hicsparseSpMM( + handle, + HICSPARSE_OPERATION_NON_TRANSPOSE, + HICSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, + matA, + matB, + &beta, + matC, + compute_type, + HICSPARSE_SPMM_ALG_DEFAULT, + buffer)); + + HIC_CALL(hicFree(buffer)); + HICSPARSE_CALL(hicsparseDestroyDnMat(matC)); + HICSPARSE_CALL(hicsparseDestroyDnMat(matB)); + HICSPARSE_CALL(hicsparseDestroySpMat(matA)); + + HIC_CALL(hicDeviceSynchronize()); +} + +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 0; + hsSpMV(W, src, beta, tgt); +} + +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 1; + hsSpMV(W, src, beta, tgt); +} + +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 0; + hsSpMM(W, src, beta, tgt); +} + +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 1; + hsSpMM(W, src, beta, tgt); +} + +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 0; + hsSpMV(W, src, beta, tgt); +} + +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 1; + hsSpMV(W, src, beta, tgt); +} + +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 0; + hsSpMM(W, src, beta, tgt); +} + +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + double beta = 1; + hsSpMM(W, src, beta, tgt); +} + +} // namespace sparse +} // namespace linalg +} // namespace atlas \ No newline at end of file diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.h b/src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.h new file mode 100644 index 000000000..616a8d617 --- /dev/null +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.h @@ -0,0 +1,53 @@ +/* + * (C) Copyright 2024 ECMWF. + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + * In applying this licence, ECMWF does not waive the privileges and immunities + * granted to it by virtue of its status as an intergovernmental organisation + * nor does it submit to any jurisdiction. + */ + +#pragma once + +#include "atlas/linalg/sparse/SparseMatrixMultiply.h" + +namespace atlas { +namespace linalg { +namespace sparse { + +template <> +struct SparseMatrixMultiply { + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); +}; + +template <> +struct SparseMatrixMultiply { + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); +}; + +template <> +struct SparseMatrixMultiply { + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); +}; + +template <> +struct SparseMatrixMultiply { + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); +}; + +} // namespace sparse +} // namespace linalg +} // namespace atlas \ No newline at end of file diff --git a/src/tests/linalg/CMakeLists.txt b/src/tests/linalg/CMakeLists.txt index d060dcc8a..714cc531c 100644 --- a/src/tests/linalg/CMakeLists.txt +++ b/src/tests/linalg/CMakeLists.txt @@ -14,6 +14,14 @@ ecbuild_add_test( TARGET atlas_test_linalg_sparse ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ) + +ecbuild_add_test( TARGET atlas_test_linalg_sparse_gpu + SOURCES test_linalg_sparse_gpu.cc + LIBS atlas + ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} + CONDITION atlas_HAVE_CUDA OR atlas_HAVE_HIP +) + ecbuild_add_test( TARGET atlas_test_linalg_dense SOURCES test_linalg_dense.cc LIBS atlas diff --git a/src/tests/linalg/test_linalg_sparse.cc b/src/tests/linalg/test_linalg_sparse.cc index 0e6f8c81f..943583b61 100644 --- a/src/tests/linalg/test_linalg_sparse.cc +++ b/src/tests/linalg/test_linalg_sparse.cc @@ -30,6 +30,7 @@ namespace test { // strings to be used in the tests static std::string eckit_linalg = sparse::backend::eckit_linalg::type(); static std::string openmp = sparse::backend::openmp::type(); +static std::string hicsparse = sparse::backend::hicsparse::type(); //---------------------------------------------------------------------------------------------------------------------- @@ -183,15 +184,21 @@ CASE("test backend functionalities") { EXPECT_EQ(sparse::current_backend().getString("backend", "undefined"), "undefined"); EXPECT_EQ(sparse::default_backend(eckit_linalg).getString("backend"), "default"); + sparse::current_backend(hicsparse); + EXPECT_EQ(sparse::current_backend().type(), "hicsparse"); + EXPECT_EQ(sparse::current_backend().getString("backend", "undefined"), "undefined"); + sparse::default_backend(eckit_linalg).set("backend", "generic"); EXPECT_EQ(sparse::default_backend(eckit_linalg).getString("backend"), "generic"); const sparse::Backend backend_default = sparse::Backend(); const sparse::Backend backend_openmp = sparse::backend::openmp(); const sparse::Backend backend_eckit_linalg = sparse::backend::eckit_linalg(); - EXPECT_EQ(backend_default.type(), openmp); + const sparse::Backend backend_hicsparse = sparse::backend::hicsparse(); + EXPECT_EQ(backend_default.type(), hicsparse); EXPECT_EQ(backend_openmp.type(), openmp); EXPECT_EQ(backend_eckit_linalg.type(), eckit_linalg); + EXPECT_EQ(backend_hicsparse.type(), hicsparse); EXPECT_EQ(std::string(backend_openmp), openmp); EXPECT_EQ(std::string(backend_eckit_linalg), eckit_linalg); diff --git a/src/tests/linalg/test_linalg_sparse_gpu.cc b/src/tests/linalg/test_linalg_sparse_gpu.cc new file mode 100644 index 000000000..ffd4af28f --- /dev/null +++ b/src/tests/linalg/test_linalg_sparse_gpu.cc @@ -0,0 +1,271 @@ +/* + * (C) Copyright 2024- ECMWF. + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + * In applying this licence, ECMWF does not waive the privileges and immunities + * granted to it by virtue of its status as an intergovernmental organisation + * nor does it submit to any jurisdiction. + */ + +#include +#include + +#include "eckit/linalg/Matrix.h" +#include "eckit/linalg/Vector.h" + +#include "atlas/array.h" +#include "atlas/linalg/sparse.h" + +#include "tests/AtlasTestEnvironment.h" + + +using namespace atlas::linalg; + +namespace atlas { +namespace test { + +//---------------------------------------------------------------------------------------------------------------------- + +// strings to be used in the tests +static std::string hicsparse = sparse::backend::hicsparse::type(); + +//---------------------------------------------------------------------------------------------------------------------- + +// Only reason to define these derived classes is for nicer constructors and convenience in the tests + +class Vector : public eckit::linalg::Vector { +public: + using Scalar = eckit::linalg::Scalar; + using eckit::linalg::Vector::Vector; + Vector(const std::initializer_list& v): eckit::linalg::Vector::Vector(v.size()) { + size_t i = 0; + for (auto& s : v) { + operator[](i++) = s; + } + } +}; + +class Matrix : public eckit::linalg::Matrix { +public: + using Scalar = eckit::linalg::Scalar; + using eckit::linalg::Matrix::Matrix; + + Matrix(const std::initializer_list>& m): + eckit::linalg::Matrix::Matrix(m.size(), m.size() ? m.begin()->size() : 0) { + size_t r = 0; + for (auto& row : m) { + for (size_t c = 0; c < cols(); ++c) { + operator()(r, c) = row[c]; + } + ++r; + } + } +}; + +// 2D array constructable from eckit::linalg::Matrix +// Indexing/memorylayout and data type can be customized for testing +template +struct ArrayMatrix { + array::ArrayView view() { + array.syncHostDevice(); + return array::make_view(array); + } + array::ArrayView device_view() { + array.syncHostDevice(); + return array::make_device_view(array); + } + void setHostNeedsUpdate(bool b) { + array.setHostNeedsUpdate(b); + } + ArrayMatrix(const eckit::linalg::Matrix& m): ArrayMatrix(m.rows(), m.cols()) { + auto view_ = array::make_view(array); + for (int r = 0; r < m.rows(); ++r) { + for (int c = 0; c < m.cols(); ++c) { + auto& v = layout_left ? view_(r, c) : view_(c, r); + v = m(r, c); + } + } + } + ArrayMatrix(int r, int c): array(make_shape(r, c)) {} + +private: + static constexpr bool layout_left = (indexing == Indexing::layout_left); + static array::ArrayShape make_shape(int rows, int cols) { + return layout_left ? array::make_shape(rows, cols) : array::make_shape(cols, rows); + } + array::ArrayT array; +}; + +// 1D array constructable from eckit::linalg::Vector +template +struct ArrayVector { + array::ArrayView view() { + array.syncHostDevice(); + return array::make_view(array); + } + array::ArrayView const_view() { + array.syncHostDevice(); + return array::make_view(array); + } + array::ArrayView device_view() { + array.syncHostDevice(); + return array::make_device_view(array); + } + void setHostNeedsUpdate(bool b) { + array.setHostNeedsUpdate(b); + } + ArrayVector(const eckit::linalg::Vector& v): ArrayVector(v.size()) { + auto view_ = array::make_view(array); + for (int n = 0; n < v.size(); ++n) { + view_[n] = v[n]; + } + } + ArrayVector(int size): array(size) {} + +private: + array::ArrayT array; +}; + +//---------------------------------------------------------------------------------------------------------------------- + +template +void expect_equal(T* v, T* r, size_t s) { + EXPECT(is_approximately_equal(eckit::testing::make_view(v, s), eckit::testing::make_view(r, s), T(1.e-5))); +} + +template +void expect_equal(const T1& v, const T2& r) { + expect_equal(v.data(), r.data(), std::min(v.size(), r.size())); +} + +//---------------------------------------------------------------------------------------------------------------------- + +CASE("sparse-matrix vector multiply (spmv) [backend=hicsparse]") { + // "square" matrix + // A = 2 . -3 + // . 2 . + // . . 2 + // x = 1 2 3 + // y = 1 2 3 + + sparse::current_backend(hicsparse); + + SparseMatrix A{3, 3, {{0, 0, 2.}, {0, 2, -3.}, {1, 1, 2.}, {2, 2, 2.}}}; + + SECTION("View of atlas::Array [backend=hicsparse]") { + ArrayVector x(Vector{1., 2., 3.}); + ArrayVector y(3); + const auto x_device_view = x.device_view(); + auto y_device_view = y.device_view(); + sparse_matrix_multiply(A, x_device_view, y_device_view); + y.setHostNeedsUpdate(true); + auto y_view = y.view(); + expect_equal(y.view(), Vector{-7., 4., 6.}); + // sparse_matrix_multiply of sparse matrix and vector of non-matching sizes should fail + { + ArrayVector x2(2); + auto x2_device_view = x2.device_view(); + EXPECT_THROWS_AS(sparse_matrix_multiply(A, x2_device_view, y_device_view), eckit::AssertionFailed); + } + } + + SECTION("sparse_matrix_multiply_add [backend=hicsparse]") { + ArrayVector x(Vector{1., 2., 3.}); + ArrayVector y(Vector{4., 5., 6.}); + auto x_device_view = x.device_view(); + auto y_device_view = y.device_view(); + sparse_matrix_multiply_add(A, x_device_view, y_device_view); + y.setHostNeedsUpdate(true); + auto y_view = y.view(); + expect_equal(y.view(), Vector{-3., 9., 12.}); + // sparse_matrix_multiply of sparse matrix and vector of non-matching sizes should fail + { + ArrayVector x2(2); + auto x2_device_view = x2.device_view(); + EXPECT_THROWS_AS(sparse_matrix_multiply_add(A, x2_device_view, y_device_view), eckit::AssertionFailed); + } + } +} + +CASE("sparse-matrix matrix multiply (spmm) [backend=hicsparse]") { + // "square" + // A = 2 . -3 + // . 2 . + // . . 2 + // x = 1 2 3 + // y = 1 2 3 + sparse::current_backend(hicsparse); + + SparseMatrix A{3, 3, {{0, 0, 2.}, {0, 2, -3.}, {1, 1, 2.}, {2, 2, 2.}}}; + Matrix m{{1., 2.}, {3., 4.}, {5., 6.}}; + Matrix c_exp{{-13., -14.}, {6., 8.}, {10., 12.}}; + + SECTION("View of atlas::Array PointsRight [backend=hicsparse]") { + ArrayMatrix ma(m); + ArrayMatrix c(3, 2); + auto ma_device_view = ma.device_view(); + auto c_device_view = c.device_view(); + sparse_matrix_multiply(A, ma_device_view, c_device_view, Indexing::layout_right); + c.setHostNeedsUpdate(true); + auto c_view = c.view(); + expect_equal(c_view, ArrayMatrix(c_exp).view()); + } + + SECTION("sparse_matrix_multiply [backend=hicsparse]") { + auto backend = sparse::backend::hicsparse(); + ArrayMatrix ma(m); + ArrayMatrix c(3, 2); + auto ma_device_view = ma.device_view(); + auto c_device_view = c.device_view(); + sparse_matrix_multiply(A, ma_device_view, c_device_view, backend); + c.setHostNeedsUpdate(true); + auto c_view = c.view(); + expect_equal(c_view, ArrayMatrix(c_exp).view()); + } + + SECTION("SparseMatrixMultiply [backend=hicsparse] 1") { + auto spmm = SparseMatrixMultiply{sparse::backend::hicsparse()}; + ArrayMatrix ma(m); + ArrayMatrix c(3, 2); + auto ma_device_view = ma.device_view(); + auto c_device_view = c.device_view(); + spmm(A, ma_device_view, c_device_view); + c.setHostNeedsUpdate(true); + auto c_view = c.view(); + expect_equal(c_view, ArrayMatrix(c_exp).view()); + } + + SECTION("SparseMatrixMultiply [backend=hicsparse] 2") { + auto spmm = SparseMatrixMultiply{hicsparse}; + ArrayMatrix ma(m); + ArrayMatrix c(3, 2); + auto ma_device_view = ma.device_view(); + auto c_device_view = c.device_view(); + spmm(A, ma_device_view, c_device_view); + c.setHostNeedsUpdate(true); + auto c_view = c.view(); + expect_equal(c_view, ArrayMatrix(c_exp).view()); + } + + SECTION("sparse_matrix_multiply_add [backend=hicsparse]") { + ArrayMatrix x(m); + ArrayMatrix y(m); + Matrix y_exp{{-12., -12.}, {9., 12.}, {15., 18.}}; + auto x_device_view = x.device_view(); + auto y_device_view = y.device_view(); + sparse_matrix_multiply_add(A, x_device_view, y_device_view, sparse::backend::hicsparse()); + y.setHostNeedsUpdate(true); + auto y_view = y.view(); + expect_equal(y_view, ArrayMatrix(y_exp).view()); + } +} + +//---------------------------------------------------------------------------------------------------------------------- + +} // namespace test +} // namespace atlas + +int main(int argc, char** argv) { + return atlas::test::run(argc, argv); +} \ No newline at end of file