Skip to content

Commit

Permalink
Add wandb logger
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Sep 3, 2024
1 parent fddc037 commit fd6b344
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/qusi/internal/lightning_train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from warnings import warn

import lightning
from lightning.pytorch.loggers import WandbLogger
from torch.nn import BCELoss, Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -66,6 +67,8 @@ def train_session(
system_configuration = TrainSystemConfiguration.new()
if loss_metric is None:
loss_metric = BCELoss()
if logging_configuration is None:
logging_configuration = TrainLoggingConfiguration.new()
if logging_metrics is None:
logging_metrics = [BinaryAccuracy(), BinaryAUROC()]

Expand Down Expand Up @@ -100,9 +103,11 @@ def train_session(

lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric,
logging_metrics=logging_metrics)
wandb_logger = WandbLogger(project=logging_configuration.wandb_project, entity=logging_configuration.wandb_entity)
trainer = lightning.Trainer(
max_epochs=hyperparameter_configuration.cycles,
limit_train_batches=hyperparameter_configuration.train_steps_per_cycle,
limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle,
logger=[wandb_logger]
)
trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders)

0 comments on commit fd6b344

Please sign in to comment.