Skip to content

Commit

Permalink
Tided up macros.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Dec 8, 2023
1 parent 636c6d8 commit 3d16b6b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
36 changes: 14 additions & 22 deletions src/atlas/interpolation/method/sphericalvector/SphericalVector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,24 @@ namespace atlas {
namespace interpolation {
namespace method {

using Complex = SphericalVector::Complex;
namespace {
MethodBuilder<SphericalVector> __builder("spherical-vector");
}

#if ATLAS_HAVE_EIGEN

using Complex = SphericalVector::Complex;

template <typename Value>
using SparseMatrix = SphericalVector::SparseMatrix<Value>;
using RealMatrixMap = Eigen::Map<const SparseMatrix<double>>;
using ComplexTriplets = std::vector<Eigen::Triplet<Complex>>;
using RealTriplets = std::vector<Eigen::Triplet<double>>;
#endif

using EckitMatrix = eckit::linalg::SparseMatrix;

namespace {

MethodBuilder<SphericalVector> __builder("spherical-vector");

#if ATLAS_HAVE_EIGEN
RealMatrixMap makeMatrixMap(const EckitMatrix& baseMatrix) {
return RealMatrixMap(baseMatrix.rows(), baseMatrix.cols(),
baseMatrix.nonZeros(), baseMatrix.outer(),
Expand All @@ -64,7 +65,6 @@ auto getInnerIt(const Matrix& matrix, typename Matrix::Index k) {

template <typename Functor, typename Matrix>
void sparseMatrixForEach(const Functor& functor, const Matrix& matrix) {

using Index = typename Matrix::Index;
atlas_omp_parallel_for (auto k = Index{}; k < matrix.outerSize(); ++k) {
for (auto it = getInnerIt(matrix, k); it; ++it) {
Expand Down Expand Up @@ -115,7 +115,6 @@ void matrixMultiply(const SourceView& sourceView, TargetView& targetView,

sparseMatrixForEach(multiplyColumn, matrices...);
}
#endif

} // namespace

Expand All @@ -134,8 +133,6 @@ void SphericalVector::do_setup(const FunctionSpace& source,
return;
}

#if ATLAS_HAVE_EIGEN

setMatrix(Interpolation(interpolationScheme_, source_, target_));

// Get matrix data.
Expand All @@ -154,7 +151,7 @@ void SphericalVector::do_setup(const FunctionSpace& source,
const auto sourceLonLats = array::make_view<double, 2>(source_.lonlat());
const auto targetLonLats = array::make_view<double, 2>(target_.lonlat());

const auto setComplexWeights = [&](auto i, auto j, const auto& weight) {
const auto setWeights = [&](auto i, auto j, const auto& baseWeight) {
const auto sourceLonLat =
PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1));
const auto targetLonLat =
Expand All @@ -165,22 +162,19 @@ void SphericalVector::do_setup(const FunctionSpace& source,
const auto deltaAlpha =
(alpha.first - alpha.second) * util::Constants::degreesToRadians();

const auto idx = std::distance(baseWeights.valuePtr(), &weight);
const auto idx = std::distance(baseWeights.valuePtr(), &baseWeight);

complexTriplets[idx] = {int(i), int(j), std::polar(weight, deltaAlpha)};
realTriplets[idx] = {int(i), int(j), weight};
complexTriplets[idx] = {int(i), int(j), std::polar(baseWeight, deltaAlpha)};
realTriplets[idx] = {int(i), int(j), baseWeight};
};

sparseMatrixForEach(setComplexWeights, baseWeights);
sparseMatrixForEach(setWeights, baseWeights);
complexWeights_->setFromTriplets(complexTriplets.begin(),
complexTriplets.end());
realWeights_->setFromTriplets(realTriplets.begin(), realTriplets.end());

ATLAS_ASSERT(complexWeights_->nonZeros() == matrix().nonZeros());

#else
ATLAS_THROW_EXCEPTION("atlas has been compiled without Eigen");
#endif
ATLAS_ASSERT(realWeights_->nonZeros() == matrix().nonZeros());
}

void SphericalVector::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; }
Expand Down Expand Up @@ -253,7 +247,6 @@ void SphericalVector::interpolate_vector_field(const Field& sourceField,
auto targetView = array::make_view<Value, Rank>(targetField);
targetView.assign(0.);

#if ATLAS_HAVE_EIGEN
const auto horizontalComponent = [](const auto& sourceVars, auto& targetVars,
const auto& complexWeight) {
const auto sourceVector = Complex(sourceVars(0), sourceVars(1));
Expand All @@ -280,13 +273,12 @@ void SphericalVector::interpolate_vector_field(const Field& sourceField,

return;
}
#else
ATLAS_THROW_EXCEPTION("atlas has been compiled without Eigen");
#endif

ATLAS_NOTIMPLEMENTED;
}

#endif

} // namespace method
} // namespace interpolation
} // namespace atlas
33 changes: 28 additions & 5 deletions src/atlas/interpolation/method/sphericalvector/SphericalVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ namespace atlas {
namespace interpolation {
namespace method {

#if ATLAS_HAVE_EIGEN
class SphericalVector : public Method {
public:
using Complex = std::complex<double>;

#if ATLAS_HAVE_EIGEN
template <typename Value>
using SparseMatrix = Eigen::SparseMatrix<Value, Eigen::RowMajor>;
using ComplexMatrix = SparseMatrix<Complex>;
using RealMatrix = SparseMatrix<double>;
#endif


/// @brief Interpolation post-processor for vector field interpolation
///
Expand All @@ -55,7 +55,7 @@ class SphericalVector : public Method {
const auto& conf = dynamic_cast<const eckit::LocalConfiguration&>(config);
interpolationScheme_ = conf.getSubConfiguration("scheme");
}
virtual ~SphericalVector() override {}
~SphericalVector() override {}

void print(std::ostream&) const override;
const FunctionSpace& source() const override { return source_; }
Expand Down Expand Up @@ -93,11 +93,34 @@ class SphericalVector : public Method {
FunctionSpace source_;
FunctionSpace target_;

#if ATLAS_HAVE_EIGEN
std::shared_ptr<ComplexMatrix> complexWeights_;
std::shared_ptr<RealMatrix> realWeights_;
#endif

};
#else
class SphericalVector : public Method {
public:
SphericalVector(const Config& config) : Method(config) {
ATLAS_THROW_EXCEPTION("atlas has been compiled without Eigen");
}

~SphericalVector() override {}

void print(std::ostream&) const override {}
const FunctionSpace& source() const override {ATLAS_NOTIMPLEMENTED;}
const FunctionSpace& target() const override {ATLAS_NOTIMPLEMENTED;}

void do_execute(const FieldSet& sourceFieldSet, FieldSet& targetFieldSet,
Metadata& metadata) const override {}
void do_execute(const Field& sourceField, Field& targetField,
Metadata& metadata) const override {}
private:
void do_setup(const FunctionSpace& source,
const FunctionSpace& target) override {}
void do_setup(const Grid& source, const Grid& target, const Cache&) override {}
};
#endif


} // namespace method
} // namespace interpolation
Expand Down

0 comments on commit 3d16b6b

Please sign in to comment.