Skip to content

Commit

Permalink
Fix tensor device assignment in windowed weighted sample covariance a…
Browse files Browse the repository at this point in the history
…nd update dimension indexing in SoftmaxND classes
  • Loading branch information
jejon committed Nov 19, 2024
1 parent c83a740 commit c275b1c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/landmarker/heatmap/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def windowed_weigthed_sample_cov(
]
.unsqueeze(0)
.unsqueeze(0),
torch.tensor([[[window, window]]], dtype=torch.float),
torch.tensor([[[window, window]]], dtype=torch.float).to(heatmap.device),
spatial_dims=spatial_dims,
activation=activation,
)
Expand Down
4 changes: 2 additions & 2 deletions src/landmarker/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class SoftmaxND(nn.Module):
def __init__(self, spatial_dims):
super().__init__()
self.dim = (2, 3) if spatial_dims == 2 else (2, 3, 4)
self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -2)

def forward(self, x):
out = torch.exp(x)
Expand All @@ -15,7 +15,7 @@ def forward(self, x):
class LogSoftmaxND(nn.Module):
def __init__(self, spatial_dims):
super().__init__()
self.dim = (2, 3) if spatial_dims == 2 else (2, 3, 4)
self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -1)
self.spatial_dims = spatial_dims

def forward(self, x):
Expand Down

0 comments on commit c275b1c

Please sign in to comment.