diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff3fb916..d31975ec 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -44,6 +44,7 @@ def __init__( config: DictConfig, graph_data: HeteroData, statistics: dict, + statistics_tendencies: dict, data_indices: IndexCollection, metadata: dict, ) -> None: @@ -57,6 +58,8 @@ def __init__( Graph object statistics : dict Statistics of the training data + statistics_tendencies : dict + Statistics of the training data tendencies data_indices : IndexCollection Indices of the training data, metadata : dict @@ -69,12 +72,21 @@ def __init__( self.model = AnemoiModelInterface( statistics=statistics, + statistics_tendencies=statistics_tendencies, data_indices=data_indices, metadata=metadata, graph_data=graph_data, config=DotDict(map_config_to_primitives(OmegaConf.to_container(config, resolve=True))), ) + # Flexible stepping function definition + self.step_functions = { + "residual": self._step_residual, + "tendency": self._step_tendency, + } + self.prediction_mode = "tendency" if self.model.tendency_mode else "residual" + LOGGER.info("Using stepping mode: %s", self.prediction_mode) + self.data_indices = data_indices self.save_hyperparameters() @@ -84,8 +96,16 @@ def __init__( self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled + # TODO (rilwan-ade): restructure this so that as the feature weighting - it can be configurable loaded in from a "get_loss_scaling" function + # use method in other branch + tendency_variance = ( + torch.from_numpy(self.model.statistics_tendencies["stdev"][self.data_indices.data.output.full]) + if self.model.tendency_mode + else None + ) + self.metric_ranges, loss_scaling = self.metrics_loss_scaling(config, data_indices) - self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling) + self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling, tendency_variances=tendency_variance) self.metrics = WeightedMSELoss(node_weights=self.loss_weights, ignore_nans=True) if config.training.loss_gradient_scaling: @@ -187,8 +207,7 @@ def advance_input( x[:, -1, :, :, self.data_indices.model.input.forcing] = batch[ :, self.multi_step + rollout_step, - :, - :, + ..., self.data_indices.data.input.forcing, ] return x @@ -198,10 +217,19 @@ def _step( batch: torch.Tensor, batch_idx: int, validation_mode: bool = False, + in_place_proc: bool = True, + ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: + return self.step_functions[self.prediction_mode](batch, batch_idx, validation_mode, in_place_proc) + + def _step_residual( + self, + batch: torch.Tensor, + batch_idx: int, + validation_mode: bool = False, + in_place_proc: bool = True, ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: - del batch_idx loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) - batch = self.model.pre_processors(batch) # normalized in-place + batch = self.model.pre_processors_state(batch, in_place=in_place_proc) # normalized in-place metrics = {} # start rollout @@ -210,6 +238,7 @@ def _step( y_preds = [] for rollout_step in range(self.rollout): # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) + # if rollout_step > 0: torch.cuda.empty_cache() # uncomment if rollout fails with OOM y_pred = self(x) y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full] @@ -219,38 +248,89 @@ def _step( x = self.advance_input(x, y_pred, batch, rollout_step) if validation_mode: - metrics_next, y_preds_next = self.calculate_val_metrics( - y_pred, - y, - rollout_step, - enable_plot=self.enable_plot, - ) + metrics_next, y_preds_next = self.calculate_val_metrics(y_pred, y, rollout_step, enable_plot=self.enable_plot) metrics.update(metrics_next) y_preds.extend(y_preds_next) # scale loss loss *= 1.0 / self.rollout + return loss, metrics, y_preds - def calculate_val_metrics( + def _step_tendency( self, - y_pred: torch.Tensor, - y: torch.Tensor, - rollout_step: int, - enable_plot: bool = False, - ) -> tuple[dict, list]: + batch: torch.Tensor, + batch_idx: int, + validation_mode: bool = False, + in_place_proc: bool = True, + ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} + + # x ( non-processed) + x = batch[:, 0 : self.multi_step, ..., self.data_indices.data.input.full] # (bs, multi_step, latlon, nvar) + y_preds = [] - y_postprocessed = self.model.post_processors(y, in_place=False) - y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) - for mkey, indices in self.metric_ranges.items(): - metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics( - y_pred_postprocessed[..., indices], - y_postprocessed[..., indices], + for rollout_step in range(self.rollout): + + # normalise inputs + x_in = self.model.pre_processors_state(x, in_place=False, data_index=self.data_indices.data.input.full) + + # prediction (normalized tendency) + tendency_pred = self(x_in) + + # re-construct non-processed predicted state + y_pred = self.model.add_tendency_to_state(x[:, -1, ...], tendency_pred) + + # Target is full state + y_target = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full] + + # calculate loss + loss += checkpoint( + self.loss, + self.model.pre_processors_state(y_pred, in_place=False, data_index=self.data_indices.data.output.full), + self.model.pre_processors_state(y_target, in_place=False, data_index=self.data_indices.data.output.full), + use_reentrant=False, ) + # TODO: We should try that too + # loss += checkpoint(self.loss, y_pred, y_target, use_reentrant=False) + + # advance input using non-processed x, y_pred and batch + x = self.advance_input(x, y_pred, batch, rollout_step) + + if validation_mode: + # calculate_val_metrics requires processed inputs + metrics_next, _ = self.calculate_val_metrics( + None, + None, + rollout_step, + self.enable_plot, + y_pred_postprocessed=y_pred, + y_postprocessed=y_target, + ) + + metrics.update(metrics_next) + + y_preds.extend(y_pred) + + # scale loss + loss *= 1.0 / self.rollout + + return loss, metrics, y_preds + + def calculate_val_metrics(self, y_pred, y, rollout_step, enable_plot=False, y_pred_postprocessed=None, y_postprocessed=None): + metrics = {} + y_preds = [] + if y_postprocessed is None: + y_postprocessed = self.model.post_processors_state(y, in_place=False) + if y_pred_postprocessed is None: + y_pred_postprocessed = self.model.post_processors_state(y_pred, in_place=False) + + for mkey, indices in self.metric_ranges.items(): + metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics(y_pred_postprocessed[..., indices], y_postprocessed[..., indices]) if enable_plot: - y_preds.append(y_pred) + y_preds.append(y_pred_postprocessed) return metrics, y_preds def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: