Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and extend flow derivatives and losses #117

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 37 additions & 167 deletions src/deepali/core/bspline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
r"""Functions for B-spline interpolation."""

from itertools import combinations_with_replacement, permutations, product
from typing import Callable, Dict, Optional, Sequence, Tuple, Union, overload
from typing import Callable, Optional, Sequence, Tuple, Union, overload

import torch
from torch import Size, Tensor
Expand All @@ -10,7 +9,6 @@
from .enum import PaddingMode, SpatialDim, SpatialDimArg
from .grid import Grid
from .image import conv, conv1d
from .itertools import is_even_permutation
from .kernels import cubic_bspline1d
from .tensor import move_dim
from .typing import DType, ScalarOrTuple
Expand Down Expand Up @@ -260,11 +258,11 @@ def cubic_bspline_interpolation_weights(

def evaluate_cubic_bspline(
data: Tensor,
stride: ScalarOrTuple[int],
stride: Optional[ScalarOrTuple[int]] = None,
size: Optional[Size] = None,
shape: Optional[Size] = None,
kernel: Optional[Union[Tensor, Sequence[Tensor]]] = None,
derivative: ScalarOrTuple[int] = 0,
derivative: Optional[ScalarOrTuple[int]] = None,
transpose: bool = False,
) -> Tensor:
r"""Evaluate cubic B-spline function.
Expand All @@ -274,11 +272,16 @@ def evaluate_cubic_bspline(
stride: Number of output grid points between control points plus one. This is the stride of the
transposed convolution used to upsample the control point displacements to the output size.
If a sequence of values is given, these must be the strides for the different spatial
dimensions in the order ``(sx, ...)``.
dimensions in the order ``(sx, ...)``. When a ``kernel`` is specified and ``transpose=False``,
this argument is ignored. Otherwise it is required.
size: Spatial size of output tensor in the order ``(nx, ...)``.
shape: Spatial size of output tensor in the order ``(..., nx)``.
kernel: Precomputed cubic B-spline interpolation kernel. When multiple 1D kernels are given,
these must be in the order ``(kx, ...)``.
derivative: Order of partial derivatives along each spatial dimension. When a scalar is given,
the same order is applied to all spatial dimensions. When a ``kernel`` is specified and
``transpose=False``, this argument is ignored. For ``transpose=True``, the evaluation of
partial derivatives is not implemented, and hence this argument must be None or 0.
transpose: Whether to use separable transposed convolution as implemented in AIRLab.
When ``False``, a more efficient implementation using multi-channel convolution followed
by a reshuffling of the output is performed. This more efficient and also more accurate
Expand All @@ -303,14 +306,20 @@ def evaluate_cubic_bspline(
D = data.ndim - 2
N = data.shape[0]
C = data.shape[1]
if isinstance(stride, int):
stride = [stride] * D
# Implementation inspired by AIRLab
if transpose:
if stride is None:
raise ValueError("evaluate_cubic_bspline() 'stride' is required when transpose=True")
if isinstance(stride, int):
stride = [stride] * D
if any(s < 1 for s in stride):
raise ValueError("evaluate_cubic_bspline() 'stride' must be positive")
if kernel is None:
if derivative != 0:
if (isinstance(derivative, int) and derivative != 0) or (
isinstance(derivative, Sequence) and any(order != 0 for order in derivative)
):
raise NotImplementedError(
"evaluate_cubic_bspline() 'derivative' must be 0 when kernel=None and transpose=True"
"evaluate_cubic_bspline() 'derivative' order cannot be non-zero when kernel=None and transpose=True"
)
kernels = {}
for s in stride:
Expand All @@ -334,8 +343,19 @@ def evaluate_cubic_bspline(
# Implementation adapted from MIRTK
else:
if kernel is None:
if stride is None:
stride = 1
if isinstance(stride, int):
stride = [stride] * D
if any(s < 1 for s in stride):
raise ValueError("evaluate_cubic_bspline() 'stride' must be positive")
if derivative is None:
derivative = 0
kernel = cubic_bspline_interpolation_weights(
stride=stride, derivative=derivative, dtype=data.dtype, device=data.device
stride=stride,
derivative=derivative,
dtype=data.dtype,
device=data.device,
)
elif isinstance(kernel, Tensor):
kernel = [kernel] * D
Expand All @@ -347,6 +367,12 @@ def evaluate_cubic_bspline(
dims = tuple(SpatialDim(dim).tensor_dim(data.ndim) for dim in range(D))
conv_fn: Callable[..., Tensor] = [F.conv1d, F.conv2d, F.conv3d][D - 1]
for dim, w in zip(dims, kernel):
if w.ndim == 1:
w = w.unsqueeze(0)
if w.ndim != 2:
raise ValueError(
"evaluate_cubic_bspline() 'kernel' must be 1- or 2-dimensional tensors"
)
weight = w.reshape((w.shape[0], 1, w.shape[1]) + (1,) * (D - 1))
weight = weight.tile((C,) + (1,) * (weight.ndim - 1))
output = move_dim(output, dim, 2)
Expand All @@ -359,162 +385,6 @@ def evaluate_cubic_bspline(
return output


def cubic_bspline_jacobian_det(data: Tensor, stride: ScalarOrTuple[int]) -> Tensor:
r"""Evaluate Jacobian determinant of cubic B-spline free-form deformation."""
if not isinstance(data, Tensor):
raise TypeError("cubic_bspline_jacobian_det() 'data' must be torch.Tensor")
if not torch.is_floating_point(data):
raise TypeError("cubic_bspline_jacobian_det() 'data' must have floating point dtype")
if data.ndim < 3:
raise ValueError("cubic_bspline_jacobian_det() 'data' must have shape (N, C, ..., X)")
D = data.ndim - 2
C = data.shape[1]
if C != D:
raise ValueError(
f"cubic_bspline_jacobian_det() 'data' mismatch between number of channels ({C}) and spatial dimensions ({D})"
)
jac: Optional[Tensor] = None
for perm in permutations(range(D)):
term: Optional[Tensor] = None
for i, j in zip(range(D), perm):
derivative = [1 if dim == j else 0 for dim in range(D)]
du = evaluate_cubic_bspline(data.narrow(1, i, 1), stride=stride, derivative=derivative)
if i == j:
du = du.add_(1) # T(x) = x + u(x)
term = du if term is None else term.mul_(du)
assert term is not None
if jac is None:
jac = term
elif is_even_permutation(perm):
jac = jac.add_(term)
else:
jac = jac.sub_(term)
assert jac is not None
return jac


def cubic_bspline_jacobian_dict(
data: Tensor,
stride: ScalarOrTuple[int],
size: Optional[Size] = None,
shape: Optional[Size] = None,
add_identity: bool = False,
) -> Dict[Tuple[int, int], Tensor]:
r"""Evaluate Jacobian of cubic B-spline free-form deformation.

Args:
data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, D, ..., X)``,
where ``D`` is the number of spatial dimensions.
stride: Number of output grid points between control points plus one. If a sequence of
values is given, these must be the strides for the different spatial dimensions in
the order ``(sx, ...)``.
size: Spatial size of output tensor in the order ``(nx, ...)``.
shape: Spatial size of output tensor in the order ``(..., nx)``.
add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the free-form
deformation given by :math:`x + u(x)` (True), where :math:`u` is the cubic B-spline
function, by adding the identity matrix to the Jacobian of :math:`u`.

