Skip to content

Commit

Permalink
- Update Docstrings following previous changes related to local addit…
Browse files Browse the repository at this point in the history
…ional code file support

- Rename "source" (DataSource) to "location" (str) in __init__ for AdditionalCode and InputDataset.
  The method uses this arg to create the "source" attribute using source=DataSource(location).
  This avoids requiring the user to create and supply a DataSource object which itself is just
  instantiated from a location string.
- Rename "local_path" to "local_dir" in AdditionalCode.get to avoid confusion with AdditionalCode.local_path
- Add an attribute AdditionalCode.modified_namelists to keep track of edited namelists based on templates
- Add logic in AdditionalCode.get() that makes copies of any namelist files with the suffix _TEMPLATE
  to a file without this suffix and adds it to AdditionalCode.modified_namelists, e.g.:
  AdditionalCode.namelists = ["roms.in_TEMPLATE",] -> AdditionalCode.modified_namelists = ["roms.in",]
- Update ROMSComponent.pre_run() and ROMSComponent.run() to reflect this namelist handling
- Minor changes to __str__ in component.py (addressing inconsistencies)
- Update cstar_example_notebook.ipynb to reflect all of the above changes
  • Loading branch information
“Dafydd committed Aug 29, 2024
1 parent 7d14f7f commit 64b1541
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 279 deletions.
116 changes: 72 additions & 44 deletions cstar/base/additional_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,73 @@ class AdditionalCode:
"""
Additional code contributing to a unique instance of a base model, e.g. namelists, source modifications, etc.
Additional code is assumed to be kept in a git-controlled repository (`source_repo`), and obtaining the code
is handled by git commands.
Additional code is assumed to be kept in a single directory or repository (described by the `source` attribute)
with this structure:
<additional_code_dir>
├── namelists
| └ <base_model_name>
| ├ <namelist_file_1>
| | ...
| └ <namelist_file_N>
└── source_mods
└ <base_model_name>
├ <source_code_file_1>
| ...
└ <source_code_file_N>
Attributes:
-----------
base_model: BaseModel
The base model with which this additional code is associated
source_repo: str
URL pointing to a git-controlled repository containing the additional code
checkout_target: str
A tag, git hash, or other target to check out the source repo at the correct point in its history
source_mods: str or list of strs
Path(s) from the top level of `source_repo` to any code that is needed to compile a unique instance of the base model
source: DataSource
Describes the location and type of source data (e.g. repository,directory)
checkout_target: Optional, str
Used if source.source_type is 'repository'. A tag, git hash, or other target to check out.
source_mods: Optional, str or list of strs
Path(s) relative to the top level of `source.location` to any code that is needed to compile a unique instance of the base model
namelists: str or list of strs
Path(s) from the top level of `source_repo` to any code that is needed at runtime for the base model
Path(s) relative to the top level of `source.location` to any code that is needed at runtime for the base model
exists_locally: bool, default None
True if the additional code has been fetched to the local machine, set when `check_exists_locally()` method is called
Set to True if source.location_type is 'path', or if AdditionalCode.get() has been called.
Is also set by the `check_exists_locally()` method.
local_path: str, default None
The path to where the additional code has been fetched locally, set when the `get()` method is called
The local path to the additional code. Set when `get()` method is called, or if source.location_type is 'path'.
Methods:
--------
get(local_path):
Clone the `source_repo` repository to a temporary directory, checkout `checkout_target`,
and move files associated with this AdditionalCode instance to `local_path`.
check_exists_locally(local_path):
Verify whether the files associated with this AdditionalCode instance can be found at `local_path`
get(local_dir):
Fetch the directory containing this additional code and copy it to `local_dir`.
If source.source_type is 'repository', and source.location_type is 'url',
clone repository to a temporary directory, checkout `checkout_target`,
and move files associated with this AdditionalCode instance to `local_dir`.
check_exists_locally(local_dir):
Verify whether the files associated with this AdditionalCode instance can be found at `local_dir`
"""

