Skip to content

Commit

Permalink
Use balanced accuracy as base metric
Browse files Browse the repository at this point in the history
  • Loading branch information
JKomorniczak committed Jan 19, 2024
1 parent d26c542 commit 860e913
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ ax.plot(results.T, label=['Inner', 'Outer', 'Halfpoint', 'Overall'])
ax.legend()
ax.grid(ls=':')
ax.set_xlabel('epochs')
ax.set_ylabel('Weighted accurracy')
ax.set_ylabel('Balanced accurracy')
ax.set_xlim(0,epochs)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ The results of the single processing can be visualized using `matplotlib` librar
ax.legend()
ax.grid(ls=':')
ax.set_xlabel('epochs')
ax.set_ylabel('Weighted accurracy')
ax.set_ylabel('Balanced accurracy')
ax.set_xlim(0,epochs)
.. image:: _static/example.png
Expand Down
Binary file modified example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion torchosr/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""``torchosr``"""

__version__ = "0.1.2"
__version__ = "0.1.3"
9 changes: 5 additions & 4 deletions torchosr/models/GSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ def test(self, dataloader, loss_fn, conf=False):
outer_targets = []
overall_targets = []

# Define metric for open set detection
# Define metrics for open set detection
try:
inner_metric = Accuracy(task="multiclass", num_classes=self.n_known, average='weighted')
overall_metric = Accuracy(task="multiclass", num_classes=self.n_known+1, average='weighted')
inner_metric = Accuracy(task="multiclass", num_classes=self.n_known, average='macro')
overall_metric = Accuracy(task="multiclass", num_classes=self.n_known+1, average='macro')
except:
inner_metric = None
overall_metric = None
outer_metric = Accuracy(task='binary', average='weighted')

outer_metric = Accuracy(task='multiclass', num_classes=2, average='macro')

with torch.no_grad():
for X, y in dataloader:
Expand Down
6 changes: 3 additions & 3 deletions torchosr/models/Openmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ def test(self, dataloader, loss_fn, conf=False):

# Define metrics for open set detection
try:
inner_metric = Accuracy(task="multiclass", num_classes=self.n_known, average='weighted')
overall_metric = Accuracy(task="multiclass", num_classes=self.n_known+1, average='weighted')
inner_metric = Accuracy(task="multiclass", num_classes=self.n_known, average='macro')
overall_metric = Accuracy(task="multiclass", num_classes=self.n_known+1, average='macro')
except:
inner_metric = None
overall_metric = None

outer_metric = Accuracy(task='binary', average='weighted')
outer_metric = Accuracy(task='multiclass', num_classes=2, average='macro')

with torch.no_grad():
for X, y in dataloader:
Expand Down
7 changes: 4 additions & 3 deletions torchosr/models/TSoftmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ def test(self, dataloader, loss_fn, conf=False):

# Define metric for open set detection
try:
inner_metric = Accuracy(task="multiclass", num_classes=self.n_known, average='weighted')
overall_metric = Accuracy(task="multiclass", num_classes=self.n_known+1, average='weighted')
inner_metric = Accuracy(task="multiclass", num_classes=self.n_known, average='macro')
overall_metric = Accuracy(task="multiclass", num_classes=self.n_known+1, average='macro')
except:
inner_metric = None
outer_metric = Accuracy(task='binary', average='weighted')

outer_metric = Accuracy(task='multiclass', num_classes=2, average='macro')

with torch.no_grad():
for X, y in dataloader:
Expand Down

0 comments on commit 860e913

Please sign in to comment.