Skip to content

Commit

Permalink
Update interpolation to work with hicsparse sparse matrix multiply ba…
Browse files Browse the repository at this point in the history
…ckend.
  • Loading branch information
l90lpa committed Nov 21, 2024
1 parent 9a1ea76 commit 4882a3c
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 52 deletions.
180 changes: 128 additions & 52 deletions src/atlas/interpolation/method/Method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,78 @@ void set_missing_values(Field& tgt, const std::vector<idx_t>& missing) {
}
}

enum class MemorySpace { Host, Device };

MemorySpace getSparseBackendMemorySpace(const sparse::Backend& backend) {
if (backend.type() == "eckit_linalg") {
return MemorySpace::Host;
} else if (backend.type() == "openmp") {
return MemorySpace::Host;
} else if (backend.type() == "hicsparse") {
return MemorySpace::Device;
} else {
ATLAS_NOTIMPLEMENTED;
}
}

void fetch(const atlas::Field& f, MemorySpace memorySpace) {
if(memorySpace == MemorySpace::Host) {
if (f.hostNeedsUpdate()) {
f.updateHost();
}
}
else {
ATLAS_ASSERT(memorySpace == MemorySpace::Device);
if (f.deviceNeedsUpdate()) {
f.updateDevice();
}
}
}

void setOtherMemorySpaceNeedsUpdate(const atlas::Field& f, MemorySpace memorySpace) {
if(memorySpace == MemorySpace::Host) {
f.setDeviceNeedsUpdate(true);
}
else {
ATLAS_ASSERT(memorySpace == MemorySpace::Device);
f.setHostNeedsUpdate(true);
}
}

template<typename Value, int Rank>
atlas::array::ArrayView<Value, Rank> make_device_view_fetched(atlas::Field& f) {
fetch(f, MemorySpace::Device);
return atlas::array::make_device_view<Value, Rank>(f);
}

template<typename Value, int Rank>
atlas::array::ArrayView<const Value, Rank> make_device_view_fetched(const atlas::Field& f) {
fetch(f, MemorySpace::Device);
return atlas::array::make_device_view<const Value, Rank>(f);
}

template<typename Value, int Rank>
atlas::array::ArrayView<Value, Rank> make_host_view_fetched(atlas::Field& f) {
fetch(f, MemorySpace::Host);
return atlas::array::make_host_view<Value, Rank>(f);
}

template<typename Value, int Rank>
atlas::array::ArrayView<const Value, Rank> make_host_view_fetched(const atlas::Field& f) {
fetch(f, MemorySpace::Host);
return atlas::array::make_host_view<const Value, Rank>(f);
}

} // anonymous namespace


template <typename Value>
void Method::interpolate_field_rank1(const Field& src, Field& tgt, const Matrix& W) const {
auto backend = std::is_same<Value, float>::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_};
auto src_v = array::make_view<Value, 1>(src);
auto tgt_v = array::make_view<Value, 1>(tgt);
const auto memorySpace = getSparseBackendMemorySpace(backend);

auto src_v = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 1>(src) : make_device_view_fetched<Value, 1>(src);
auto tgt_v = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 1>(tgt) : make_device_view_fetched<Value, 1>(tgt);

