Skip to content

Commit

Permalink
Rename the all metrics to train metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Sep 25, 2024
1 parent d15a6f9 commit 5b4c90e
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/qusi/internal/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def __init__(
self.model: Module = model
self._optimizer: Optimizer = optimizer
self.loss_metric: Module = loss_metric
self.state_based_logging_metrics: ModuleList = state_based_logging_metrics
self.functional_logging_metrics: list[Module] = functional_logging_metrics
self._functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(len(self.functional_logging_metrics),
dtype=np.float32)
self.train_state_based_logging_metrics: ModuleList = state_based_logging_metrics
self.train_functional_logging_metrics: list[Module] = functional_logging_metrics
self._train_functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(
len(self.train_functional_logging_metrics), dtype=np.float32)
self._loss_cycle_total: int = 0
self._steps_run_in_phase: int = 0

Expand All @@ -70,11 +70,11 @@ def compute_loss_and_metrics(self, batch):
predicted = self(inputs)
loss = self.loss_metric(predicted, target)
self._loss_cycle_total += loss
for state_based_logging_metric in self.state_based_logging_metrics:
for state_based_logging_metric in self.train_state_based_logging_metrics:
state_based_logging_metric(predicted, target)
for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics):
for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics):
functional_logging_metric_value = functional_logging_metric(predicted, target)
self._functional_logging_metric_cycle_totals[
self._train_functional_logging_metric_cycle_totals[
functional_logging_metric_index] += functional_logging_metric_value
self._steps_run_in_phase += 1
return loss
Expand All @@ -83,14 +83,14 @@ def on_train_epoch_end(self) -> None:
self.log_loss_and_metrics()

def log_loss_and_metrics(self, logging_name_prefix: str = ''):
for state_based_logging_metric in self.state_based_logging_metrics:
for state_based_logging_metric in self.train_state_based_logging_metrics:
state_based_logging_metric_name = get_metric_name(state_based_logging_metric)
self.log(name=logging_name_prefix + state_based_logging_metric_name,
value=state_based_logging_metric.compute(), sync_dist=True)
state_based_logging_metric.reset()
for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics):
for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics):
functional_logging_metric_name = get_metric_name(functional_logging_metric)
functional_logging_metric_cycle_total = float(self._functional_logging_metric_cycle_totals[
functional_logging_metric_cycle_total = float(self._train_functional_logging_metric_cycle_totals[
functional_logging_metric_index])

functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_phase
Expand All @@ -101,7 +101,7 @@ def log_loss_and_metrics(self, logging_name_prefix: str = ''):
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True)
self._loss_cycle_total = 0
self._functional_logging_metric_cycle_totals = np.zeros(len(self.functional_logging_metrics), dtype=np.float32)
self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics), dtype=np.float32)
self._steps_run_in_phase = 0

def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
Expand Down

0 comments on commit 5b4c90e

Please sign in to comment.