Skip to content

Commit

Permalink
Switch to metric groups
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Oct 2, 2024
1 parent 5b4c90e commit 626e61b
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 29 deletions.
84 changes: 55 additions & 29 deletions src/qusi/internal/module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Any

import numpy as np
Expand All @@ -13,6 +14,33 @@
from qusi.internal.logging import get_metric_name


class MetricGroup(Module):
def __init__(self, loss_metric: Module, state_based_logging_metrics: ModuleList,
functional_logging_metrics: ModuleList):
super().__init__()
self.loss_metric: Module = loss_metric
self.state_based_logging_metrics: ModuleList = state_based_logging_metrics
self.functional_logging_metrics: ModuleList = functional_logging_metrics
self.loss_cycle_total: float = 0
self.steps_run_in_phase: int = 0
self.functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(
len(self.functional_logging_metrics), dtype=np.float32)

@classmethod
def new(
cls,
loss_metric: Module,
state_based_logging_metrics: ModuleList,
functional_logging_metrics: ModuleList
) -> Self:
loss_metric_: Module = copy.deepcopy(loss_metric)
state_based_logging_metrics_: ModuleList = copy.deepcopy(state_based_logging_metrics)
functional_logging_metrics_: ModuleList = copy.deepcopy(functional_logging_metrics)
instance = cls(loss_metric=loss_metric_, state_based_logging_metrics=state_based_logging_metrics_,
functional_logging_metrics=functional_logging_metrics_)
return instance


class QusiLightningModule(LightningModule):
@classmethod
def new(
Expand All @@ -29,54 +57,50 @@ def new(
if logging_metrics is None:
logging_metrics = [BinaryAccuracy(), BinaryAUROC()]
state_based_logging_metrics: ModuleList = ModuleList()
functional_logging_metrics: list[Module] = []
functional_logging_metrics: ModuleList = ModuleList()
for logging_metric in logging_metrics:
if isinstance(logging_metric, Metric):
state_based_logging_metrics.append(logging_metric)
else:
functional_logging_metrics.append(logging_metric)
instance = cls(model=model, optimizer=optimizer, loss_metric=loss_metric,
state_based_logging_metrics=state_based_logging_metrics,
functional_logging_metrics=functional_logging_metrics)
train_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics)
validation_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics)
instance = cls(model=model, optimizer=optimizer, train_metric_group=train_metric_group,
validation_metric_groups=[validation_metric_group])
return instance

def __init__(
self,
model: Module,
optimizer: Optimizer,
loss_metric: Module,
state_based_logging_metrics: ModuleList,
functional_logging_metrics: list[Module],
train_metric_group: MetricGroup,
validation_metric_groups: list[MetricGroup],
):
super().__init__()
self.model: Module = model
self._optimizer: Optimizer = optimizer
self.loss_metric: Module = loss_metric
self.train_state_based_logging_metrics: ModuleList = state_based_logging_metrics
self.train_functional_logging_metrics: list[Module] = functional_logging_metrics
self._train_functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(
len(self.train_functional_logging_metrics), dtype=np.float32)
self._loss_cycle_total: int = 0
self._steps_run_in_phase: int = 0
self.train_metric_group = train_metric_group
self.validation_metric_groups: list[MetricGroup] = validation_metric_groups

def forward(self, inputs: Any) -> Any:
return self.model(inputs)

def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch)
return self.compute_loss_and_metrics(batch, self.train_metric_group)

