Skip to content

Commit

Permalink
Merge pull request #25 from normal-computing/fix_solve_processed_inputs
Browse files Browse the repository at this point in the history
Fix how sample function handles Processed inputs
  • Loading branch information
SamDuffield authored May 28, 2024
2 parents a5b9638 + fee0aac commit 2f395b0
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 39 deletions.
22 changes: 22 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import jax
from jax import numpy as jnp

import thermox


def test_sample_array_input():
key = jax.random.PRNGKey(0)
dim = 2
dt = 0.1
ts = jnp.arange(0, 10_000, dt)

A = jnp.array([[3, 2], [2, 4.0]])
b, x0 = jnp.zeros(dim), jnp.zeros(dim)
D = 2 * jnp.eye(dim)

samples = thermox.sample(key, ts, x0, A, b, D)

samp_cov = jnp.cov(samples.T)
samp_mean = jnp.mean(samples.T, axis=1)
assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1)
assert jnp.allclose(samp_mean, b, atol=1e-1)
60 changes: 60 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from jax import numpy as jnp

from thermox.utils import (
handle_matrix_inputs,
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
preprocess,
)


def test_handle_matrix_inputs_arrays():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = preprocess(A, D)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert jnp.all(a.val == A_star.val)


def test_handle_matrix_inputs_processed():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = handle_matrix_inputs(a, d)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert jnp.all(a.val == A_star.val)


def test_handle_matrix_inputs_array_drift_processed_diffusion():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = handle_matrix_inputs(A, d)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert jnp.all(a.val == A_star.val)


def test_handle_matrix_inputs_array_diffusion_processed_drift():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = handle_matrix_inputs(a, D)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert not jnp.all(a.val == A_star.val)
37 changes: 20 additions & 17 deletions thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import Array, vmap

from thermox.utils import (
preprocess,
handle_matrix_inputs,
preprocess_drift_matrix,
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
Expand All @@ -28,10 +28,10 @@ def log_prob_identity_diffusion(
Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2).
Args:
ts: array-like, times at which samples are collected. Includes time for x0.
xs: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
b: Drift displacement vector.
Returns:
Scalar log probability of given xs.
"""
Expand Down Expand Up @@ -94,24 +94,27 @@ def log_prob(
Assumes x(t_0) is given deterministically.
Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2).
Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2),
where T=len(ts).
By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.
Args:
ts: array-like, times at which samples are collected. Includes time for x0.
xs: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
Returns:
Scalar log probability of given xs.
"""
if isinstance(A, Array) or isinstance(D, Array):
if isinstance(A, ProcessedDriftMatrix):
A = A.val
if isinstance(D, ProcessedDiffusionMatrix):
D = D.val
A_y, D = preprocess(A, D)
A_y, D = handle_matrix_inputs(A, D)

ys = vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt_inv, xs)
b_y = D.sqrt_inv @ b
Expand Down
39 changes: 21 additions & 18 deletions thermox/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import Array

from thermox.utils import (
preprocess,
handle_matrix_inputs,
preprocess_drift_matrix,
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
Expand All @@ -28,11 +28,11 @@ def sample_identity_diffusion(
where T=len(ts).
Args:
key: jax PRNGKey.
ts: array-like, times at which samples are collected. Includes time for x0.
x0: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
b: Drift displacement vector.
Returns:
Array-like, desired samples.
Expand Down Expand Up @@ -88,26 +88,29 @@ def sample(
by using exact diagonalization.
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).
By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.
Args:
key: jax PRNGKey.
ts: array-like, times at which samples are collected. Includes time for x0.
x0: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
if isinstance(A, Array) and isinstance(D, Array):
A_y, D = preprocess(A, D)

assert isinstance(A_y, ProcessedDriftMatrix)
assert isinstance(D, ProcessedDiffusionMatrix)
A_y, D = handle_matrix_inputs(A, D)

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
Expand Down
33 changes: 29 additions & 4 deletions thermox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix:
"""Preprocesses matrix A (calculates eigendecompositions of A and (A+A^T)/2)
Args:
A: drift matrix.
A: Drift matrix.
Returns:
ProcessedDriftMatrix containing eigendeomcomposition of A and (A+A^T)/2.
Expand Down Expand Up @@ -59,7 +59,7 @@ def preprocess_diffusion_matrix(D: Array) -> ProcessedDiffusionMatrix:
"""Preprocesses diffusion matrix D (calculates D^0.5 and D^-0.5 via Cholesky)
Args:
D: diffusion matrix.
D: Diffusion matrix.
Returns:
ProcessedDiffusionMatrix containing D^0.5 and D^-0.5.
Expand All @@ -77,8 +77,8 @@ def preprocess(
D^0.5 and D^-0.5)
Args:
A: drift matrix.
D: diffusion matrix.
A: Drift matrix.
D: Diffusion matrix.
Returns:
ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2.
Expand All @@ -89,3 +89,28 @@ def preprocess(
A_y = PD.sqrt_inv @ A @ PD.sqrt
PA_y = preprocess_drift_matrix(A_y)
return PA_y, PD


def handle_matrix_inputs(
A: Array | ProcessedDriftMatrix, D: Array | ProcessedDiffusionMatrix
) -> Tuple[ProcessedDriftMatrix, ProcessedDiffusionMatrix]:
"""Checks the type of the input drift matrix, A, and diffusion matrix, D,
and ensures that they are processed in the correct way.
Helper function for sample and log_prob functions.
Args:
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
Returns:
ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2.
where A_y = D^-0.5 @ A @ D^0.5
ProcessedDiffusionMatrix containing D^0.5 and D^-0.5.
"""
if isinstance(A, Array) or isinstance(D, Array):
if isinstance(A, ProcessedDriftMatrix):
A = A.val
if isinstance(D, ProcessedDiffusionMatrix):
D = D.val
A, D = preprocess(A, D)
return A, D

0 comments on commit 2f395b0

Please sign in to comment.