Skip to content

Commit

Permalink
Further refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Jan 3, 2024
1 parent db8ea97 commit ee5a6b3
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 138 deletions.
10 changes: 6 additions & 4 deletions src/atlas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ functionspace/detail/PointCloudInterface.cc
functionspace/detail/CubedSphereStructure.h
functionspace/detail/CubedSphereStructure.cc

# for cubedsphere matching mesh partitioner
# for cubedsphere matching mesh partitioner
interpolation/method/cubedsphere/CellFinder.cc
interpolation/method/cubedsphere/CellFinder.h
interpolation/Vector2D.cc
Expand All @@ -541,8 +541,8 @@ interpolation/element/Triag3D.cc
interpolation/element/Triag3D.h
interpolation/method/Intersect.cc
interpolation/method/Intersect.h
interpolation/method/Ray.cc # For testing Quad
interpolation/method/Ray.h # For testing Quad
interpolation/method/Ray.cc # For testing Quad
interpolation/method/Ray.h # For testing Quad

# for BuildConvexHull3D

Expand Down Expand Up @@ -635,8 +635,10 @@ interpolation/method/knn/KNearestNeighboursBase.h
interpolation/method/knn/NearestNeighbour.cc
interpolation/method/knn/NearestNeighbour.h
interpolation/method/sphericalvector/ComplexMatrixMultiply.h
interpolation/method/sphericalvector/SparseMatrix.h
interpolation/method/sphericalvector/SphericalVector.cc
interpolation/method/sphericalvector/SphericalVector.h
interpolation/method/sphericalvector/Types.h
interpolation/method/structured/Cubic2D.cc
interpolation/method/structured/Cubic2D.h
interpolation/method/structured/Cubic3D.cc
Expand Down Expand Up @@ -869,7 +871,7 @@ if( NOT atlas_HAVE_ATLAS_FUNCTIONSPACE )
unset( atlas_parallel_srcs )
unset( atlas_output_srcs )
unset( atlas_redistribution_srcs )
unset( atlas_linalg_srcs ) # only depends on array
unset( atlas_linalg_srcs ) # only depends on array
endif()

