From a97b4d4b97a44f4a0f701fec72fa191d301146ce Mon Sep 17 00:00:00 2001 From: Rilwan Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> Date: Wed, 4 Sep 2024 07:40:13 +0000 Subject: [PATCH] #47 implement seperate post_processing objects for state and tendency --- src/anemoi/training/config/data/zarr.yaml | 78 +++++++++++++------ src/anemoi/training/data/datamodule.py | 9 +++ .../diagnostics/callbacks/__init__.py | 55 +++---------- 3 files changed, 74 insertions(+), 68 deletions(-) diff --git a/src/anemoi/training/config/data/zarr.yaml b/src/anemoi/training/config/data/zarr.yaml index 27c17edb..ba3b0a7b 100644 --- a/src/anemoi/training/config/data/zarr.yaml +++ b/src/anemoi/training/config/data/zarr.yaml @@ -27,24 +27,45 @@ diagnostic: - tp - cp -normalizer: - default: "mean-std" - min-max: - max: - - "sdor" - - "slor" - - "z" - none: - - "cos_latitude" - - "cos_longitude" - - "sin_latitude" - - "sin_longitude" - - "cos_julian_day" - - "cos_local_time" - - "sin_julian_day" - - "sin_local_time" - - "insolation" - - "lsm" +normalizers: + state: + default: "mean-std" + min-max: + max: + - "sdor" + - "slor" + - "z" + none: + - "cos_latitude" + - "cos_longitude" + - "sin_latitude" + - "sin_longitude" + - "cos_julian_day" + - "cos_local_time" + - "sin_julian_day" + - "sin_local_time" + - "cos_solar_zenith_angle" + - "lsm" + + tendency: + default: "mean-std" + min-max: + max: + - "sdor" + - "slor" + - "z" + none: + - "cos_latitude" + - "cos_longitude" + - "sin_latitude" + - "sin_longitude" + - "cos_julian_day" + - "cos_local_time" + - "sin_julian_day" + - "sin_local_time" + - "cos_solar_zenith_angle" + - "lsm" + imputer: default: "none" @@ -52,13 +73,20 @@ imputer: # processors including imputers and normalizers are applied in order of definition processors: # example_imputer: - # _target_: anemoi.models.preprocessing.imputer.InputImputer - # _convert_: all - # config: ${data.imputer} - normalizer: - _target_: anemoi.models.preprocessing.normalizer.InputNormalizer - _convert_: all - config: ${data.normalizer} + # _target_: anemoi.models.preprocessing.imputer.InputImputer + # _convert_: all + # config: ${data.imputer} + state: + normalizer: + _target_: anemoi.models.preprocessing.normalizer.InputNormalizer + _convert_: all + config: ${data.normalizers.state} + + tendency: + normalizer: + _target_: anemoi.models.preprocessing.normalizer.InputNormalizer + _convert_: all + config: ${data.normalizers.tendency} # Values set in the code num_features: null # number of features in the forecast state diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 4e1e4d1b..1e2debaa 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -105,6 +105,14 @@ def _check_resolution(self, resolution: str) -> None: def statistics(self) -> dict: return self.ds_train.statistics + @cached_property + def statistics_tendencies(self) -> dict: + # This is just a quick fix to work with datasets without stored tendency + # statistics. This should be caught in anemoi-datasets. + if self.config.training.tendency_mode: + return self.ds_train.statistics_tendencies + return None + @cached_property def metadata(self) -> dict: return self.ds_train.metadata @@ -165,6 +173,7 @@ def _get_dataset( rollout=r, multistep=self.config.training.multistep_input, timeincrement=self.timeincrement, + timestep=self.config.data.timestep, model_comm_group_rank=self.model_comm_group_rank, model_comm_group_id=self.model_comm_group_id, model_comm_num_groups=self.model_comm_num_groups, diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 5c0d8b0e..10444c19 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -90,8 +90,8 @@ def __init__(self, config: OmegaConf) -> None: self.config = config self.save_basedir = config.hardware.paths.plots self.plot_frequency = config.diagnostics.plot.frequency - self.post_processors = None - self.pre_processors = None + self.post_processors_state = None + self.pre_processors_state = None self.latlons = None init_plot_settings() @@ -195,40 +195,9 @@ def _eval( pl_module: pl.LightningModule, batch: torch.Tensor, ) -> None: - loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) - # NB! the batch is already normalized in-place - see pl_model.validation_step() - metrics = {} - - # start rollout - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= self.rollout + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" - with torch.no_grad(): - for rollout_step in range(self.rollout): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - y = batch[ - :, - pl_module.multi_step + rollout_step, - ..., - pl_module.data_indices.data.output.full, - ] # target, shape = (bs, latlon, nvar) - # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += pl_module.loss(y_pred, y) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step) - metrics.update(metrics_next) - - # scale loss - loss *= 1.0 / self.rollout + loss, metrics, _ = pl_module._step(batch, validation_mode=True, in_place_proc=False) + self._log(pl_module, loss, metrics, batch.shape[0]) def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: @@ -533,9 +502,9 @@ def _plot( # When running in Async mode, it might happen that in the last epoch these tensors # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA # but internal ones would be on the cpu), The lines below allow to address this problem - if self.post_processors is None: + if self.post_processors_state is None: # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + self.post_processors_state = copy.deepcopy(pl_module.model.post_processors_state).cpu() if self.latlons is None: self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) local_rank = pl_module.local_rank @@ -546,9 +515,9 @@ def _plot( ..., pl_module.data_indices.data.output.full, ].cpu() - data = self.post_processors(input_tensor).numpy() + data = self.post_processors_state(input_tensor).numpy() - output_tensor = self.post_processors( + output_tensor = self.post_processors_state( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ).numpy() @@ -624,9 +593,9 @@ def _plot( if self.pre_processors is None: # Copy to be used across all the training cycle self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() - if self.post_processors is None: + if self.post_processors_state is None: # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + self.post_processors_state = copy.deepcopy(pl_module.model.post_processors_state).cpu() if self.latlons is None: self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) local_rank = pl_module.local_rank @@ -637,8 +606,8 @@ def _plot( ..., pl_module.data_indices.data.output.full, ].cpu() - data = self.post_processors(input_tensor).numpy() - output_tensor = self.post_processors( + data = self.post_processors_state(input_tensor).numpy() + output_tensor = self.post_processors_state( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ).numpy()