Skip to content

Commit

Permalink
Enable on-device halo exchange in interpolation.
Browse files Browse the repository at this point in the history
  • Loading branch information
l90lpa committed Nov 21, 2024
1 parent 4882a3c commit 4ea6544
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 22 deletions.
70 changes: 52 additions & 18 deletions src/atlas/interpolation/method/Method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,29 @@ void Method::do_execute(const FieldSet& fieldsSource, FieldSet& fieldsTarget, Me
void Method::do_execute(const Field& src, Field& tgt, Metadata&) const {
ATLAS_TRACE("atlas::interpolation::method::Method::do_execute()");

// todo: dispatch to gpu-aware mpi if available
if (src.hostNeedsUpdate()) {
src.updateHost();
if (src.hostNeedsUpdate() && src.deviceNeedsUpdate()) {
throw_AssertionFailed("Inconsistent memory state flags - we will not be able to "
"determine which memory space to perform the halo exchange on",
Here());
}

sparse::Backend backend{linalg_backend_};
const auto memorySpace = getSparseBackendMemorySpace(backend);

bool on_device = false;
if (!src.hostNeedsUpdate() && !src.deviceNeedsUpdate()) {
on_device = memorySpace == MemorySpace::Device;
} else {
on_device = !src.deviceNeedsUpdate() ? true : false;
}

haloExchange(src, on_device);

if (on_device) {
src.setHostNeedsUpdate(true);
} else {
src.setDeviceNeedsUpdate(true);
}
haloExchange(src);
src.setDeviceNeedsUpdate(true);

if( matrix_ ) { // (matrix == nullptr) when a partition is empty
if (src.datatype().kind() == array::DataType::KIND_REAL64) {
Expand Down Expand Up @@ -501,6 +518,15 @@ void Method::do_execute_adjoint(FieldSet& fieldsSource, const FieldSet& fieldsTa
void Method::do_execute_adjoint(Field& src, const Field& tgt, Metadata&) const {
ATLAS_TRACE("atlas::interpolation::method::Method::do_execute_adjoint()");

if (src.hostNeedsUpdate() && src.deviceNeedsUpdate()) {
throw_AssertionFailed("Inconsistent memory state flags - we will not be able to "
"determine which memory space to perform the adjoint halo exchange on",
Here());
}

sparse::Backend backend{linalg_backend_};
const auto memorySpace = getSparseBackendMemorySpace(backend);

if (nonLinear_(src)) {
throw_NotImplemented("Adjoint interpolation only works for interpolation schemes that are linear", Here());
}
Expand All @@ -525,12 +551,20 @@ void Method::do_execute_adjoint(Field& src, const Field& tgt, Metadata&) const {

src.set_dirty();

// todo: dispatch to gpu-aware mpi if available
if (src.hostNeedsUpdate()) {
src.updateHost();
bool on_device = false;
if (!src.hostNeedsUpdate() && !src.deviceNeedsUpdate()) {
on_device = memorySpace == MemorySpace::Device;
} else {
on_device = !src.deviceNeedsUpdate() ? true : false;
}

adjointHaloExchange(src, on_device);

if (on_device) {
src.setHostNeedsUpdate(true);
} else {
src.setDeviceNeedsUpdate(true);
}
adjointHaloExchange(src);
src.setDeviceNeedsUpdate(true);
}


Expand All @@ -549,25 +583,25 @@ void Method::normalise(Triplets& triplets) {
}
}

void Method::haloExchange(const FieldSet& fields) const {
void Method::haloExchange(const FieldSet& fields, bool on_device) const {
for (auto& field : fields) {
haloExchange(field);
haloExchange(field, on_device);
}
}
void Method::haloExchange(const Field& field) const {
void Method::haloExchange(const Field& field, bool on_device) const {
if (field.dirty() && allow_halo_exchange_) {
source().haloExchange(field);
source().haloExchange(field, on_device);
}
}

void Method::adjointHaloExchange(const FieldSet& fields) const {
void Method::adjointHaloExchange(const FieldSet& fields, bool on_device) const {
for (auto& field : fields) {
adjointHaloExchange(field);
adjointHaloExchange(field, on_device);
}
}
void Method::adjointHaloExchange(const Field& field) const {
void Method::adjointHaloExchange(const Field& field, bool on_device) const {
if (field.dirty() && allow_halo_exchange_) {
source().adjointHaloExchange(field);
source().adjointHaloExchange(field, on_device);
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/atlas/interpolation/method/Method.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ class Method : public util::Object {

static void normalise(Triplets& triplets);

void haloExchange(const FieldSet&) const;
void haloExchange(const Field&) const;
void haloExchange(const FieldSet&, bool on_device = false) const;
void haloExchange(const Field&, bool on_device = false) const;

void adjointHaloExchange(const FieldSet&) const;
void adjointHaloExchange(const Field&) const;
void adjointHaloExchange(const FieldSet&, bool on_device = false) const;
void adjointHaloExchange(const Field&, bool on_device = false) const;

// NOTE : Matrix-free or non-linear interpolation operators do not have matrices, so do not expose here
friend class atlas::test::Access;
Expand Down
8 changes: 8 additions & 0 deletions src/tests/interpolation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ ecbuild_add_test( TARGET atlas_test_interpolation_structured2D_to_points
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_interpolation_structured2D_to_points_gpu
SOURCES test_interpolation_structured2D_to_points_gpu.cc
LIBS atlas
MPI 4
CONDITION eckit_HAVE_MPI AND atlas_HAVE_GPU
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_interpolation_cubedsphere
SOURCES test_interpolation_cubedsphere.cc
LIBS atlas
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* (C) Copyright 2013 ECMWF.
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
* In applying this licence, ECMWF does not waive the privileges and immunities
* granted to it by virtue of its status as an intergovernmental organisation
* nor does it submit to any jurisdiction.
*/

#include "atlas/array.h"
#include "atlas/field/Field.h"
#include "atlas/field/FieldSet.h"
#include "atlas/functionspace/PointCloud.h"
#include "atlas/functionspace/StructuredColumns.h"
#include "atlas/grid/Grid.h"
#include "atlas/grid/Iterator.h"
#include "atlas/interpolation.h"
#include "atlas/mesh/Mesh.h"
#include "atlas/meshgenerator.h"
#include "atlas/output/Gmsh.h"
#include "atlas/util/CoordinateEnums.h"
#include "atlas/util/function/VortexRollup.h"
#include "atlas/util/PolygonXY.h"

#include "tests/AtlasTestEnvironment.h"

using atlas::functionspace::PointCloud;
using atlas::functionspace::StructuredColumns;
using atlas::util::Config;

namespace atlas {
namespace test {

//-----------------------------------------------------------------------------

std::string input_gridname(const std::string& default_grid) {
return eckit::Resource<std::string>("--input-grid", default_grid);
}

FunctionSpace output_functionspace( const FunctionSpace& input_functionspace, bool expect_fail ) {

std::vector<PointXY> all_points {
{360., 90.},
{360., 0.},
{360., -90.},
{0.,0.},
};

// Only keep points that match the input partitioning.
// Note that with following implementation it could be that some points
// are present in two partitions, but it is not a problem for this test purpose.
std::vector<PointXY> points;
auto polygon = util::PolygonXY{input_functionspace.polygon()};
for (const auto& p : all_points ) {
if (polygon.contains(p)) {
points.emplace_back(p);
}
}
if( expect_fail && mpi::rank() == mpi::size() - 1 ) {
points.emplace_back(720,0.);
}
return PointCloud(points);
}


FieldSet create_source_fields(StructuredColumns& fs, idx_t nb_fields, idx_t nb_levels) {
using Value = double;
FieldSet fields_source;
auto lonlat = array::make_view<double, 2>(fs.xy());
for (idx_t f = 0; f < nb_fields; ++f) {
auto field_source = fields_source.add(fs.createField<Value>());
auto source = array::make_view<Value, 2>(field_source);
for (idx_t n = 0; n < fs.size(); ++n) {
for (idx_t k = 0; k < nb_levels; ++k) {
source(n, k) = util::function::vortex_rollup(lonlat(n, LON), lonlat(n, LAT), 0.5 + double(k) / 2);
}
};
field_source.updateDevice();
}
return fields_source;
}
FieldSet create_target_fields(FunctionSpace& fs, idx_t nb_fields, idx_t nb_levels) {
using Value = double;
FieldSet fields_target;
for (idx_t f = 0; f < nb_fields; ++f) {
auto field_target = fields_target.add(fs.createField<Value>(option::levels(nb_levels)));
field_target.updateDevice();
}
return fields_target;
}

void do_test( std::string type, int input_halo, bool matrix_free, bool expect_failure ) {
idx_t nb_fields = 2;
idx_t nb_levels = 3;

Grid input_grid(input_gridname("O32"));
StructuredColumns input_fs(input_grid, option::levels(nb_levels) |
option::halo(input_halo));

FunctionSpace output_fs = output_functionspace(input_fs, expect_failure);

Interpolation interpolation(option::type(type) |
util::Config("matrix_free",matrix_free) |
util::Config("sparse_matrix_multiply", "hicsparse") |
util::Config("verbose",eckit::Resource<bool>("--verbose",false)),
input_fs, output_fs);

FieldSet fields_source = create_source_fields(input_fs, nb_fields, nb_levels);
FieldSet fields_target = create_target_fields(output_fs, nb_fields, nb_levels);

interpolation.execute(fields_source, fields_target);
}

CASE("test structured-bilinear, halo 2, with matrix") {
EXPECT_NO_THROW( do_test("structured-bilinear",2,false,false) );
}

CASE("test structured-bilinear, halo 2, with matrix, expected failure") {
EXPECT_THROWS_AS( do_test("structured-bilinear",2,false,true), eckit::Exception );
}

CASE("test structured-bilinear, halo 2, without matrix, expected failure") {
EXPECT_THROWS_AS( do_test("structured-bilinear",2,false,true), eckit::Exception );
}

CASE("test structured-bilinear, halo 1, with matrix, expected failure") {
EXPECT_THROWS_AS( do_test("structured-bilinear",1,false,false), eckit::Exception );
}

CASE("test structured-bicubic, halo 3, with matrix") {
EXPECT_NO_THROW( do_test("structured-bicubic",3,false,false) );
}

CASE("test structured-bicubic, halo 2, with matrix") {
EXPECT_THROWS_AS( do_test("structured-bicubic",2,false,false), eckit::Exception );
}


} // namespace test
} // namespace atlas

int main(int argc, char** argv) {
return atlas::test::run(argc, argv);
}

0 comments on commit 4ea6544

Please sign in to comment.