Skip to content

Commit

Permalink
Add some convenience functions/overloads to StateDescriptor and DataC…
Browse files Browse the repository at this point in the history
…ollection
  • Loading branch information
Ben Prather committed Aug 16, 2023
1 parent 86f750e commit 94984ec
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 7 deletions.
47 changes: 46 additions & 1 deletion src/interface/data_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,55 @@ std::shared_ptr<T> &DataCollection<T>::AddShallow(const std::string &label,
const std::vector<std::string> &flags) {
return Add(label, src, flags, true);
}
template <typename T>
std::shared_ptr<T> &
DataCollection<T>::Add(const std::string &name, const std::string src_name,
const std::vector<std::string> &field_names, const bool shallow) {
return Add(name, containers_[src_name], field_names, false);
}
template <typename T>
std::shared_ptr<T> &DataCollection<T>::Add(const std::string &name,
const std::string src_name,
const std::vector<std::string> &field_names) {
return Add(name, containers_[src_name], field_names, false);
}
template <typename T>
std::shared_ptr<T> &
DataCollection<T>::AddShallow(const std::string &name, const std::string src_name,
const std::vector<std::string> &field_names) {
return Add(name, containers_[src_name], field_names, true);
}
template <typename T>
std::shared_ptr<T> &DataCollection<T>::Add(const std::string &name,
const std::string src_name) {
return Add(name, containers_[src_name], {}, false);
}
template <typename T>
std::shared_ptr<T> &DataCollection<T>::AddShallow(const std::string &name,
const std::string src_name) {
return Add(name, containers_[src_name], {}, true);
}
template <typename T>
std::shared_ptr<T> &
DataCollection<T>::Add(const std::string &name,
const std::vector<std::string> &field_names, const bool shallow) {
return Add(name, containers_["base"], field_names, false);
}
template <typename T>
std::shared_ptr<T> &DataCollection<T>::Add(const std::string &name,
const std::vector<std::string> &field_names) {
return Add(name, containers_["base"], field_names, false);
}
template <typename T>
std::shared_ptr<T> &
DataCollection<T>::AddShallow(const std::string &name,
const std::vector<std::string> &field_names) {
return Add(name, containers_["base"], field_names, true);
}

template <>
std::shared_ptr<MeshData<Real>> &
DataCollection<MeshData<Real>>::GetOrAdd(const std::string &mbd_label,
DataCollection<MeshData<Real>>::GetOrAddByPartition(const std::string &mbd_label,
const int &partition_id) {
const std::string label = mbd_label + "_part-" + std::to_string(partition_id);
auto it = containers_.find(label);
Expand Down
33 changes: 33 additions & 0 deletions src/interface/state_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,4 +530,37 @@ StateDescriptor::GetVariableNames(const Metadata::FlagCollection &flags) {
return GetVariableNames({}, flags, {});
}

// Get the total length of this StateDescriptor's variables when packed
int StateDescriptor::GetPackDimension(const std::vector<std::string> &req_names,
const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids) {
std::vector<std::string> names = GetVariableNames(req_names, flags, sparse_ids);
int nvar = 0;
for (auto name : names) {
const auto &meta = metadataMap_[VarID(name)];
int var_len = 0;
if (meta.Shape().size() < 1) {
var_len = 1;
} else {
var_len = meta.Shape()[0];
}
nvar += var_len;
}
return nvar;
}
int StateDescriptor::GetPackDimension(const std::vector<std::string> &req_names,
const std::vector<int> &sparse_ids) {
return GetPackDimension(req_names, Metadata::FlagCollection(), sparse_ids);
}
int StateDescriptor::GetPackDimension(const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids) {
return GetPackDimension({}, flags, sparse_ids);
}
int StateDescriptor::GetPackDimension(const std::vector<std::string> &req_names) {
return GetPackDimension(req_names, Metadata::FlagCollection(), {});
}
int StateDescriptor::GetPackDimension(const Metadata::FlagCollection &flags) {
return GetPackDimension({}, flags, {});
}

} // namespace parthenon
10 changes: 10 additions & 0 deletions src/interface/state_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,16 @@ class StateDescriptor {
std::vector<std::string> GetVariableNames(const std::vector<std::string> &req_names);
std::vector<std::string> GetVariableNames(const Metadata::FlagCollection &flags);

int GetPackDimension(const std::vector<std::string> &req_names,
const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids);
int GetPackDimension(const std::vector<std::string> &req_names,
const std::vector<int> &sparse_ids);
int GetPackDimension(const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids);
int GetPackDimension(const std::vector<std::string> &req_names);
int GetPackDimension(const Metadata::FlagCollection &flags);

std::size_t
RefinementFuncID(const refinement::RefinementFunctions_t &funcs) const noexcept {
return refinementFuncMaps_.funcs_to_ids.at(funcs);
Expand Down
12 changes: 6 additions & 6 deletions src/mesh/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, Packages_t &packages,
// private members:
num_mesh_threads_(pin->GetOrAddInteger("parthenon/mesh", "num_threads", 1)),
tree(this), use_uniform_meshgen_fn_{true, true, true, true}, lb_flag_(true),
lb_automatic_(),
lb_manual_(), MeshGenerator_{nullptr, UniformMeshGeneratorX1,
UniformMeshGeneratorX2, UniformMeshGeneratorX3},
lb_automatic_(), lb_manual_(),
MeshGenerator_{nullptr, UniformMeshGeneratorX1, UniformMeshGeneratorX2,
UniformMeshGeneratorX3},
MeshBndryFnctn{nullptr, nullptr, nullptr, nullptr, nullptr, nullptr} {
std::stringstream msg;
RegionSize block_size;
Expand Down Expand Up @@ -541,9 +541,9 @@ Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, RestartReader &rr,
// private members:
num_mesh_threads_(pin->GetOrAddInteger("parthenon/mesh", "num_threads", 1)),
tree(this), use_uniform_meshgen_fn_{true, true, true, true}, lb_flag_(true),
lb_automatic_(),
lb_manual_(), MeshGenerator_{nullptr, UniformMeshGeneratorX1,
UniformMeshGeneratorX2, UniformMeshGeneratorX3},
lb_automatic_(), lb_manual_(),
MeshGenerator_{nullptr, UniformMeshGeneratorX1, UniformMeshGeneratorX2,
UniformMeshGeneratorX3},
MeshBndryFnctn{nullptr, nullptr, nullptr, nullptr, nullptr, nullptr} {
std::stringstream msg;
RegionSize block_size;
Expand Down

0 comments on commit 94984ec

Please sign in to comment.