Returns:
Dictionary of spatial derivatives with keys corresponding to (row, col) indices.

"""
if not isinstance(data, Tensor):
raise TypeError("cubic_bspline_jacobian_dict() 'data' must be torch.Tensor")
if not torch.is_floating_point(data):
raise TypeError("cubic_bspline_jacobian_dict() 'data' must have floating point dtype")
if data.ndim < 3:
raise ValueError("cubic_bspline_jacobian_dict() 'data' must have shape (N, C, ..., X)")
if size is not None:
if shape is not None:
raise ValueError(
"cubic_bspline_jacobian_dict() 'size' and 'shape' are mutually exclusive"
)
shape = Size(reversed(size))
D = data.ndim - 2
C = data.shape[1]
if C != D:
raise ValueError(
f"cubic_bspline_jacobian_dict() 'data' mismatch between number of channels ({C}) and spatial dimensions ({D})"
)
jac = {}
for i, j in combinations_with_replacement(range(D), 2):
derivative = [1 if dim == j else 0 for dim in range(D)]
coeff = data.narrow(1, i, 1)
deriv = evaluate_cubic_bspline(coeff, shape=shape, stride=stride, derivative=derivative)
if add_identity and i == j:
deriv = deriv.add_(1) # T(x) = x + u(x)
jac[(i, j)] = deriv
return {(i, j): jac[(i, j) if i < j else (j, i)] for i, j in product(range(D), repeat=2)}


def cubic_bspline_jacobian_matrix(
data: Tensor,
stride: ScalarOrTuple[int],
size: Optional[Size] = None,
shape: Optional[Size] = None,
add_identity: bool = False,
) -> Tensor:
r"""Evaluate Jacobian of cubic B-spline free-form deformation.

