diff --git a/src/qusi/internal/lightning_train_session.py b/src/qusi/internal/lightning_train_session.py index 68b87a6..5980d95 100644 --- a/src/qusi/internal/lightning_train_session.py +++ b/src/qusi/internal/lightning_train_session.py @@ -4,6 +4,7 @@ from warnings import warn import lightning +from lightning.pytorch.loggers import WandbLogger from torch.nn import BCELoss, Module from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -66,6 +67,8 @@ def train_session( system_configuration = TrainSystemConfiguration.new() if loss_metric is None: loss_metric = BCELoss() + if logging_configuration is None: + logging_configuration = TrainLoggingConfiguration.new() if logging_metrics is None: logging_metrics = [BinaryAccuracy(), BinaryAUROC()] @@ -100,9 +103,11 @@ def train_session( lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric, logging_metrics=logging_metrics) + wandb_logger = WandbLogger(project=logging_configuration.wandb_project, entity=logging_configuration.wandb_entity) trainer = lightning.Trainer( max_epochs=hyperparameter_configuration.cycles, limit_train_batches=hyperparameter_configuration.train_steps_per_cycle, limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle, + logger=[wandb_logger] ) trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders)