Skip to content

Commit

Permalink
Add binary AUROC metric
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed May 9, 2024
1 parent 95a5b90 commit d8945e8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ dependencies = [
"sphinx>=6.1.3",
"backports.strenum",
"typing_extensions",
"myst-parser"
"myst-parser",
"torcheval>=0.0.7",
]

[build-system]
Expand Down
4 changes: 2 additions & 2 deletions src/qusi/internal/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.nn import BCELoss, Module
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchmetrics.classification import BinaryAccuracy
from torcheval.metrics import BinaryAccuracy, BinaryAUROC

import wandb
from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset
Expand Down Expand Up @@ -49,7 +49,7 @@ def train_session(
if loss_function is None:
loss_function = BCELoss()
if metric_functions is None:
metric_functions = [BinaryAccuracy()]
metric_functions = [BinaryAccuracy(), BinaryAUROC()]
set_up_default_logger()
wandb_init(
process_rank=0,
Expand Down

0 comments on commit d8945e8

Please sign in to comment.