def compute_loss_and_metrics(self, batch):
def compute_loss_and_metrics(self, batch: tuple[Any, Any], metric_group: MetricGroup):
inputs, target = batch
predicted = self(inputs)
loss = self.loss_metric(predicted, target)
self._loss_cycle_total += loss
for state_based_logging_metric in self.train_state_based_logging_metrics:
loss = metric_group.loss_metric(predicted, target)
metric_group.loss_cycle_total += loss
for state_based_logging_metric in metric_group.state_based_logging_metrics:
state_based_logging_metric(predicted, target)
for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics):
for functional_logging_metric_index, functional_logging_metric in enumerate(
metric_group.functional_logging_metrics):
functional_logging_metric_value = functional_logging_metric(predicted, target)
self._train_functional_logging_metric_cycle_totals[
metric_group.functional_logging_metric_cycle_totals[
functional_logging_metric_index] += functional_logging_metric_value
self._steps_run_in_phase += 1
metric_group.steps_run_in_phase += 1
return loss

def on_train_epoch_end(self) -> None:
Expand All @@ -88,24 +112,26 @@ def log_loss_and_metrics(self, logging_name_prefix: str = ''):
self.log(name=logging_name_prefix + state_based_logging_metric_name,
value=state_based_logging_metric.compute(), sync_dist=True)
state_based_logging_metric.reset()
for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics):
for functional_logging_metric_index, functional_logging_metric in enumerate(
self.train_functional_logging_metrics):
functional_logging_metric_name = get_metric_name(functional_logging_metric)
functional_logging_metric_cycle_total = float(self._train_functional_logging_metric_cycle_totals[
functional_logging_metric_index])

functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_phase
functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._train_steps_run_in_phase
self.log(name=logging_name_prefix + functional_logging_metric_name,
value=functional_logging_metric_cycle_mean,
sync_dist=True)
mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_phase
mean_cycle_loss = self._train_loss_cycle_total / self._train_steps_run_in_phase
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True)
self._loss_cycle_total = 0
self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics), dtype=np.float32)
self._steps_run_in_phase = 0
self._train_loss_cycle_total = 0
self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics),
dtype=np.float32)
self._train_steps_run_in_phase = 0

def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch)
return self.compute_loss_and_metrics(batch, self.validation_metric_groups[0])

def on_validation_epoch_end(self) -> None:
self.log_loss_and_metrics(logging_name_prefix='val_')
Expand Down
60 changes: 60 additions & 0 deletions tests/unit_tests/test_lightning_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unittest.mock import Mock

import torch
from torch.nn import ModuleList, Module
from torchmetrics import MeanSquaredError

from qusi.internal.module import QusiLightningModule, MetricGroup


class MockStateBasedMetric(Mock, MeanSquaredError):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class MockFunctionalMetric(Mock, Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.side_effect: Module = MeanSquaredError()


def create_fake_qusi_lightning_module() -> QusiLightningModule:
qusi_lightning_module_mock = QusiLightningModule(
model=Mock(return_value=torch.tensor([1])), optimizer=Mock(), train_metric_group=Mock(),
validation_metric_groups=[Mock()]
)
return qusi_lightning_module_mock


def create_fake_metric_group() -> MetricGroup:
fake_metric_group = MetricGroup(loss_metric=Mock(return_value=torch.tensor([1])),
state_based_logging_metrics=ModuleList([MockStateBasedMetric()]),
functional_logging_metrics=ModuleList([MockFunctionalMetric()]))
return fake_metric_group


def test_compute_loss_and_metrics_calls_passed_loss_metric():
fake_qusi_lightning_module0 = create_fake_qusi_lightning_module()
fake_metric_group = create_fake_metric_group()
batch = (torch.tensor([3]), torch.tensor([4]))
assert not fake_metric_group.loss_metric.called
fake_qusi_lightning_module0.compute_loss_and_metrics(batch=batch, metric_group=fake_metric_group)
assert fake_metric_group.loss_metric.called


def test_compute_loss_and_metrics_uses_correct_phase_state_metric():
fake_qusi_lightning_module0 = create_fake_qusi_lightning_module()
fake_metric_group = create_fake_metric_group()
batch = (torch.tensor([3]), torch.tensor([4]))
assert not fake_metric_group.state_based_logging_metrics[0].called
fake_qusi_lightning_module0.compute_loss_and_metrics(batch=batch, metric_group=fake_metric_group)
assert fake_metric_group.state_based_logging_metrics[0].called


def test_compute_loss_and_metrics_uses_correct_phase_functional_metric():
fake_qusi_lightning_module0 = create_fake_qusi_lightning_module()
fake_metric_group = create_fake_metric_group()
batch = (torch.tensor([3]), torch.tensor([4]))
assert not fake_metric_group.functional_logging_metrics[0].called
fake_qusi_lightning_module0.compute_loss_and_metrics(batch=batch, metric_group=fake_metric_group)
assert fake_metric_group.functional_logging_metrics[0].called

0 comments on commit 626e61b

Please sign in to comment.