From c275b1cb092165942157efcc0e96e32b6ab15c8e Mon Sep 17 00:00:00 2001 From: JefJ <26346574+jejon@users.noreply.github.com> Date: Tue, 19 Nov 2024 19:26:38 +0100 Subject: [PATCH] Fix tensor device assignment in windowed weighted sample covariance and update dimension indexing in SoftmaxND classes --- src/landmarker/heatmap/decoder.py | 2 +- src/landmarker/models/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/landmarker/heatmap/decoder.py b/src/landmarker/heatmap/decoder.py index 6fe2e5c..aa00825 100644 --- a/src/landmarker/heatmap/decoder.py +++ b/src/landmarker/heatmap/decoder.py @@ -752,7 +752,7 @@ def windowed_weigthed_sample_cov( ] .unsqueeze(0) .unsqueeze(0), - torch.tensor([[[window, window]]], dtype=torch.float), + torch.tensor([[[window, window]]], dtype=torch.float).to(heatmap.device), spatial_dims=spatial_dims, activation=activation, ) diff --git a/src/landmarker/models/utils.py b/src/landmarker/models/utils.py index d44647b..cb62cb1 100644 --- a/src/landmarker/models/utils.py +++ b/src/landmarker/models/utils.py @@ -5,7 +5,7 @@ class SoftmaxND(nn.Module): def __init__(self, spatial_dims): super().__init__() - self.dim = (2, 3) if spatial_dims == 2 else (2, 3, 4) + self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -2) def forward(self, x): out = torch.exp(x) @@ -15,7 +15,7 @@ def forward(self, x): class LogSoftmaxND(nn.Module): def __init__(self, spatial_dims): super().__init__() - self.dim = (2, 3) if spatial_dims == 2 else (2, 3, 4) + self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -1) self.spatial_dims = spatial_dims def forward(self, x):