Skip to content

Commit

Permalink
Update add interpolation offload test to exercise passing in updated …
Browse files Browse the repository at this point in the history
…device memory.
  • Loading branch information
l90lpa committed Nov 21, 2024
1 parent 41af0ef commit f0391f8
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions src/tests/interpolation/test_interpolation_structured2D_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ template <>
const float AdjointTolerance<float>::value = 2.e-5;


void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend() {
void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend(const bool start_with_data_on_device) {
Grid input_grid(input_gridname("O32"));
Grid output_grid(output_gridname("O64"));

Expand All @@ -83,7 +83,7 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend
for (idx_t f = 0; f < 3; ++f) {
auto field_source = fields_source.add(input_fs.createField<double>(option::name("field " + std::to_string(f))));
fields_target.add(output_fs.createField<double>(option::name("field " + std::to_string(f))));

auto source = array::make_view<double, 2>(field_source);
for (idx_t n = 0; n < input_fs.size(); ++n) {
for (idx_t k = 0; k < 3; ++k) {
Expand All @@ -105,6 +105,9 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend
Config("coordinates", "xy"));
gmsh.write(output_mesh);
output_fs.haloExchange(fields_target);
for (auto& field : fields_target) {
field.setDeviceNeedsUpdate(true);
}
gmsh.write(fields_target);
}
}
Expand All @@ -116,7 +119,7 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend
std::vector<double> xAtAx(fields_source.field_names().size(), 0.);

FieldSet fields_source_reference;
for (atlas::Field& field : fields_source) {
for (const atlas::Field& field : fields_source) {
Field temp_field(field.name(), field.datatype().kind(), field.shape());
temp_field.set_levels(field.levels());

Expand All @@ -131,6 +134,16 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend
fields_source_reference.add(temp_field);
}


if (start_with_data_on_device) {
for (auto& field : fields_source) {
field.updateDevice();
}
for (auto& field : fields_target) {
field.updateDevice();
}
}

interpolation.execute(fields_source, fields_target);
for (auto& field : fields_target) {
field.updateHost();
Expand All @@ -156,6 +169,19 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend
fIndx += 1;
}

if (!start_with_data_on_device) {
for (auto& field : fields_source) {
field.syncHostDevice();
field.deallocateDevice();
field.setDeviceNeedsUpdate(true);
}
for (auto& field : fields_target) {
field.syncHostDevice();
field.deallocateDevice();
field.setDeviceNeedsUpdate(true);
}
}

interpolation.execute_adjoint(fields_source, fields_target);
for (auto& field : fields_source) {
field.updateHost();
Expand Down Expand Up @@ -184,8 +210,12 @@ void test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend
}
}

CASE("test_interpolation_structured using fs API for fieldset with hicsparse backend") {
test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend();
CASE("test_interpolation_structured using fs API for fieldset with hicsparse backend (start with data on host)") {
test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend(false);
}

CASE("test_interpolation_structured using fs API for fieldset with hicsparse backend (start with data on device)") {
test_interpolation_structured_using_fs_API_for_fieldset_w_hicsparse_backend(true);
}

} // namespace test
Expand Down

0 comments on commit f0391f8

Please sign in to comment.