diff --git a/src/atlas/array/helpers/ArrayForEach.h b/src/atlas/array/helpers/ArrayForEach.h index 38d4d6441..26665a1d5 100644 --- a/src/atlas/array/helpers/ArrayForEach.h +++ b/src/atlas/array/helpers/ArrayForEach.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "atlas/array/ArrayView.h" #include "atlas/array/Range.h" @@ -370,6 +371,15 @@ struct ArrayForEach { }; +/// brief Construct ArrayForEach and call apply +/// +/// details Construct an ArrayForEach using std::integer_sequence +/// . Remaining arguments are forwarded to apply method. +template +void arrayForEachDim(std::integer_sequence, Args&&... args) { + ArrayForEach::apply(std::forward(args)...); +} + } // namespace helpers } // namespace array } // namespace atlas diff --git a/src/tests/array/test_array_foreach.cc b/src/tests/array/test_array_foreach.cc index f8c5b0b4a..05849d747 100644 --- a/src/tests/array/test_array_foreach.cc +++ b/src/tests/array/test_array_foreach.cc @@ -7,6 +7,7 @@ #include #include +#include #include "atlas/array.h" #include "atlas/array/MakeView.h" @@ -207,6 +208,61 @@ CASE("test_array_foreach_3_views") { EXPECT_EQ(count, 60); } +CASE("test_array_foreach_integer_sequence") { + + const auto arr1 = ArrayT(2, 3); + const auto view1 = make_view(arr1); + + const auto arr2 = ArrayT(2, 3, 4); + const auto view2 = make_view(arr2); + + const auto arr3 = ArrayT(2, 3, 4, 5); + const auto view3 = make_view(arr3); + + const auto zero = std::integer_sequence{}; + const auto one = std::integer_sequence{}; + const auto zeroOneTwoThree = std::make_integer_sequence{}; + + + // Test slice shapes. + + const auto loopFunctorDim0 = [](auto&& slice1, auto&& slice2, auto&& slice3) { + EXPECT_EQ(slice1.rank(), 1); + EXPECT_EQ(slice1.shape(0), 3); + + EXPECT_EQ(slice2.rank(), 2); + EXPECT_EQ(slice2.shape(0), 3); + EXPECT_EQ(slice2.shape(1), 4); + + EXPECT_EQ(slice3.rank(), 3); + EXPECT_EQ(slice3.shape(0), 3); + EXPECT_EQ(slice3.shape(1), 4); + EXPECT_EQ(slice3.shape(2), 5); + }; + arrayForEachDim(zero, std::tie(view1, view2, view3), loopFunctorDim0); + + const auto loopFunctorDim1 = [](auto&& slice1, auto&& slice2, auto&& slice3) { + EXPECT_EQ(slice1.rank(), 1); + EXPECT_EQ(slice1.shape(0), 2); + + EXPECT_EQ(slice2.rank(), 2); + EXPECT_EQ(slice2.shape(0), 2); + EXPECT_EQ(slice2.shape(1), 4); + + EXPECT_EQ(slice3.rank(), 3); + EXPECT_EQ(slice3.shape(0), 2); + EXPECT_EQ(slice3.shape(1), 4); + EXPECT_EQ(slice3.shape(2), 5); + }; + arrayForEachDim(one, std::tie(view1, view2, view3), loopFunctorDim1); + + // Test that slice resolves to double. + + const auto loopFunctorDimAll = [](auto&& slice3) { + static_assert(std::is_convertible_v); + }; + arrayForEachDim(zeroOneTwoThree, std::tie(view3), loopFunctorDimAll); +} CASE("test_array_foreach_forwarding") {