def __init__(
self,
base_model: BaseModel,
source: DataSource,
location: str,
checkout_target: Optional[str] = None,
source_mods: Optional[List[str]] = None,
namelists: Optional[List[str]] = None,
):
"""
Initialize an AdditionalCode object from a repository URL and a list of code files
Initialize an AdditionalCode object from a DataSource and a list of code files
Parameters:
-----------
base_model: BaseModel
The base model with which this additional code is associated
source_repo: str
URL pointing to a git-controlled repository containing the additional code
checkout_target: str
A tag, git hash, or other target to check out the source repo at the correct point in its history
source_mods: str or list of strs
Path(s) from the top level of `source_repo` to any code that is needed to compile a unique instance of the base model
namelists: str or list of strs
Path(s) from the top level of `source_repo` to any code that is needed at runtime for the base model
location: str
url or path pointing to the additional code directory or repository, used to set `source` attribute
checkout_target: Optional, str
Used if source.source_type is 'repository'. A tag, git hash, or other target to check out.
source_mods: Optional, str or list of strs
Path(s) relative to the top level of `source.location` to any code that is needed to compile a unique instance of the base model
namelists: Optional, str or list of strs
Path(s) relative to the top level of `source.location` to any code that is needed at runtime for the base model
Returns:
--------
Expand All @@ -74,13 +88,18 @@ def __init__(

# TODO: Type check here
self.base_model: BaseModel = base_model
self.source: DataSource = source
self.source: DataSource = DataSource(location)
self.checkout_target: Optional[str] = checkout_target
self.source_mods: Optional[List[str]] = source_mods
self.namelists: Optional[List[str]] = namelists
self.exists_locally: Optional[bool] = None
self.local_path: Optional[str] = None

# If there are namelists, make a parallel attribute to keep track of the ones we are editing
# AdditionalCode.get() determines which namelists are editable templates and updates this list
if self.namelists:
self.modified_namelists: list = []

if self.source.location_type == "path":
self.exists_locally = True
self.local_path = self.source.location
Expand All @@ -91,38 +110,33 @@ def __str__(self):
)
base_str += "\n---------------------"
base_str += f"\nBase model: {self.base_model.name}"
# FIXME update after sorting all this ish out
# base_str += f"\nAdditional code repository URL: {self.source_repo} (checkout target: {self.checkout_target})"
base_str += f"\nLocation: {self.source.location}"
if self.exists_locally is not None:
base_str += f"\n Exists locally: {self.exists_locally}"
if self.local_path is not None:
base_str += f"\n Local path: {self.local_path}"
if self.source_mods is not None:
base_str += "\nSource code modification files (paths relative to repository top level):"
base_str += (
"\nSource code modification files (paths relative to above location)):"
)
for filename in self.source_mods:
base_str += f"\n {filename}"
if self.namelists is not None:
base_str += "\nNamelist files (paths relative to repository top level):"
base_str += "\nNamelist files (paths relative to above location):"
for filename in self.namelists:
base_str += f"\n {filename}"
if filename[-9:]=="_TEMPLATE":
base_str+=f" ({filename[:-9]} will be used by C-Star based on this template)"
return base_str

def __repr__(self):
return self.__str__()

def get(self, local_dir: str):
"""
Clone `source_repo` into a temporary directory and move required files to `local_dir`.
This method:
1. Clones the `source_repo` repository into a temporary directory (deleted after call)
2. Checks out the `checkout_target` (a tag or commit hash) to move to the correct point in the commit history
3. Loops over the paths described in `source_mods` and `namelists` and
moves those files to `local_dir/source_mods/base_model.name/` and `local_dir/namelists/base_model.name`,
respectively.
Copy the required AdditionalCode files to `local_dir`
Clone the `source_repo` repository to a temporary directory, checkout `checkout_target`,
and move files associated with this AdditionalCode instance to `local_dir`.
If AdditionalCode.source describes a remote repository, this is cloned into a temporary directory first.
Parameters:
-----------
Expand All @@ -133,6 +147,7 @@ def get(self, local_dir: str):
try:
tmp_dir = None # initialise the tmp_dir variable in case we need it later

# CASE 1: Additional code is in a remote repository:
if (self.source.location_type == "url") and (
self.source.source_type == "repository"
):
Expand All @@ -151,7 +166,7 @@ def get(self, local_dir: str):
checkout_target=self.checkout_target,
)
source_dir = Path(tmp_dir)

# CASE 2: Additional code is in a local directory/repository
elif (self.source.location_type == "path") and (
(self.source.source_type == "directory")
or (self.source.source_type == "repository")
Expand All @@ -165,9 +180,10 @@ def get(self, local_dir: str):
+ "AdditionalCode.source.source_type should be "
+ "'url' and 'repository', or 'path' and 'repository', or"
+ "'path' and 'directory', not"
+ f"{self.source.location_type} and {self.source.source_type}"
+ f"'{self.source.location_type}' and '{self.source.source_type}'"
)

# Now go through the files and copy them to local_dir
for file_type in ["source_mods", "namelists"]:
file_list = getattr(self, file_type)

