Skip to content

Commit

Permalink
Refactor code to use LogSoftmaxND instead of SoftmaxND in losses.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jejon committed Oct 3, 2024
1 parent 09e34f0 commit 5a992f6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/landmarker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Landmarker
"""

__version__ = "0.1.1-alpha"
__version__ = "0.1.2-alpha"

__all__ = [
"data",
Expand Down
8 changes: 5 additions & 3 deletions src/landmarker/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
HeatmapGenerator,
LaplacianHeatmapGenerator,
)
from landmarker.models.utils import SoftmaxND
from landmarker.models.utils import LogSoftmaxND


class GeneralizedNormalHeatmapLoss(nn.Module):
Expand Down Expand Up @@ -563,7 +563,7 @@ def __init__(self, spatial_dims: int = 2, apply_softmax: bool = True, reduction:
self.spatial_dims = spatial_dims
self.apply_softmax = apply_softmax
if self.apply_softmax:
self.softmax = SoftmaxND(spatial_dims)
self.log_softmax = LogSoftmaxND(spatial_dims)
if spatial_dims not in [2, 3]:
raise ValueError("spatial_dims must be 2 or 3")
self.reduction = reduction
Expand All @@ -572,7 +572,9 @@ def __init__(self, spatial_dims: int = 2, apply_softmax: bool = True, reduction:

def forward(self, output, target):
if self.apply_softmax:
output = self.softmax(output)
output = self.log_softmax(output)
else:
output = torch.log(output.double())
nll = -target * torch.log(output.double())
if self.spatial_dims == 2:
dim = (2, 3)
Expand Down
19 changes: 19 additions & 0 deletions src/landmarker/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,22 @@ def __init__(self, spatial_dims):
def forward(self, x):
out = torch.exp(x)
return out / torch.sum(out, dim=self.dim, keepdim=True)


class LogSoftmaxND(nn.Module):
def __init__(self, spatial_dims):
super().__init__()
self.dim = (2, 3) if spatial_dims == 2 else (2, 3, 4)
self.spatial_dims = spatial_dims

def forward(self, x):
if self.spatial_dims == 2:
out_max, _ = torch.max(x, dim=-1, keepdim=True)
out_max, _ = torch.max(out_max, dim=-2, keepdim=True)
else:
out_max, _ = torch.max(x, dim=-1, keepdim=True)
out_max, _ = torch.max(out_max, dim=-1, keepdim=True)
out_max, _ = torch.max(out_max, dim=-1, keepdim=True)
x_exp = torch.exp(x - out_max)
x_exp_sum = torch.sum(x_exp, dim=self.dim, keepdim=True)
return x - out_max - torch.log(x_exp_sum)
28 changes: 28 additions & 0 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EuclideanDistanceVarianceReg,
GeneralizedNormalHeatmapLoss,
MultivariateGaussianNLLLoss,
NLLLoss,
StackedLoss,
StarLoss,
)
Expand Down Expand Up @@ -670,3 +671,30 @@ def test_stacked_loss_3d():

# check that the output is non-negative
assert (loss >= 0).all()


def test_NLLLoss_2d():
"""Test the NLLLoss class."""
reduction = "mean"
pred = torch.rand(2, 3, 64, 64)
target = torch.rand(2, 3, 64, 64)

loss_fn = NLLLoss(reduction=reduction)
expected_output_shape = torch.Size([])
loss = loss_fn(pred, target)
assert loss.shape == expected_output_shape

pred = torch.rand(2, 3, 64, 64)
target = torch.rand(2, 3, 64, 64)

loss_fn = NLLLoss(reduction="sum")
expected_output_shape = torch.Size([])

loss = loss_fn(pred, target)
assert loss.shape == expected_output_shape

loss_fn = NLLLoss(reduction="none")
expected_output_shape = torch.Size([2, 3])

loss = loss_fn(pred, target)
assert loss.shape == expected_output_shape

0 comments on commit 5a992f6

Please sign in to comment.