From d15a6f9f73106233e2bd15a5f9453e95c514a5b5 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 2 Sep 2024 23:00:57 -0400 Subject: [PATCH] Add validation logging --- src/qusi/internal/module.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index 506b58f..cac86a4 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -57,7 +57,7 @@ def __init__( self._functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(len(self.functional_logging_metrics), dtype=np.float32) self._loss_cycle_total: int = 0 - self._steps_run_in_cycle: int = 0 + self._steps_run_in_phase: int = 0 def forward(self, inputs: Any) -> Any: return self.model(inputs) @@ -76,7 +76,7 @@ def compute_loss_and_metrics(self, batch): functional_logging_metric_value = functional_logging_metric(predicted, target) self._functional_logging_metric_cycle_totals[ functional_logging_metric_index] += functional_logging_metric_value - self._steps_run_in_cycle += 1 + self._steps_run_in_phase += 1 return loss def on_train_epoch_end(self) -> None: @@ -93,19 +93,22 @@ def log_loss_and_metrics(self, logging_name_prefix: str = ''): functional_logging_metric_cycle_total = float(self._functional_logging_metric_cycle_totals[ functional_logging_metric_index]) - functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_cycle + functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_phase self.log(name=logging_name_prefix + functional_logging_metric_name, value=functional_logging_metric_cycle_mean, sync_dist=True) - mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_cycle + mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_phase self.log(name=logging_name_prefix + 'loss', value=mean_cycle_loss, sync_dist=True) self._loss_cycle_total = 0 self._functional_logging_metric_cycle_totals = np.zeros(len(self.functional_logging_metrics), dtype=np.float32) - self._steps_run_in_cycle = 0 + self._steps_run_in_phase = 0 def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: return self.compute_loss_and_metrics(batch) + def on_validation_epoch_end(self) -> None: + self.log_loss_and_metrics(logging_name_prefix='val_') + def configure_optimizers(self): return self._optimizer