Skip to content

Commit

Permalink
Add validation logging
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Sep 3, 2024
1 parent fd6b344 commit d15a6f9
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/qusi/internal/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self._functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(len(self.functional_logging_metrics),
dtype=np.float32)
self._loss_cycle_total: int = 0
self._steps_run_in_cycle: int = 0
self._steps_run_in_phase: int = 0

def forward(self, inputs: Any) -> Any:
return self.model(inputs)
Expand All @@ -76,7 +76,7 @@ def compute_loss_and_metrics(self, batch):
functional_logging_metric_value = functional_logging_metric(predicted, target)
self._functional_logging_metric_cycle_totals[
functional_logging_metric_index] += functional_logging_metric_value
self._steps_run_in_cycle += 1
self._steps_run_in_phase += 1
return loss

def on_train_epoch_end(self) -> None:
Expand All @@ -93,19 +93,22 @@ def log_loss_and_metrics(self, logging_name_prefix: str = ''):
functional_logging_metric_cycle_total = float(self._functional_logging_metric_cycle_totals[
functional_logging_metric_index])

functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_cycle
functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_phase
self.log(name=logging_name_prefix + functional_logging_metric_name,
value=functional_logging_metric_cycle_mean,
sync_dist=True)
mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_cycle
mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_phase
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True)
self._loss_cycle_total = 0
self._functional_logging_metric_cycle_totals = np.zeros(len(self.functional_logging_metrics), dtype=np.float32)
self._steps_run_in_cycle = 0
self._steps_run_in_phase = 0

def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch)

def on_validation_epoch_end(self) -> None:
self.log_loss_and_metrics(logging_name_prefix='val_')

def configure_optimizers(self):
return self._optimizer

0 comments on commit d15a6f9

Please sign in to comment.