Args:
data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, D, ..., X)``,
where ``D`` is the number of spatial dimensions.
stride: Number of output grid points between control points plus one. If a sequence of
values is given, these must be the strides for the different spatial dimensions in
the order ``(sx, ...)``.
size: Spatial size of output tensor in the order ``(nx, ...)``.
shape: Spatial size of output tensor in the order ``(..., nx)``.
add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the free-form
deformation given by :math:`x + u(x)` (True), where :math:`u` is the cubic B-spline
function, by adding the identity matrix to the Jacobian of :math:`u`.

Returns:
Full Jacobian matrices as tensors of shape ``(N, ..., X, D, D)``.

"""
N = data.shape[0]
D = data.ndim - 2
jac = cubic_bspline_jacobian_dict(
data, stride=stride, shape=shape, size=size, add_identity=add_identity
)
mat = torch.cat([jac[(i, j)] for i, j in product(range(D), repeat=2)], dim=1)
mat = move_dim(mat, 1, -1)
mat = mat.reshape((N,) + jac[(0, 0)].shape[2:] + (D, D))
return mat.contiguous()


def cubic_bspline_jacobian_triu(
data: Tensor,
stride: ScalarOrTuple[int],
size: Optional[Size] = None,
shape: Optional[Size] = None,
add_identity: bool = False,
) -> Tensor:
r"""Evaluate Jacobian of cubic B-spline free-form deformation.

Args:
data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, D, ..., X)``,
where ``D`` is the number of spatial dimensions.
stride: Number of output grid points between control points plus one. If a sequence of
values is given, these must be the strides for the different spatial dimensions in
the order ``(sx, ...)``.
size: Spatial size of output tensor in the order ``(nx, ...)``.
shape: Spatial size of output tensor in the order ``(..., nx)``.
add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the free-form
deformation given by :math:`x + u(x)` (True), where :math:`u` is the cubic B-spline
function, by adding the identity matrix to the Jacobian of :math:`u`.

Returns:
Flattened upper triangular Jacobian matrices as tensors of shape ``(N, D * (D + 1) / 2, ..., X)``.

"""
D = data.ndim - 2
jac = cubic_bspline_jacobian_dict(
data, stride=stride, shape=shape, size=size, add_identity=add_identity
)
return torch.cat([jac[(i, j)] for i, j in combinations_with_replacement(range(D), 2)], dim=1)


def subdivide_cubic_bspline(
data: Tensor, dims: Optional[Union[SpatialDimArg, Sequence[SpatialDimArg]]] = None
) -> Tensor:
Expand Down
Loading