Skip to content

Commit

Permalink
Add conditional mean and cov (#39)
Browse files Browse the repository at this point in the history
* Add conditional mean and cov

* Change log_prob test to asymmetric A

* Doc
  • Loading branch information
SamDuffield authored Sep 24, 2024
1 parent 2a43ef6 commit ea286e6
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 12 deletions.
32 changes: 32 additions & 0 deletions tests/test_conditional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import jax
from jax import numpy as jnp

import thermox


def test_mean_and_cov():
jax.config.update("jax_enable_x64", True)
dim = 2
t = 1.0

A = jnp.array([[3, 2.5], [2, 4.0]])
b = jax.random.normal(jax.random.PRNGKey(1), (dim,))
x0 = jax.random.normal(jax.random.PRNGKey(2), (dim,))
D = 2 * jnp.eye(dim)

mean = thermox.conditional.mean(t, x0, A, b, D)
samples = jax.vmap(
lambda k: thermox.sample(k, jnp.array([0.0, t]), x0, A, b, D)[-1]
)(jax.random.split(jax.random.PRNGKey(0), 1000000))
assert mean.shape == (dim,)
assert jnp.allclose(mean, jnp.mean(samples, axis=0), atol=1e-2)

cov = thermox.conditional.covariance(t, A, D)
assert cov.shape == (dim, dim)
assert jnp.allclose(cov, jnp.cov(samples.T), atol=1e-3)

mean_and_cov = thermox.conditional.mean_and_covariance(t, x0, A, b, D)
assert mean_and_cov[0].shape == (dim,)
assert mean_and_cov[1].shape == (dim, dim)
assert jnp.allclose(mean_and_cov[0], mean, atol=1e-5)
assert jnp.allclose(mean_and_cov[1], cov, atol=1e-5)
20 changes: 9 additions & 11 deletions tests/test_log_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,42 +91,40 @@ def test_MLE():
D_true = jnp.array([[1, 0.3, -0.1], [0.3, 1, 0.2], [-0.1, 0.2, 1.0]])

nts = 300
ts = jnp.linspace(0, 10, nts)
ts = jnp.linspace(0, 100, nts)
x0 = jnp.zeros_like(b_true)

n_trajecs = 3
n_trajecs = 5
rks = jax.random.split(jax.random.PRNGKey(0), n_trajecs)

samps = jax.vmap(lambda key: thermox.sample(key, ts, x0, A_true, b_true, D_true))(
rks
)

A_sqrt_init = jnp.tril(jnp.eye(3) + jax.random.normal(rks[0], (3, 3)) * 1e-1)
A_init = jnp.eye(3) + jax.random.normal(rks[0], (3, 3)) * 1e-1
b_init = jnp.zeros(3)
D_sqrt_init = jnp.eye(3)

log_prob_true = thermox.log_prob(ts, samps[0], A_true, b_true, D_true)
log_prob_init = thermox.log_prob(
ts, samps[0], A_sqrt_init @ A_sqrt_init.T, b_init, D_sqrt_init @ D_sqrt_init.T
ts, samps[0], A_init, b_init, D_sqrt_init @ D_sqrt_init.T
)

assert log_prob_true > log_prob_init

# Gradient descent
def loss(params):
A_sqrt, b, D_sqrt = params
A_sqrt = jnp.tril(A_sqrt)
A, b, D_sqrt = params
D_sqrt = jnp.tril(D_sqrt)
A = A_sqrt @ A_sqrt.T
D = D_sqrt @ D_sqrt.T
return -jax.vmap(lambda s: thermox.log_prob(ts, s, A, b, D))(
samps
).mean() / len(ts)

val_and_g = jax.jit(jax.value_and_grad(loss))

ps = (A_sqrt_init, b_init, D_sqrt_init)
ps_true = (jnp.linalg.cholesky(A_true), b_true, jnp.linalg.cholesky(D_true))
ps = (A_init, b_init, D_sqrt_init)
ps_true = (A_true, b_true, jnp.linalg.cholesky(D_true))

v, g = val_and_g(ps)
v_true, g_true = val_and_g(ps_true)
Expand All @@ -138,7 +136,7 @@ def loss(params):
n_steps = 20000
neg_log_probs = jnp.zeros(n_steps)

optimizer = optax.adam(1e-2)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(ps)

for i in range(n_steps):
Expand All @@ -149,7 +147,7 @@ def loss(params):
ps = optax.apply_updates(ps, updates)
neg_log_probs = neg_log_probs.at[i].set(neg_log_prob)

A_recover = ps[0] @ ps[0].T
A_recover = ps[0]
b_recover = ps[1]
D_recover = ps[2] @ ps[2].T

Expand Down
1 change: 1 addition & 0 deletions thermox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from thermox import linalg
from thermox import conditional
from thermox.sampler import sample
from thermox.prob import log_prob
from thermox.utils import preprocess
Expand Down
98 changes: 98 additions & 0 deletions thermox/conditional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from jax import numpy as jnp
from jax import Array

from thermox.utils import (
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
handle_matrix_inputs,
)
from thermox.sampler import expm_vp


def mean(
t: float,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
) -> Array:
"""Computes the mean of p(x_t | x_0)
For x_t evolving according to the SDE:
dx = - A * (x - b) dt + sqrt(D) dW
Args:
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
"""
A_y, D = handle_matrix_inputs(A, D)

y0 = D.sqrt_inv @ (x0 - b)
return b + D.sqrt @ expm_vp(A_y, y0, t)


def covariance(
t: float,
A: Array | ProcessedDriftMatrix,
D: Array | ProcessedDiffusionMatrix,
) -> Array:
"""Computes the covariance of p(x_t | x_0)
For x evolving according to the SDE:
dx = - A * (x - b) dt + sqrt(D) dW
Args:
ts: Times at which samples are collected. Includes time for x0.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
"""
A_y, D = handle_matrix_inputs(A, D)

identity_diffusion_cov = (
A_y.sym_eigvecs
@ jnp.diag((1 - jnp.exp(-2 * A_y.sym_eigvals * t)) / (2 * A_y.sym_eigvals))
@ A_y.sym_eigvecs.T
)
return D.sqrt @ identity_diffusion_cov @ D.sqrt.T


def mean_and_covariance(
t: float,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
) -> tuple[Array, Array]:
"""Computes the mean and covariance of p(x_t | x_0)
For x evolving according to the SDE:
dx = - A * (x - b) dt + sqrt(D) dW
Args:
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
"""
A, D = handle_matrix_inputs(A, D)
mean_val = mean(t, x0, A, b, D)
covariance_val = covariance(t, A, D)
return mean_val, covariance_val
2 changes: 1 addition & 1 deletion thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def log_prob(
Args:
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
xs: States of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
Expand Down

0 comments on commit ea286e6

Please sign in to comment.