Skip to content

Commit

Permalink
Added arrayForEachDim wrapper function.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Dec 21, 2023
1 parent edd96a5 commit 0958cfb
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/atlas/array/helpers/ArrayForEach.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tuple>
#include <type_traits>
#include <string_view>
#include <utility>

#include "atlas/array/ArrayView.h"
#include "atlas/array/Range.h"
Expand Down Expand Up @@ -370,6 +371,15 @@ struct ArrayForEach {

};

/// brief Construct ArrayForEach and call apply
///
/// details Construct an ArrayForEach<ItrDims...> using std::integer_sequence
/// <int, ItrDims...>. Remaining arguments are forwarded to apply method.
template <int... ItrDims, typename... Args>
void arrayForEachDim(std::integer_sequence<int, ItrDims...>, Args&&... args) {
ArrayForEach<ItrDims...>::apply(std::forward<Args>(args)...);
}

} // namespace helpers
} // namespace array
} // namespace atlas
56 changes: 56 additions & 0 deletions src/tests/array/test_array_foreach.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <chrono>
#include <type_traits>
#include <utility>

#include "atlas/array.h"
#include "atlas/array/MakeView.h"
Expand Down Expand Up @@ -207,6 +208,61 @@ CASE("test_array_foreach_3_views") {
EXPECT_EQ(count, 60);
}

CASE("test_array_foreach_integer_sequence") {

const auto arr1 = ArrayT<double>(2, 3);
const auto view1 = make_view<double, 2>(arr1);

const auto arr2 = ArrayT<double>(2, 3, 4);
const auto view2 = make_view<double, 3>(arr2);

const auto arr3 = ArrayT<double>(2, 3, 4, 5);
const auto view3 = make_view<double, 4>(arr3);

const auto zero = std::integer_sequence<int, 0>{};
const auto one = std::integer_sequence<int, 1>{};
const auto zeroOneTwoThree = std::make_integer_sequence<int, 4>{};


// Test slice shapes.

const auto loopFunctorDim0 = [](auto&& slice1, auto&& slice2, auto&& slice3) {
EXPECT_EQ(slice1.rank(), 1);
EXPECT_EQ(slice1.shape(0), 3);

EXPECT_EQ(slice2.rank(), 2);
EXPECT_EQ(slice2.shape(0), 3);
EXPECT_EQ(slice2.shape(1), 4);

EXPECT_EQ(slice3.rank(), 3);
EXPECT_EQ(slice3.shape(0), 3);
EXPECT_EQ(slice3.shape(1), 4);
EXPECT_EQ(slice3.shape(2), 5);
};
arrayForEachDim(zero, std::tie(view1, view2, view3), loopFunctorDim0);

const auto loopFunctorDim1 = [](auto&& slice1, auto&& slice2, auto&& slice3) {
EXPECT_EQ(slice1.rank(), 1);
EXPECT_EQ(slice1.shape(0), 2);

EXPECT_EQ(slice2.rank(), 2);
EXPECT_EQ(slice2.shape(0), 2);
EXPECT_EQ(slice2.shape(1), 4);

EXPECT_EQ(slice3.rank(), 3);
EXPECT_EQ(slice3.shape(0), 2);
EXPECT_EQ(slice3.shape(1), 4);
EXPECT_EQ(slice3.shape(2), 5);
};
arrayForEachDim(one, std::tie(view1, view2, view3), loopFunctorDim1);

// Test that slice resolves to double.

const auto loopFunctorDimAll = [](auto&& slice3) {
static_assert(std::is_convertible_v<decltype(slice3), const double&>);
};
arrayForEachDim(zeroOneTwoThree, std::tie(view3), loopFunctorDimAll);
}

CASE("test_array_foreach_forwarding") {

Expand Down

0 comments on commit 0958cfb

Please sign in to comment.