if (nonLinear_(src)) {
Matrix W_nl(W); // copy (a big penalty -- copy-on-write would definitely be better)
Expand All @@ -112,17 +176,25 @@ void Method::interpolate_field_rank1(const Field& src, Field& tgt, const Matrix&
else {
sparse_matrix_multiply(W, src_v, tgt_v, backend);
}

setOtherMemorySpaceNeedsUpdate(tgt, memorySpace);
}


template <typename Value>
void Method::interpolate_field_rank2(const Field& src, Field& tgt, const Matrix& W) const {
sparse::Backend backend{linalg_backend_};
auto src_v = array::make_view<Value, 2>(src);
auto tgt_v = array::make_view<Value, 2>(tgt);
auto backend = std::is_same<Value, float>::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_};

// To match previous logic of only using OpenMP (probably because eckit_linalg backend doesn't support "layout_left" indexing)
if (backend.type() == "eckit_linalg") {
backend = sparse::backend::openmp();
}

if (nonLinear_(src)) {
// We cannot apply the same matrix to full columns as e.g. missing values could be present in only certain parts.

auto src_v = array::make_view<Value, 2>(src);
auto tgt_v = array::make_view<Value, 2>(tgt);

// Allocate temporary rank-1 fields corresponding to one horizontal level
auto src_slice = Field("s", array::make_datatype<Value>(), {src.shape(0)});
Expand All @@ -144,89 +216,77 @@ void Method::interpolate_field_rank2(const Field& src, Field& tgt, const Matrix&
interpolate_field_rank1<Value>(src_slice, tgt_slice, W);

// Copy rank-1 field to this level in the rank-2 field
fetch(tgt_slice, MemorySpace::Host);
for (idx_t i = 0; i < tgt.shape(0); ++i) {
tgt_v(i, lev) = tgt_slice_v(i);
}
setOtherMemorySpaceNeedsUpdate(tgt, MemorySpace::Host);
}
}
else {
sparse_matrix_multiply(W, src_v, tgt_v, sparse::backend::openmp());
const auto memorySpace = getSparseBackendMemorySpace(backend);
auto src_dv = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 2>(src) : make_device_view_fetched<Value, 2>(src);
auto tgt_dv = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 2>(tgt) : make_device_view_fetched<Value, 2>(tgt);
sparse_matrix_multiply(W, src_dv, tgt_dv, backend);
setOtherMemorySpaceNeedsUpdate(tgt, memorySpace);
}
}


template <typename Value>
void Method::interpolate_field_rank3(const Field& src, Field& tgt, const Matrix& W) const {
sparse::Backend backend{linalg_backend_};
auto src_v = array::make_view<Value, 3>(src);
auto tgt_v = array::make_view<Value, 3>(tgt);
auto src_v = make_host_view_fetched<Value, 3>(src);
auto tgt_v = make_host_view_fetched<Value, 3>(tgt);
if (not W.empty() && nonLinear_(src)) {
ATLAS_ASSERT(false, "nonLinear interpolation not supported for rank-3 fields.");
}
sparse_matrix_multiply(W, src_v, tgt_v, sparse::backend::openmp());
setOtherMemorySpaceNeedsUpdate(tgt, MemorySpace::Host);
}

template <typename Value>
void Method::adjoint_interpolate_field_rank1(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());
auto backend = std::is_same<Value, float>::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_};
const auto memorySpace = getSparseBackendMemorySpace(backend);

auto tmp_v = array::make_view<Value, 1>(tmp);
auto src_v = array::make_view<Value, 1>(src);
auto tgt_v = array::make_view<Value, 1>(tgt);
auto src_v = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 1>(src) : make_device_view_fetched<Value, 1>(src);
auto tgt_v = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 1>(tgt) : make_device_view_fetched<Value, 1>(tgt);

tmp_v.assign(0.);
sparse_matrix_multiply_add(W, tgt_v, src_v, backend);

if (std::is_same<Value, float>::value) {
sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());
}
else {
sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::Backend{linalg_backend_});
}


for (idx_t t = 0; t < tmp.shape(0); ++t) {
src_v(t) += tmp_v(t);
}
setOtherMemorySpaceNeedsUpdate(src, memorySpace);
}

template <typename Value>
void Method::adjoint_interpolate_field_rank2(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());
auto backend = std::is_same<Value, float>::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_};

// To match previous logic of only using OpenMP (probably because eckit_linalg backend doesn't support "layout_left" indexing)
if (backend.type() == "eckit_linalg") {
backend = sparse::backend::openmp();
}

auto tmp_v = array::make_view<Value, 2>(tmp);
auto src_v = array::make_view<Value, 2>(src);
auto tgt_v = array::make_view<Value, 2>(tgt);
const auto memorySpace = getSparseBackendMemorySpace(backend);

tmp_v.assign(0.);
auto src_v = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 2>(src) : make_device_view_fetched<Value, 2>(src);
auto tgt_v = (memorySpace == MemorySpace::Host) ? make_host_view_fetched<Value, 2>(tgt) : make_device_view_fetched<Value, 2>(tgt);

sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());
sparse_matrix_multiply_add(W, tgt_v, src_v, backend);

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t k = 0; k < tmp.shape(1); ++k) {
src_v(t, k) += tmp_v(t, k);
}
}
setOtherMemorySpaceNeedsUpdate(src, memorySpace);
}

template <typename Value>
void Method::adjoint_interpolate_field_rank3(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());

auto tmp_v = array::make_view<Value, 3>(tmp);
auto src_v = array::make_view<Value, 3>(src);
auto tgt_v = array::make_view<Value, 3>(tgt);
sparse::Backend backend{linalg_backend_};

tmp_v.assign(0.);
auto src_v = make_host_view_fetched<Value, 3>(src);
auto tgt_v = make_host_view_fetched<Value, 3>(tgt);

sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());
sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp());

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t j = 0; j < tmp.shape(1); ++j) {
for (idx_t k = 0; k < tmp.shape(2); ++k) {
src_v(t, j, k) += tmp_v(t, j, k);
}
}
}
setOtherMemorySpaceNeedsUpdate(src, MemorySpace::Host);
}

void Method::check_compatibility(const Field& src, const Field& tgt, const Matrix& W) const {
Expand Down Expand Up @@ -381,8 +441,13 @@ void Method::do_execute(const FieldSet& fieldsSource, FieldSet& fieldsTarget, Me
void Method::do_execute(const Field& src, Field& tgt, Metadata&) const {
ATLAS_TRACE("atlas::interpolation::method::Method::do_execute()");

// todo: dispatch to gpu-aware mpi if available
if (src.hostNeedsUpdate()) {
src.updateHost();
}
haloExchange(src);

src.setDeviceNeedsUpdate(true);

if( matrix_ ) { // (matrix == nullptr) when a partition is empty
if (src.datatype().kind() == array::DataType::KIND_REAL64) {
interpolate_field<double>(src, tgt, *matrix_);
Expand Down Expand Up @@ -411,7 +476,13 @@ void Method::do_execute(const Field& src, Field& tgt, Metadata&) const {
}

// set missing values
set_missing_values(tgt, missing_);
if (not missing_.empty()) {
if (tgt.hostNeedsUpdate()) {
tgt.updateHost();
}
set_missing_values(tgt, missing_);
tgt.setDeviceNeedsUpdate(true);
}

tgt.set_dirty();
}
Expand Down Expand Up @@ -454,7 +525,12 @@ void Method::do_execute_adjoint(Field& src, const Field& tgt, Metadata&) const {

src.set_dirty();

// todo: dispatch to gpu-aware mpi if available
if (src.hostNeedsUpdate()) {
src.updateHost();
}
adjointHaloExchange(src);
src.setDeviceNeedsUpdate(true);
}


Expand Down Expand Up @@ -501,4 +577,4 @@ interpolation::Cache Method::createCache() const {


} // namespace interpolation
} // namespace atlas
} // namespace atlas
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,11 @@ void ConservativeSphericalPolygonInterpolation::do_execute(const Field& src_fiel
stopwatch.stop();
{
ATLAS_TRACE("halo exchange target");
if (tgt_field.hostNeedsUpdate()) {
tgt_field.updateHost();
}
tgt_field.haloExchange();
tgt_field.setDeviceNeedsUpdate(true);
}

auto remap_stat = remap_stat_;
Expand Down
7 changes: 7 additions & 0 deletions src/tests/interpolation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ ecbuild_add_test( TARGET atlas_test_interpolation_biquasicubic
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_interpolation_bicubic_gpu
SOURCES test_interpolation_structured2D_gpu.cc
LIBS atlas
CONDITION atlas_HAVE_CUDA OR atlas_HAVE_HIP
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_interpolation_structured2D_to_unstructured
SOURCES test_interpolation_structured2D_to_unstructured.cc
LIBS atlas
Expand Down
Loading

0 comments on commit 4882a3c

Please sign in to comment.