if( NOT atlas_HAVE_ATLAS_INTERPOLATION )
Expand Down
167 changes: 98 additions & 69 deletions src/atlas/interpolation/method/sphericalvector/ComplexMatrixMultiply.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
/*
* (C) Crown Copyright 2023 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#pragma once

#include <array>
Expand All @@ -8,17 +15,16 @@
#include "atlas/array/ArrayView.h"
#include "atlas/array/Range.h"
#include "atlas/array/helpers/ArrayForEach.h"
#include "atlas/interpolation/method/sphericalvector/SphericalVector.h"
#include "atlas/interpolation/method/sphericalvector/Types.h"
#include "atlas/parallel/omp/omp.h"

namespace atlas {
namespace interpolation {
namespace method {
namespace detail {

using Complex = SphericalVector::Complex;
using ComplexMatPtr = SphericalVector::ComplexMatPtr;
using RealMatPtr = SphericalVector::RealMatPtr;
using ComplexMatPtr = std::shared_ptr<ComplexMatrix>;
using RealMatPtr = std::shared_ptr<RealMatrix>;

struct TwoVectorTag {};
struct ThreeVectorTag {};
Expand All @@ -29,89 +35,112 @@ template <typename VectorTag>
using IsVectorTag = std::enable_if_t<std::is_same_v<VectorTag, TwoVectorTag> ||
std::is_same_v<VectorTag, ThreeVectorTag>>;

/// @brief Helper class to perform complex matrix multiplications
///
/// @details Performs matrix multiplication between fields of 2-vectors or
/// 3-vectors. Fields must have Rank >= 2. Here, the assumption is
/// that Dim = 0 is the horizontal dimension, and Dim = (Rank - 1) is
/// the vector element dimension.
template <bool InitialiseTarget>
class ComplexMatrixMultiply {
public:
ComplexMatrixMultiply() = default;

/// @brief Construct object from sparse matrices.
///
/// @details complexWeights is a SparseMatrix of weights. realWeights is a
/// SparseMatrix containing the magnitudes of the elements of
/// complexWeights.
ComplexMatrixMultiply(const ComplexMatPtr& complexWeights,
const RealMatPtr& realWeights);
const RealMatPtr& realWeights)
: complexWeights_{complexWeights}, realWeights_{realWeights} {

if constexpr (ATLAS_BUILD_TYPE_DEBUG) {
ATLAS_ASSERT(complexWeights->rows() == realWeights->rows());
ATLAS_ASSERT(complexWeights->cols() == realWeights->cols());
ATLAS_ASSERT(complexWeights->nonZeros() == realWeights->nonZeros());

for (auto i = Size{0}; i < complexWeights->rows() + 1; ++i) {
ATLAS_ASSERT(complexWeights_->outer()[i] == realWeights_->outer()[i]);
}

for (auto i = Size{0}; i < complexWeights->nonZeros(); ++i) {
ATLAS_ASSERT(complexWeights_->inner()[i] == realWeights_->inner()[i]);
}
}
}

/// @brief Apply complex matrix vector multiplication.
///
/// @details Multiply weights by the elements in sourceView to give
/// elements in targetView. If VectorType == TwoVectorTag,
/// complexWeights are applied to the horizontal elements of
/// sourceView. If VectorType == ThreeVectorTag, then realWeights
/// are additionally applied to the vertical elements of sourceView.
template <typename Value, int Rank, typename VectorType,
typename = IsVectorTag<VectorType>>
void apply(const array::ArrayView<const Value, Rank>& sourceView,
array::ArrayView<Value, Rank>& targetView, VectorType) const;
array::ArrayView<Value, Rank>& targetView, VectorType) const {

const auto* outerIndices = complexWeights_->outer();
const auto* innerIndices = complexWeights_->inner();
const auto* complexWeightValues = complexWeights_->data();
const auto* realWeightValues = realWeights_->data();
const auto nRows = complexWeights_->rows();

using Index = std::decay_t<decltype(*innerIndices)>;

atlas_omp_parallel_for(auto rowIndex = Index{0}; rowIndex < nRows;
++rowIndex) {

auto targetSlice = sliceColumn(targetView, rowIndex);
if constexpr (InitialiseTarget) { targetSlice.assign(0.); }
for (auto dataIndex = outerIndices[rowIndex];
dataIndex < outerIndices[rowIndex + 1]; ++dataIndex) {

const auto colIndex = innerIndices[dataIndex];
const auto sourceSlice = sliceColumn(sourceView, colIndex);

array::helpers::arrayForEachDim(
slicedColumnDims<Rank>{}, std::tie(sourceSlice, targetSlice),
[&](auto&& sourceElem, auto&& targetElem) {
const auto targetVector = complexWeightValues[dataIndex] *
Complex(sourceElem(0), sourceElem(1));
targetElem(0) += targetVector.real();
targetElem(1) += targetVector.imag();

if constexpr (std::is_same_v<VectorType, ThreeVectorTag>) {
targetElem(2) += realWeightValues[dataIndex] * sourceElem(2);
}
});
}
}
}

private:
template <typename View, typename Index>
static auto sliceColumn(View& arrayView, Index index);

template <int Rank>
using slicedColumnDims = std::make_integer_sequence<int, Rank - 2>;

const ComplexMatPtr complexWeights_{};
const RealMatPtr realWeights_{};
};
/// @brief Makes the slice arrayView.slice(index, Range::all()...).
static auto sliceColumn(View& arrayView, Index index) {
constexpr auto Rank = std::decay_t<View>::rank();
using RangeAll = decltype(array::Range::all());

template <bool InitialiseTarget>
inline ComplexMatrixMultiply<InitialiseTarget>::ComplexMatrixMultiply(
const ComplexMatPtr& complexWeights, const RealMatPtr& realWeights)
: complexWeights_{complexWeights}, realWeights_{realWeights} {}
const auto slicerArgs = std::tuple_cat(std::make_tuple(index),
std::array<RangeAll, Rank - 1>{});
const auto slicer = [&](auto&&... args) {
return arrayView.slice(args...);
};

template <bool InitialiseTarget>
template <typename Value, int Rank, typename VectorType, typename>
void ComplexMatrixMultiply<InitialiseTarget>::apply(
const array::ArrayView<const Value, Rank>& sourceView,
array::ArrayView<Value, Rank>& targetView, VectorType) const {

const auto* outerIndices = complexWeights_->outerIndexPtr();
const auto* innerIndices = complexWeights_->innerIndexPtr();
const auto* complexWeightValues = complexWeights_->valuePtr();
const auto* realWeightValues = realWeights_->valuePtr();
const auto nRows = complexWeights_->outerSize();

using Index = std::decay_t<decltype(*innerIndices)>;

atlas_omp_parallel_for(auto rowIndex = Index{0}; rowIndex < nRows;
++rowIndex) {

auto targetSlice = sliceColumn(targetView, rowIndex);
if constexpr (InitialiseTarget) { targetSlice.assign(0.); }
for (auto dataIndex = outerIndices[rowIndex];
dataIndex < outerIndices[rowIndex + 1]; ++dataIndex) {

const auto colIndex = innerIndices[dataIndex];
const auto sourceSlice = sliceColumn(sourceView, colIndex);

array::helpers::arrayForEachDim(
slicedColumnDims<Rank>{}, std::tie(sourceSlice, targetSlice),
[&](auto&& sourceElem, auto&& targetElem) {
const auto targetVector = complexWeightValues[dataIndex] *
Complex(sourceElem(0), sourceElem(1));
targetElem(0) += targetVector.real();
targetElem(1) += targetVector.imag();

if constexpr (std::is_same_v<VectorType, ThreeVectorTag>) {
targetElem(2) += realWeightValues[dataIndex] * sourceElem(2);
}
});
}
return std::apply(slicer, slicerArgs);
}
}

template <bool InitialiseTarget>
template <typename View, typename Index>
auto ComplexMatrixMultiply<InitialiseTarget>::sliceColumn(View& arrayView,
Index index) {

constexpr auto Rank = std::decay_t<View>::rank();
using RangeAll = decltype(array::Range::all());

const auto slicerArgs =
std::tuple_cat(std::make_tuple(index), std::array<RangeAll, Rank - 1>{});
const auto slicer = [&](auto&&... args) { return arrayView.slice(args...); };
/// @brief Creates a sequence of Iteration Dims for a sliced column.
template <int Rank>
using slicedColumnDims = std::make_integer_sequence<int, Rank - 2>;

return std::apply(slicer, slicerArgs);
}
ComplexMatPtr complexWeights_{};
RealMatPtr realWeights_{};
};

} // namespace detail
} // namespace method
Expand Down
85 changes: 85 additions & 0 deletions src/atlas/interpolation/method/sphericalvector/SparseMatrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* (C) Crown Copyright 2023 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#pragma once

#include <vector>

#include "atlas/library/defines.h"
#if ATLAS_HAVE_EIGEN
#include <Eigen/Sparse>
#endif

#include "atlas/runtime/Exception.h"

namespace atlas {
namespace interpolation {
namespace method {
namespace detail {

#if ATLAS_HAVE_EIGEN
/// @brief Wrapper class for Eigen sparse matrix
///
/// @details Adapts the Eigen sparse matrix interface to be more in line with
/// eckit::linalg::SparseMatrix. Also allows preprocessor disabling of
/// class is Eigen library is not present.
template <typename Value>
class SparseMatrix {
using EigenMatrix = Eigen::SparseMatrix<Value, Eigen::RowMajor>;

public:
using Index = typename EigenMatrix::StorageIndex;
using Size = typename EigenMatrix::Index;
using Triplets = std::vector<Eigen::Triplet<Value>>;

SparseMatrix(Index nRows, Index nCols, const Triplets& triplets)
: eigenMatrix_(nRows, nCols) {
eigenMatrix_.setFromTriplets(triplets.begin(), triplets.end());
}

Size nonZeros() const { return eigenMatrix_.nonZeros(); }
Size rows() const { return eigenMatrix_.rows(); }
Size cols() const { return eigenMatrix_.cols(); }
const Index* outer() { return eigenMatrix_.outerIndexPtr(); }
const Index* inner() { return eigenMatrix_.innerIndexPtr(); }
const Value* data() { return eigenMatrix_.valuePtr(); }

private:
EigenMatrix eigenMatrix_{};
};
#else

template <typename Value>
class SparseMatrix {
public:
class Triplet {
public:
template<typename... Args>
Triplet(const Args&... args) {}
};
using Index = int;
using Size = long int;
using Triplets = std::vector<Triplet>;

template<typename... Args>
SparseMatrix(const Args&... args) {
ATLAS_THROW_EXCEPTION("Atlas has been compiled without Eigen");
}
constexpr Size nonZeros() const { return 0; }
constexpr Size rows() const { return 0; }
constexpr Size cols() const { return 0; }
constexpr const Index* outer() { return nullptr; }
constexpr const Index* inner() { return nullptr; }
constexpr const Value* data() { return nullptr; }

};
#endif

} // namespace detail
} // namespace method
} // namespace interpolation
} // namespace atlas
Loading

0 comments on commit ee5a6b3

Please sign in to comment.