From ee5a6b3c1911f514a000e676a22fb398b7c29a7e Mon Sep 17 00:00:00 2001 From: odlomax Date: Wed, 3 Jan 2024 18:05:03 +0000 Subject: [PATCH] Further refactoring. --- src/atlas/CMakeLists.txt | 10 +- .../sphericalvector/ComplexMatrixMultiply.h | 167 ++++++++++-------- .../method/sphericalvector/SparseMatrix.h | 85 +++++++++ .../method/sphericalvector/SphericalVector.cc | 30 ++-- .../method/sphericalvector/SphericalVector.h | 47 +---- .../method/sphericalvector/Types.h | 29 +++ 6 files changed, 230 insertions(+), 138 deletions(-) create mode 100644 src/atlas/interpolation/method/sphericalvector/SparseMatrix.h create mode 100644 src/atlas/interpolation/method/sphericalvector/Types.h diff --git a/src/atlas/CMakeLists.txt b/src/atlas/CMakeLists.txt index 3ad7805b9..634a22dce 100644 --- a/src/atlas/CMakeLists.txt +++ b/src/atlas/CMakeLists.txt @@ -524,7 +524,7 @@ functionspace/detail/PointCloudInterface.cc functionspace/detail/CubedSphereStructure.h functionspace/detail/CubedSphereStructure.cc -# for cubedsphere matching mesh partitioner +# for cubedsphere matching mesh partitioner interpolation/method/cubedsphere/CellFinder.cc interpolation/method/cubedsphere/CellFinder.h interpolation/Vector2D.cc @@ -541,8 +541,8 @@ interpolation/element/Triag3D.cc interpolation/element/Triag3D.h interpolation/method/Intersect.cc interpolation/method/Intersect.h -interpolation/method/Ray.cc # For testing Quad -interpolation/method/Ray.h # For testing Quad +interpolation/method/Ray.cc # For testing Quad +interpolation/method/Ray.h # For testing Quad # for BuildConvexHull3D @@ -635,8 +635,10 @@ interpolation/method/knn/KNearestNeighboursBase.h interpolation/method/knn/NearestNeighbour.cc interpolation/method/knn/NearestNeighbour.h interpolation/method/sphericalvector/ComplexMatrixMultiply.h +interpolation/method/sphericalvector/SparseMatrix.h interpolation/method/sphericalvector/SphericalVector.cc interpolation/method/sphericalvector/SphericalVector.h +interpolation/method/sphericalvector/Types.h interpolation/method/structured/Cubic2D.cc interpolation/method/structured/Cubic2D.h interpolation/method/structured/Cubic3D.cc @@ -869,7 +871,7 @@ if( NOT atlas_HAVE_ATLAS_FUNCTIONSPACE ) unset( atlas_parallel_srcs ) unset( atlas_output_srcs ) unset( atlas_redistribution_srcs ) - unset( atlas_linalg_srcs ) # only depends on array + unset( atlas_linalg_srcs ) # only depends on array endif() if( NOT atlas_HAVE_ATLAS_INTERPOLATION ) diff --git a/src/atlas/interpolation/method/sphericalvector/ComplexMatrixMultiply.h b/src/atlas/interpolation/method/sphericalvector/ComplexMatrixMultiply.h index c3b49a77b..be7a966bf 100644 --- a/src/atlas/interpolation/method/sphericalvector/ComplexMatrixMultiply.h +++ b/src/atlas/interpolation/method/sphericalvector/ComplexMatrixMultiply.h @@ -1,3 +1,10 @@ +/* + * (C) Crown Copyright 2023 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + #pragma once #include @@ -8,7 +15,7 @@ #include "atlas/array/ArrayView.h" #include "atlas/array/Range.h" #include "atlas/array/helpers/ArrayForEach.h" -#include "atlas/interpolation/method/sphericalvector/SphericalVector.h" +#include "atlas/interpolation/method/sphericalvector/Types.h" #include "atlas/parallel/omp/omp.h" namespace atlas { @@ -16,9 +23,8 @@ namespace interpolation { namespace method { namespace detail { -using Complex = SphericalVector::Complex; -using ComplexMatPtr = SphericalVector::ComplexMatPtr; -using RealMatPtr = SphericalVector::RealMatPtr; +using ComplexMatPtr = std::shared_ptr; +using RealMatPtr = std::shared_ptr; struct TwoVectorTag {}; struct ThreeVectorTag {}; @@ -29,89 +35,112 @@ template using IsVectorTag = std::enable_if_t || std::is_same_v>; +/// @brief Helper class to perform complex matrix multiplications +/// +/// @details Performs matrix multiplication between fields of 2-vectors or +/// 3-vectors. Fields must have Rank >= 2. Here, the assumption is +/// that Dim = 0 is the horizontal dimension, and Dim = (Rank - 1) is +/// the vector element dimension. template class ComplexMatrixMultiply { public: ComplexMatrixMultiply() = default; + + /// @brief Construct object from sparse matrices. + /// + /// @details complexWeights is a SparseMatrix of weights. realWeights is a + /// SparseMatrix containing the magnitudes of the elements of + /// complexWeights. ComplexMatrixMultiply(const ComplexMatPtr& complexWeights, - const RealMatPtr& realWeights); + const RealMatPtr& realWeights) + : complexWeights_{complexWeights}, realWeights_{realWeights} { + + if constexpr (ATLAS_BUILD_TYPE_DEBUG) { + ATLAS_ASSERT(complexWeights->rows() == realWeights->rows()); + ATLAS_ASSERT(complexWeights->cols() == realWeights->cols()); + ATLAS_ASSERT(complexWeights->nonZeros() == realWeights->nonZeros()); + for (auto i = Size{0}; i < complexWeights->rows() + 1; ++i) { + ATLAS_ASSERT(complexWeights_->outer()[i] == realWeights_->outer()[i]); + } + + for (auto i = Size{0}; i < complexWeights->nonZeros(); ++i) { + ATLAS_ASSERT(complexWeights_->inner()[i] == realWeights_->inner()[i]); + } + } + } + + /// @brief Apply complex matrix vector multiplication. + /// + /// @details Multiply weights by the elements in sourceView to give + /// elements in targetView. If VectorType == TwoVectorTag, + /// complexWeights are applied to the horizontal elements of + /// sourceView. If VectorType == ThreeVectorTag, then realWeights + /// are additionally applied to the vertical elements of sourceView. template > void apply(const array::ArrayView& sourceView, - array::ArrayView& targetView, VectorType) const; + array::ArrayView& targetView, VectorType) const { + + const auto* outerIndices = complexWeights_->outer(); + const auto* innerIndices = complexWeights_->inner(); + const auto* complexWeightValues = complexWeights_->data(); + const auto* realWeightValues = realWeights_->data(); + const auto nRows = complexWeights_->rows(); + + using Index = std::decay_t; + + atlas_omp_parallel_for(auto rowIndex = Index{0}; rowIndex < nRows; + ++rowIndex) { + + auto targetSlice = sliceColumn(targetView, rowIndex); + if constexpr (InitialiseTarget) { targetSlice.assign(0.); } + for (auto dataIndex = outerIndices[rowIndex]; + dataIndex < outerIndices[rowIndex + 1]; ++dataIndex) { + + const auto colIndex = innerIndices[dataIndex]; + const auto sourceSlice = sliceColumn(sourceView, colIndex); + + array::helpers::arrayForEachDim( + slicedColumnDims{}, std::tie(sourceSlice, targetSlice), + [&](auto&& sourceElem, auto&& targetElem) { + const auto targetVector = complexWeightValues[dataIndex] * + Complex(sourceElem(0), sourceElem(1)); + targetElem(0) += targetVector.real(); + targetElem(1) += targetVector.imag(); + + if constexpr (std::is_same_v) { + targetElem(2) += realWeightValues[dataIndex] * sourceElem(2); + } + }); + } + } + } private: template - static auto sliceColumn(View& arrayView, Index index); - - template - using slicedColumnDims = std::make_integer_sequence; - const ComplexMatPtr complexWeights_{}; - const RealMatPtr realWeights_{}; -}; + /// @brief Makes the slice arrayView.slice(index, Range::all()...). + static auto sliceColumn(View& arrayView, Index index) { + constexpr auto Rank = std::decay_t::rank(); + using RangeAll = decltype(array::Range::all()); -template -inline ComplexMatrixMultiply::ComplexMatrixMultiply( - const ComplexMatPtr& complexWeights, const RealMatPtr& realWeights) - : complexWeights_{complexWeights}, realWeights_{realWeights} {} + const auto slicerArgs = std::tuple_cat(std::make_tuple(index), + std::array{}); + const auto slicer = [&](auto&&... args) { + return arrayView.slice(args...); + }; -template -template -void ComplexMatrixMultiply::apply( - const array::ArrayView& sourceView, - array::ArrayView& targetView, VectorType) const { - - const auto* outerIndices = complexWeights_->outerIndexPtr(); - const auto* innerIndices = complexWeights_->innerIndexPtr(); - const auto* complexWeightValues = complexWeights_->valuePtr(); - const auto* realWeightValues = realWeights_->valuePtr(); - const auto nRows = complexWeights_->outerSize(); - - using Index = std::decay_t; - - atlas_omp_parallel_for(auto rowIndex = Index{0}; rowIndex < nRows; - ++rowIndex) { - - auto targetSlice = sliceColumn(targetView, rowIndex); - if constexpr (InitialiseTarget) { targetSlice.assign(0.); } - for (auto dataIndex = outerIndices[rowIndex]; - dataIndex < outerIndices[rowIndex + 1]; ++dataIndex) { - - const auto colIndex = innerIndices[dataIndex]; - const auto sourceSlice = sliceColumn(sourceView, colIndex); - - array::helpers::arrayForEachDim( - slicedColumnDims{}, std::tie(sourceSlice, targetSlice), - [&](auto&& sourceElem, auto&& targetElem) { - const auto targetVector = complexWeightValues[dataIndex] * - Complex(sourceElem(0), sourceElem(1)); - targetElem(0) += targetVector.real(); - targetElem(1) += targetVector.imag(); - - if constexpr (std::is_same_v) { - targetElem(2) += realWeightValues[dataIndex] * sourceElem(2); - } - }); - } + return std::apply(slicer, slicerArgs); } -} -template -template -auto ComplexMatrixMultiply::sliceColumn(View& arrayView, - Index index) { - - constexpr auto Rank = std::decay_t::rank(); - using RangeAll = decltype(array::Range::all()); - - const auto slicerArgs = - std::tuple_cat(std::make_tuple(index), std::array{}); - const auto slicer = [&](auto&&... args) { return arrayView.slice(args...); }; + /// @brief Creates a sequence of Iteration Dims for a sliced column. + template + using slicedColumnDims = std::make_integer_sequence; - return std::apply(slicer, slicerArgs); -} + ComplexMatPtr complexWeights_{}; + RealMatPtr realWeights_{}; +}; } // namespace detail } // namespace method diff --git a/src/atlas/interpolation/method/sphericalvector/SparseMatrix.h b/src/atlas/interpolation/method/sphericalvector/SparseMatrix.h new file mode 100644 index 000000000..7a2ac9d04 --- /dev/null +++ b/src/atlas/interpolation/method/sphericalvector/SparseMatrix.h @@ -0,0 +1,85 @@ +/* + * (C) Crown Copyright 2023 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#pragma once + +#include + +#include "atlas/library/defines.h" +#if ATLAS_HAVE_EIGEN +#include +#endif + +#include "atlas/runtime/Exception.h" + +namespace atlas { +namespace interpolation { +namespace method { +namespace detail { + +#if ATLAS_HAVE_EIGEN +/// @brief Wrapper class for Eigen sparse matrix +/// +/// @details Adapts the Eigen sparse matrix interface to be more in line with +/// eckit::linalg::SparseMatrix. Also allows preprocessor disabling of +/// class is Eigen library is not present. +template +class SparseMatrix { + using EigenMatrix = Eigen::SparseMatrix; + + public: + using Index = typename EigenMatrix::StorageIndex; + using Size = typename EigenMatrix::Index; + using Triplets = std::vector>; + + SparseMatrix(Index nRows, Index nCols, const Triplets& triplets) + : eigenMatrix_(nRows, nCols) { + eigenMatrix_.setFromTriplets(triplets.begin(), triplets.end()); + } + + Size nonZeros() const { return eigenMatrix_.nonZeros(); } + Size rows() const { return eigenMatrix_.rows(); } + Size cols() const { return eigenMatrix_.cols(); } + const Index* outer() { return eigenMatrix_.outerIndexPtr(); } + const Index* inner() { return eigenMatrix_.innerIndexPtr(); } + const Value* data() { return eigenMatrix_.valuePtr(); } + + private: + EigenMatrix eigenMatrix_{}; +}; +#else + +template +class SparseMatrix { + public: + class Triplet { + public: + template + Triplet(const Args&... args) {} + }; + using Index = int; + using Size = long int; + using Triplets = std::vector; + + template + SparseMatrix(const Args&... args) { + ATLAS_THROW_EXCEPTION("Atlas has been compiled without Eigen"); + } + constexpr Size nonZeros() const { return 0; } + constexpr Size rows() const { return 0; } + constexpr Size cols() const { return 0; } + constexpr const Index* outer() { return nullptr; } + constexpr const Index* inner() { return nullptr; } + constexpr const Value* data() { return nullptr; } + +}; +#endif + +} // namespace detail +} // namespace method +} // namespace interpolation +} // namespace atlas diff --git a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc index 43c1278b2..9fc9a2a20 100644 --- a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc +++ b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc @@ -13,8 +13,7 @@ #include "atlas/interpolation/Cache.h" #include "atlas/interpolation/Interpolation.h" #include "atlas/interpolation/method/MethodFactory.h" -#include "atlas/interpolation/method/sphericalvector/ComplexMatrixMultiply.h" -#include "atlas/library/defines.h" +#include "atlas/interpolation/method/sphericalvector/Types.h" #include "atlas/parallel/omp/omp.h" #include "atlas/runtime/Exception.h" #include "atlas/runtime/Trace.h" @@ -30,12 +29,8 @@ namespace { MethodBuilder __builder("spherical-vector"); } -#if ATLAS_HAVE_EIGEN - -using Complex = SphericalVector::Complex; -using Real = SphericalVector::Real; -using ComplexTriplets = std::vector>; -using RealTriplets = std::vector>; +using ComplexTriplets = detail::ComplexMatrix::Triplets; +using RealTriplets = detail::RealMatrix::Triplets; void SphericalVector::do_setup(const Grid& source, const Grid& target, const Cache&) { @@ -66,8 +61,6 @@ void SphericalVector::do_setup(const FunctionSpace& source, // Note: need to store copy of weights as Eigen3 sorts compressed rows by j // whereas eckit does not. - auto complexWeights = std::make_shared(nRows, nCols); - auto realWeights = std::make_shared(nRows, nCols); auto complexTriplets = ComplexTriplets(nNonZeros); auto realTriplets = RealTriplets(nNonZeros); @@ -100,14 +93,13 @@ void SphericalVector::do_setup(const FunctionSpace& source, } } - complexWeights->setFromTriplets(complexTriplets.begin(), - complexTriplets.end()); - realWeights->setFromTriplets(realTriplets.begin(), realTriplets.end()); + const auto complexWeights = + std::make_shared(nRows, nCols, complexTriplets); - ATLAS_ASSERT(complexWeights->nonZeros() == matrix().nonZeros()); - ATLAS_ASSERT(realWeights->nonZeros() == matrix().nonZeros()); + const auto realWeights = + std::make_shared(nRows, nCols, realTriplets); - weightsMatMul_= std::make_shared(complexWeights, realWeights); + weightsMatMul_= WeightsMatMul(complexWeights, realWeights); } @@ -201,20 +193,18 @@ void SphericalVector::interpolate_vector_field(const Field& sourceField, auto targetView = array::make_view(targetField); if (sourceField.variables() == 2) { - weightsMatMul_->apply(sourceView, targetView, detail::twoVector); + weightsMatMul_.apply(sourceView, targetView, detail::twoVector); return; } if (sourceField.variables() == 3) { - weightsMatMul_->apply(sourceView, targetView, detail::threeVector); + weightsMatMul_.apply(sourceView, targetView, detail::threeVector); return; } ATLAS_NOTIMPLEMENTED; } -#endif - } // namespace method } // namespace interpolation } // namespace atlas diff --git a/src/atlas/interpolation/method/sphericalvector/SphericalVector.h b/src/atlas/interpolation/method/sphericalvector/SphericalVector.h index 273e7de94..6563d98ec 100644 --- a/src/atlas/interpolation/method/sphericalvector/SphericalVector.h +++ b/src/atlas/interpolation/method/sphericalvector/SphericalVector.h @@ -12,37 +12,17 @@ #include #include -#if ATLAS_HAVE_EIGEN -#include -#endif - #include "atlas/functionspace/FunctionSpace.h" #include "atlas/interpolation/method/Method.h" +#include "atlas/interpolation/method/sphericalvector/ComplexMatrixMultiply.h" namespace atlas { namespace interpolation { namespace method { -namespace detail { -template -class ComplexMatrixMultiply; -} // namespace detail - - -#if ATLAS_HAVE_EIGEN class SphericalVector : public Method { public: - using Real = double; - using Complex = std::complex; - - template - using SparseMatrix = Eigen::SparseMatrix; - using ComplexMatrix = SparseMatrix; - using RealMatrix = SparseMatrix; - using ComplexMatPtr = std::shared_ptr; - using RealMatPtr = std::shared_ptr; - using WeightsMatMul = detail::ComplexMatrixMultiply; /// @brief Interpolation post-processor for vector field interpolation @@ -95,32 +75,9 @@ class SphericalVector : public Method { FunctionSpace source_; FunctionSpace target_; - std::shared_ptr weightsMatMul_{}; + WeightsMatMul weightsMatMul_{}; }; -#else - class SphericalVector : public Method { - public: - SphericalVector(const Config& config) : Method(config) { - ATLAS_THROW_EXCEPTION("atlas has been compiled without Eigen"); - } - - ~SphericalVector() override {} - - void print(std::ostream&) const override {} - const FunctionSpace& source() const override {ATLAS_NOTIMPLEMENTED;} - const FunctionSpace& target() const override {ATLAS_NOTIMPLEMENTED;} - - void do_execute(const FieldSet& sourceFieldSet, FieldSet& targetFieldSet, - Metadata& metadata) const override {} - void do_execute(const Field& sourceField, Field& targetField, - Metadata& metadata) const override {} - private: - void do_setup(const FunctionSpace& source, - const FunctionSpace& target) override {} - void do_setup(const Grid& source, const Grid& target, const Cache&) override {} - }; -#endif } // namespace method diff --git a/src/atlas/interpolation/method/sphericalvector/Types.h b/src/atlas/interpolation/method/sphericalvector/Types.h new file mode 100644 index 000000000..bf91c1afe --- /dev/null +++ b/src/atlas/interpolation/method/sphericalvector/Types.h @@ -0,0 +1,29 @@ +/* + * (C) Crown Copyright 2023 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#pragma once + +#include + +#include "atlas/interpolation/method/sphericalvector/SparseMatrix.h" + +namespace atlas { +namespace interpolation { +namespace method { +namespace detail { + +using Real = double; +using Complex = std::complex; +using ComplexMatrix = SparseMatrix; +using RealMatrix = SparseMatrix; +using Index = ComplexMatrix::Index; +using Size = ComplexMatrix::Size; + +} // detail +} // method +} // interpolation +} // atlas