Expand All @@ -186,6 +202,18 @@ def get(self, local_dir: str):
raise FileNotFoundError(
f"Error: {src_file_path} does not exist."
)
# Special case for template namelists:
if (
file_type == "namelists"
and str(src_file_path)[-9:] == "_TEMPLATE"
):
print(
f"copying {tgt_file_path} to editable namelist {str(tgt_file_path)[:-9]}"
)
shutil.copy(tgt_file_path, str(tgt_file_path)[:-9])
if hasattr(self, "modified_namelists"):
self.modified_namelists.append(f[:-9])

self.local_path = local_dir
self.exists_locally = True
finally:
Expand Down
7 changes: 4 additions & 3 deletions cstar/base/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,15 @@ def __str__(self):
if isinstance(self.input_datasets, InputDataset)
else 0
)
base_str += f"\n{NAC} AdditionalCode repositories (query using ROMSComponent.additional_code)"
base_str += f"\n{NAC} AdditionalCode repositories (query using Component.additional_code)"
base_str += (
f"\n{NID} InputDataset objects (query using ROMSComponent.input_datasets"
f"\n{NID} InputDataset objects (query using Component.input_datasets"
)

#Discretisation
disc_str = ""
if hasattr(self, "time_step") and self.time_step is not None:
disc_str += "\ntime_step: " + str(self.time_step)
disc_str += "\ntime_step: " + str(self.time_step) +"s"
if hasattr(self, "n_procs_x") and self.n_procs_x is not None:
disc_str += (
"\nn_procs_x: "
Expand Down
12 changes: 6 additions & 6 deletions cstar/base/input_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class InputDataset(ABC):
def __init__(
self,
base_model: "BaseModel",
source: DataSource,
location: str,
file_hash: Optional[str] = None,
start_date: Optional[str | dt.datetime] = None,
end_date: Optional[str | dt.datetime] = None,
Expand All @@ -51,16 +51,16 @@ def __init__(
-----------
base_model: BaseModel
The base model with which this input dataset is associated
source: str
URL or path pointing to the netCDF file containing this input dataset
file_hash: str
location: str
URL or path pointing to a file either containing this dataset or instructions for creating it.
Used to set the `source` attribute.
file_hash: str, optional
The 256 bit SHA sum associated with the file for verification
"""

self.base_model: "BaseModel" = base_model

self.source: DataSource = source
self.source: DataSource = DataSource(location)
self.file_hash: Optional[str] = file_hash

if (self.file_hash is None) and (self.source.location_type == "url"):
Expand Down
14 changes: 7 additions & 7 deletions cstar/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def from_blueprint(
else:
additional_code_info = component_info["component"]["additional_code"]

source = DataSource(additional_code_info.get("location"))
location = additional_code_info.get("location")
checkout_target = additional_code_info.get("checkout_target", None)
source_mods = (
[f for f in additional_code_info["source_mods"]]
Expand All @@ -359,7 +359,7 @@ def from_blueprint(

additional_code = AdditionalCode(
base_model=base_model,
source=source,
location=location,
checkout_target=checkout_target,
source_mods=source_mods,
namelists=namelists,
Expand All @@ -381,7 +381,7 @@ def from_blueprint(
model_grid = [
ROMSModelGrid(
base_model=base_model,
source=DataSource(f.get("location")),
location=f.get("location"),
file_hash=f.get("hash", None),
)
for f in input_dataset_info["model_grid"]["files"]
Expand All @@ -397,7 +397,7 @@ def from_blueprint(
initial_conditions = [
ROMSInitialConditions(
base_model=base_model,
source=DataSource(f.get("location")),
location=f.get("location"),
file_hash=f.get("hash", None),
start_date=f.get("start_date", None),
end_date=f.get("end_date", None),
Expand All @@ -417,7 +417,7 @@ def from_blueprint(
tidal_forcing = [
ROMSTidalForcing(
base_model=base_model,
source=DataSource(f.get("location")),
location=f.get("location"),
file_hash=f.get("hash", None),
)
for f in input_dataset_info["tidal_forcing"]["files"]
Expand All @@ -435,7 +435,7 @@ def from_blueprint(
boundary_forcing = [
ROMSBoundaryForcing(
base_model=base_model,
source=DataSource(f.get("location")),
location=f.get("location"),
file_hash=f.get("hash", None),
start_date=f.get("start_date", None),
end_date=f.get("end_date", None),
Expand All @@ -455,7 +455,7 @@ def from_blueprint(
surface_forcing = [
ROMSSurfaceForcing(
base_model=base_model,
source=DataSource(f.get("location")),
location=f.get("location"),
file_hash=f.get("hash", None),
start_date=f.get("start_date", None),
end_date=f.get("end_date", None),
Expand Down
Loading

0 comments on commit 64b1541

Please sign in to comment.