Skip to content

Commit

Permalink
✨ Implement neural autoregressive flow (NAF)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed May 25, 2022
1 parent 40f6a94 commit 81a22d3
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 73 deletions.
130 changes: 64 additions & 66 deletions lampe/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class PermutationTransform(Transform):
r"""Transform via a permutation of the elements.
r"""Creates a transformation that permutes the elements.
Arguments:
order: The permuatation order, with shape :math:`(*, D)`.
Expand Down Expand Up @@ -48,7 +48,7 @@ def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:


class CosTransform(Transform):
r"""Transform via the mapping :math:`f(x) = -\cos(x)`."""
r"""Creates a transformation :math:`f(x) = -\cos(x)`."""

domain = constraints.interval(0, math.pi)
codomain = constraints.interval(-1, 1)
Expand All @@ -68,7 +68,7 @@ def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:


class SinTransform(Transform):
r"""Transform via the mapping :math:`f(x) = \sin(x)`."""
r"""Creates a transformation :math:`f(x) = \sin(x)`."""

domain = constraints.interval(-math.pi / 2, math.pi / 2)
codomain = constraints.interval(-1, 1)
Expand All @@ -88,7 +88,7 @@ def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:


class MonotonicAffineTransform(Transform):
r"""Transform via the mapping :math:`f(x) = x \times \text{softplus}(\alpha) + \beta`.
r"""Creates a transformation :math:`f(x) = x \times \text{softplus}(\alpha) + \beta`.
Arguments:
shift: The shift term :math:`\beta`, with shape :math:`(*,)`.
Expand Down Expand Up @@ -123,69 +123,8 @@ def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
return torch.log(self.scale).expand(x.shape)


class MonotonicSOSTransform(Transform):
r"""Transform via a sum of sigmoids
.. math:: f(x) = x + \sum_{i = 1}^K \gamma_i \sigma(x \times \alpha_i + \beta_i)
Arguments:
shift: The shift terms :math:`\beta_i`, with shape :math:`(*, K)`.
scale: The unconstrained scale factors :math:`\alpha_i`, with shape :math:`(*, K)`.
amplitude: The unconstrained amplitudes :math:`\gamma_i`, with shape :math:`(*, K)`.
bisect: The number of bisection steps for the inverse transformation.
eps: A numerical stability term.
"""

domain = constraints.real
codomain = constraints.real
bijective = True
sign = +1

def __init__(
self,
shift: Tensor,
scale: Tensor,
amplitude: Tensor,
bisect: int = 16,
eps: float = 1e-3,
**kwargs,
):
super().__init__(**kwargs)

self.shift = shift
self.scale = F.softplus(scale) + eps
self.amplitude = F.softplus(amplitude) + eps

self.bisect = bisect

def _call(self, x: Tensor) -> Tensor:
return x + (self.amplitude * torch.sigmoid(x[..., None] * self.scale + self.shift)).sum(dim=-1)

def _inverse(self, y: Tensor) -> Tensor:
xa, xb = y - self.amplitude.sum(dim=-1), y
ya, yb = self._call(xa), self._call(xb)

for _ in range(self.bisect):
xc = (xa + xb) / 2
yc = self._call(xc)

xa, ya, xb, yb = torch.where(
yc <= y,
torch.stack((xc, yc, xb, yb)),
torch.stack((xa, ya, xc, yc)),
)

return (xa + xb) / 2

def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
sigmoids = torch.sigmoid(x[..., None] * self.scale + self.shift)
jacobian = 1 + (self.amplitude * self.scale * sigmoids * (1 - sigmoids)).sum(dim=-1)

return torch.log(jacobian)


class MonotonicRationalQuadraticSplineTransform(Transform):
r"""Transform via a monotonic rational-quadratic spline mapping.
r"""Creates a monotonic rational-quadratic spline transformation.
References:
Neural Spline Flows (Durkan et al., 2019)
Expand Down Expand Up @@ -302,6 +241,65 @@ def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
return torch.log(jacobian) * mask


