diff --git a/src/landmarker/__init__.py b/src/landmarker/__init__.py index 8d1b254..9762f1f 100755 --- a/src/landmarker/__init__.py +++ b/src/landmarker/__init__.py @@ -2,7 +2,7 @@ Landmarker """ -__version__ = "0.1.1-alpha" +__version__ = "0.1.2-alpha" __all__ = [ "data", diff --git a/src/landmarker/losses/losses.py b/src/landmarker/losses/losses.py index b77afab..1f2009c 100644 --- a/src/landmarker/losses/losses.py +++ b/src/landmarker/losses/losses.py @@ -13,7 +13,7 @@ HeatmapGenerator, LaplacianHeatmapGenerator, ) -from landmarker.models.utils import SoftmaxND +from landmarker.models.utils import LogSoftmaxND class GeneralizedNormalHeatmapLoss(nn.Module): @@ -563,7 +563,7 @@ def __init__(self, spatial_dims: int = 2, apply_softmax: bool = True, reduction: self.spatial_dims = spatial_dims self.apply_softmax = apply_softmax if self.apply_softmax: - self.softmax = SoftmaxND(spatial_dims) + self.log_softmax = LogSoftmaxND(spatial_dims) if spatial_dims not in [2, 3]: raise ValueError("spatial_dims must be 2 or 3") self.reduction = reduction @@ -572,7 +572,9 @@ def __init__(self, spatial_dims: int = 2, apply_softmax: bool = True, reduction: def forward(self, output, target): if self.apply_softmax: - output = self.softmax(output) + output = self.log_softmax(output) + else: + output = torch.log(output.double()) nll = -target * torch.log(output.double()) if self.spatial_dims == 2: dim = (2, 3) diff --git a/src/landmarker/models/utils.py b/src/landmarker/models/utils.py index eba8eaf..def6ee1 100644 --- a/src/landmarker/models/utils.py +++ b/src/landmarker/models/utils.py @@ -10,3 +10,22 @@ def __init__(self, spatial_dims): def forward(self, x): out = torch.exp(x) return out / torch.sum(out, dim=self.dim, keepdim=True) + + +class LogSoftmaxND(nn.Module): + def __init__(self, spatial_dims): + super().__init__() + self.dim = (2, 3) if spatial_dims == 2 else (2, 3, 4) + self.spatial_dims = spatial_dims + + def forward(self, x): + if self.spatial_dims == 2: + out_max, _ = torch.max(x, dim=-1, keepdim=True) + out_max, _ = torch.max(out_max, dim=-2, keepdim=True) + else: + out_max, _ = torch.max(x, dim=-1, keepdim=True) + out_max, _ = torch.max(out_max, dim=-1, keepdim=True) + out_max, _ = torch.max(out_max, dim=-1, keepdim=True) + x_exp = torch.exp(x - out_max) + x_exp_sum = torch.sum(x_exp, dim=self.dim, keepdim=True) + return x - out_max - torch.log(x_exp_sum) diff --git a/tests/test_loss.py b/tests/test_loss.py index d3ddb59..09c3a1f 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -12,6 +12,7 @@ EuclideanDistanceVarianceReg, GeneralizedNormalHeatmapLoss, MultivariateGaussianNLLLoss, + NLLLoss, StackedLoss, StarLoss, ) @@ -670,3 +671,30 @@ def test_stacked_loss_3d(): # check that the output is non-negative assert (loss >= 0).all() + + +def test_NLLLoss_2d(): + """Test the NLLLoss class.""" + reduction = "mean" + pred = torch.rand(2, 3, 64, 64) + target = torch.rand(2, 3, 64, 64) + + loss_fn = NLLLoss(reduction=reduction) + expected_output_shape = torch.Size([]) + loss = loss_fn(pred, target) + assert loss.shape == expected_output_shape + + pred = torch.rand(2, 3, 64, 64) + target = torch.rand(2, 3, 64, 64) + + loss_fn = NLLLoss(reduction="sum") + expected_output_shape = torch.Size([]) + + loss = loss_fn(pred, target) + assert loss.shape == expected_output_shape + + loss_fn = NLLLoss(reduction="none") + expected_output_shape = torch.Size([2, 3]) + + loss = loss_fn(pred, target) + assert loss.shape == expected_output_shape