From ea63d3d600e68a6c74d0fa76d09111be7991e530 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Thu, 21 Nov 2024 19:51:29 +0000 Subject: [PATCH] Update add interpolation offload test to exercise passing in updated device memory. --- .../test_interpolation_structured2D_gpu.cc | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/tests/interpolation/test_interpolation_structured2D_gpu.cc b/src/tests/interpolation/test_interpolation_structured2D_gpu.cc index 1e42dc152..4892a336d 100644 --- a/src/tests/interpolation/test_interpolation_structured2D_gpu.cc +++ b/src/tests/interpolation/test_interpolation_structured2D_gpu.cc @@ -65,7 +65,7 @@ template <> const float AdjointTolerance::value = 2.e-5; -void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend() { +void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend(const bool start_with_data_on_device) { Grid input_grid(input_gridname("O32")); Grid output_grid(output_gridname("O64")); @@ -83,7 +83,7 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend for (idx_t f = 0; f < 3; ++f) { auto field_source = fields_source.add(input_fs.createField(option::name("field " + std::to_string(f)))); fields_target.add(output_fs.createField(option::name("field " + std::to_string(f)))); - + auto source = array::make_view(field_source); for (idx_t n = 0; n < input_fs.size(); ++n) { for (idx_t k = 0; k < 3; ++k) { @@ -105,6 +105,9 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend Config("coordinates", "xy")); gmsh.write(output_mesh); output_fs.haloExchange(fields_target); + for (auto& field : fields_target) { + field.setDeviceNeedsUpdate(true); + } gmsh.write(fields_target); } } @@ -116,7 +119,7 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend std::vector xAtAx(fields_source.field_names().size(), 0.); FieldSet fields_source_reference; - for (atlas::Field& field : fields_source) { + for (const atlas::Field& field : fields_source) { Field temp_field(field.name(), field.datatype().kind(), field.shape()); temp_field.set_levels(field.levels()); @@ -131,6 +134,16 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend fields_source_reference.add(temp_field); } + + if (start_with_data_on_device) { + for (auto& field : fields_source) { + field.updateDevice(); + } + for (auto& field : fields_target) { + field.updateDevice(); + } + } + interpolation.execute(fields_source, fields_target); for (auto& field : fields_target) { field.updateHost(); @@ -156,6 +169,19 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend fIndx += 1; } + if (!start_with_data_on_device) { + for (auto& field : fields_source) { + field.syncHostDevice(); + field.deallocateDevice(); + field.setDeviceNeedsUpdate(true); + } + for (auto& field : fields_target) { + field.syncHostDevice(); + field.deallocateDevice(); + field.setDeviceNeedsUpdate(true); + } + } + interpolation.execute_adjoint(fields_source, fields_target); for (auto& field : fields_source) { field.updateHost(); @@ -184,8 +210,12 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend } } -CASE("test_interpolation_structured using fs API for fieldset with hicsparse backend") { - test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend(); +CASE("test_interpolation_structured using fs API for fieldset with hicsparse backend (start with data on host)") { + test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend(false); +} + +CASE("test_interpolation_structured using fs API for fieldset with hicsparse backend (start with data on device)") { + test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend(true); } } // namespace test