class MonotonicTransform(Transform):
r"""Creates a transformation from a monotonic univariate function :math:`f(x)`.
The inverse function :math:`f^{-1}` is approximated using the bisection method.
Wikipedia:
https://en.wikipedia.org/wiki/Bisection_method
Arguments:
f: A monotonic univariate function :math:`f(x)`.
bound: The domain bound :math:`B`.
eps: The numerical tolerance for the inverse transformation.
"""

domain = constraints.real
codomain = constraints.real
bijective = True
sign = +1

def __init__(
self,
f: Callable[[Tensor], Tensor],
bound: float = 1e1,
eps: float = 1e-6,
**kwargs,
):
super().__init__(**kwargs)

self.f = f
self.bound = bound
self.eps = eps

def _call(self, x: Tensor) -> Tensor:
return self.f(self.bound * torch.tanh(x / self.bound))

def _inverse(self, y: Tensor) -> Tensor:
a = torch.full_like(y, -self.bound)
b = torch.full_like(y, self.bound)

for _ in range(int(math.log2(self.bound / self.eps))):
c = (a + b) / 2

mask = self.f(c) < y

a = torch.where(mask, c, a)
b = torch.where(mask, b, c)

x = (a + b) / 2

return self.bound * torch.atanh(x / self.bound)

def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
return torch.log(torch.autograd.functional.jacobian(
func=lambda x: self._call(x).sum(),
inputs=x,
create_graph=torch.is_grad_enabled(),
))


