Skip to content

Commit

Permalink
Move to working metric groups
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Oct 2, 2024
1 parent 626e61b commit cd77001
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/qusi/internal/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def new(
model: Module,
optimizer: Optimizer | None,
loss_metric: Module | None = None,
logging_metrics: list[Module] | None = None,
logging_metrics: ModuleList | None = None,
) -> Self:
if optimizer is None:
optimizer = AdamW(model.parameters())
Expand All @@ -66,21 +66,21 @@ def new(
train_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics)
validation_metric_group = MetricGroup.new(loss_metric, state_based_logging_metrics, functional_logging_metrics)
instance = cls(model=model, optimizer=optimizer, train_metric_group=train_metric_group,
validation_metric_groups=[validation_metric_group])
validation_metric_groups=ModuleList([validation_metric_group]))
return instance

def __init__(
self,
model: Module,
optimizer: Optimizer,
train_metric_group: MetricGroup,
validation_metric_groups: list[MetricGroup],
validation_metric_groups: ModuleList,
):
super().__init__()
self.model: Module = model
self._optimizer: Optimizer = optimizer
self.train_metric_group = train_metric_group
self.validation_metric_groups: list[MetricGroup] = validation_metric_groups
self.train_metric_group: MetricGroup = train_metric_group
self.validation_metric_groups: ModuleList | list[MetricGroup] = validation_metric_groups

def forward(self, inputs: Any) -> Any:
return self.model(inputs)
Expand All @@ -104,37 +104,37 @@ def compute_loss_and_metrics(self, batch: tuple[Any, Any], metric_group: MetricG
return loss

def on_train_epoch_end(self) -> None:
self.log_loss_and_metrics()
self.log_loss_and_metrics(self.train_metric_group)

def log_loss_and_metrics(self, logging_name_prefix: str = ''):
for state_based_logging_metric in self.train_state_based_logging_metrics:
def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: str = ''):
for state_based_logging_metric in metric_group.state_based_logging_metrics:
state_based_logging_metric_name = get_metric_name(state_based_logging_metric)
self.log(name=logging_name_prefix + state_based_logging_metric_name,
value=state_based_logging_metric.compute(), sync_dist=True)
state_based_logging_metric.reset()
for functional_logging_metric_index, functional_logging_metric in enumerate(
self.train_functional_logging_metrics):
metric_group.functional_logging_metrics):
functional_logging_metric_name = get_metric_name(functional_logging_metric)
functional_logging_metric_cycle_total = float(self._train_functional_logging_metric_cycle_totals[
functional_logging_metric_cycle_total = float(metric_group.functional_logging_metric_cycle_totals[
functional_logging_metric_index])

functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._train_steps_run_in_phase
functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / metric_group.steps_run_in_phase
self.log(name=logging_name_prefix + functional_logging_metric_name,
value=functional_logging_metric_cycle_mean,
sync_dist=True)
mean_cycle_loss = self._train_loss_cycle_total / self._train_steps_run_in_phase
mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True)
self._train_loss_cycle_total = 0
self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics),
metric_group.loss_cycle_total = 0
metric_group.functional_logging_metric_cycle_totals = np.zeros(len(metric_group.functional_logging_metrics),
dtype=np.float32)
self._train_steps_run_in_phase = 0
metric_group.steps_run_in_phase = 0

def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch, self.validation_metric_groups[0])

def on_validation_epoch_end(self) -> None:
self.log_loss_and_metrics(logging_name_prefix='val_')
self.log_loss_and_metrics(self.validation_metric_groups[0], logging_name_prefix='val_')

def configure_optimizers(self):
return self._optimizer
28 changes: 28 additions & 0 deletions tests/end_to_end_tests/test_toy_lightning_train_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from functools import partial

from qusi.internal.light_curve_dataset import (
default_light_curve_observation_post_injection_transform,
)
from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel
from qusi.internal.toy_light_curve_collection import get_toy_dataset
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration
from qusi.internal.lightning_train_session import train_session


def test_toy_train_session():
os.environ["WANDB_MODE"] = "disabled"
model = SingleDenseLayerBinaryClassificationModel.new(input_size=100)
dataset = get_toy_dataset()
dataset.post_injection_transform = partial(
default_light_curve_observation_post_injection_transform, length=100
)
train_hyperparameter_configuration = TrainHyperparameterConfiguration.new(
batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5
)
train_session(
train_datasets=[dataset],
validation_datasets=[dataset],
model=model,
hyperparameter_configuration=train_hyperparameter_configuration,
)

0 comments on commit cd77001

Please sign in to comment.