class AutoregressiveTransform(Transform):
r"""Tranform via an autoregressive mapping.
Expand Down
85 changes: 81 additions & 4 deletions lampe/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import *


__all__ = ['MLP', 'ResBlock', 'ResMLP', 'MaskedMLP']
__all__ = ['MLP', 'ResBlock', 'ResMLP', 'MaskedMLP', 'MonotonicMLP']


class Affine(nn.Module):
Expand Down Expand Up @@ -258,10 +258,10 @@ def forward(self, x: Tensor) -> Tensor:


class MaskedMLP(MLP):
r"""Creates a masked multi-layer perceptron (MaskedMLP).
r"""Creates a masked multi-layer perceptron.
The resulting MLP is a transformation :math:`y = f(x)` such that the Jacobian
entry :math:`\frac{\partial y_j}{\partial x_i}` is null if :math:`A_{ij} = 0`.
The resulting MLP is a transformation :math:`y = f(x)` whose Jacobian entries
:math:`\frac{\partial y_j}{\partial x_i}` are null if :math:`A_{ij} = 0`.
Arguments:
adjacency: The adjacency matrix :math:`A \in \{0, 1\}^{M \times N}`.
Expand Down Expand Up @@ -316,3 +316,80 @@ def __init__(
mask = mask[:, indices]

self[i] = MaskedLinear(adjacency=mask)


class MonotonicLinear(nn.Linear):
r"""Creates a monotonic linear layer.
.. math:: y = x |W|^T + b
Arguments:
args: Positional arguments passed to :class:`torch.nn.Linear`.
kwargs: Keyword arguments passed to :class:`torch.nn.Linear`.
"""

def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight.abs(), self.bias)


class TwoWayELU(nn.ELU):
r"""Creates a layer that splits the input into two groups and applies
:math:`\text{ELU}(x)` to the first and :math:`-\text{ELU}(-x)` to the second.
Arguments:
args: Positional arguments passed to :class:`torch.nn.ELU`.
kwargs: Keyword arguments passed to :class:`torch.nn.ELU`.
"""

def forward(self, x: Tensor) -> Tensor:
x0, x1 = torch.split(x, x.shape[-1] // 2, dim=-1)

return torch.cat((
super().forward(x0),
-super().forward(-x1),
), dim=-1)


class MonotonicMLP(MLP):
r"""Creates a monotonic multi-layer perceptron.
The resulting MLP is a transformation :math:`y = f(x)` whose Jacobian entries
:math:`\frac{\partial y_j}{\partial x_i}` are positive.
Arguments:
args: Positional arguments passed to :class:`MLP`.
kwargs: Keyword arguments passed to :class:`MLP`.
Example:
>>> net = MonotonicMLP(4, 4, [16, 32])
>>> net
MonotonicMLP(
(0): MonotonicLinear(in_features=4, out_features=16, bias=True)
(1): TwoWayELU(alpha=1.0)
(2): MonotonicLinear(in_features=16, out_features=32, bias=True)
(3): TwoWayELU(alpha=1.0)
(4): MonotonicLinear(in_features=32, out_features=4, bias=True)
)
>>> x = torch.randn(4)
>>> torch.autograd.functional.jacobian(net, x).t()
tensor([[0.8742, 0.9439, 0.9759, 1.1040],
[0.8969, 0.9716, 0.9866, 1.1321],
[1.0780, 1.1651, 1.2056, 1.3674],
[0.8596, 0.9400, 0.9502, 1.0916]])
"""

def __init__(
self,
*args,
**kwargs,
):
kwargs['activation'] = 'ELU'
kwargs['batchnorm'] = False

super().__init__(*args, **kwargs)

for i, layer in enumerate(self):
if isinstance(layer, nn.Linear):
layer.__class__ = MonotonicLinear
elif isinstance(layer, nn.ELU):
layer.__class__ = TwoWayELU
87 changes: 85 additions & 2 deletions lampe/nn/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from torch import Tensor, Size
from typing import *

from . import MLP, MaskedMLP
from . import MLP, MaskedMLP, MonotonicMLP
from ..distributions import *
from ..utils import broadcast


__all__ = [
'DistributionModule', 'TransformModule', 'FlowModule',
'MaskedAutoregressiveTransform', 'MAF',
'NeuralAutoregressiveTransform', 'NAF',
]


Expand Down Expand Up @@ -234,7 +235,6 @@ def __init__(
features: int,
context: int = 0,
transforms: int = 3,
linear: bool = False,
**kwargs,
):
increasing = torch.arange(features)
Expand All @@ -253,3 +253,86 @@ def __init__(
base = Buffer(DiagNormal, torch.zeros(features), torch.ones(features))

super().__init__(transforms, base)


class NeuralAutoregressiveTransform(MaskedAutoregressiveTransform):
r"""Creates a neural autoregressive transform.
The monotonic neural network is parameterized by its internal (positive) weights,
which are independent of the context and the features. To modulate its behavior,
it receives as input a signal that is autoregressively dependent on the features
and context.
Arguments:
features: The number of features.
context: The number of context features.
signal: The number of signal features of the monotonic network.
monotone: Keyword arguments passed to :class:`lampe.nn.MonotonicMLP`.
kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`.
"""

def __init__(
self,
features: int,
context: int = 0,
signal: int = 8,
monotone: Dict[str, Any] = {},
**kwargs,
):
super().__init__(
features=features,
context=context,
univariate=self.univariate,
shapes=[(signal,)],
**kwargs,
)

self.transform = MonotonicMLP(1 + signal, 1, **monotone)

def univariate(self, signal: Tensor) -> MonotonicTransform:
def f(x: Tensor) -> Tensor:
return self.transform(
torch.cat((x[..., None], signal), dim=-1)
).squeeze(-1)

return MonotonicTransform(f)


class NAF(FlowModule):
r"""Creates a neural autoregressive flow (NAF).
References:
Neural Autoregressive Flows
(Huang et al., 2018)
https://arxiv.org/abs/1804.00779
Arguments:
features: The number of features.
context: The number of context features.
transforms: The number of autoregressive transforms.
kwargs: Keyword arguments passed to :class:`NeuralAutoregressiveTransform`.
"""

def __init__(
self,
features: int,
context: int = 0,
transforms: int = 3,
**kwargs,
):
increasing = torch.arange(features)
decreasing = torch.flipud(increasing)

transforms = [
NeuralAutoregressiveTransform(
features=features,
context=context,
order=decreasing if i % 2 else increasing,
**kwargs,
)
for i in range(transforms)
]

base = Buffer(DiagNormal, torch.zeros(features), torch.ones(features))

super().__init__(transforms, base)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name='lampe',
version='0.4.2',
version='0.4.3',
packages=setuptools.find_packages(),
description='Likelihood-free AMortized Posterior Estimation with PyTorch',
keywords='parameter inference bayes posterior amortized likelihood ratio mcmc torch',
Expand Down

0 comments on commit 81a22d3

Please sign in to comment.