Skip to content

Commit

Permalink
Flexible scan
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDuffield committed Jun 4, 2024
1 parent f2a5528 commit 5b254bf
Showing 1 changed file with 61 additions and 3 deletions.
64 changes: 61 additions & 3 deletions thermox/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def sample_identity_diffusion(
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
associative_scan: bool = True,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
Expand All @@ -27,21 +28,76 @@ def sample_identity_diffusion(
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
where T=len(ts).
Uses jax.lax.associative_scan, so will run in time O(log(T) * d^2) on a GPU/TPU with
O(T) cores.
If associative_scan = True then jax.lax.associative_scan is used, so will run in
time O(log(T) * d^2) on a GPU/TPU with O(T) cores.
Args:
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
b: Drift displacement vector.
associative_scan: If True, uses jax.lax.associative_scan.
Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
if associative_scan:
return _sample_identity_diffusion_associative_scan(key, ts, x0, A, b)
else:
return _sample_identity_diffusion_scan(key, ts, x0, A, b)


def _sample_identity_diffusion_scan(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> Array:
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

def expm_vp(v, dt):
out = A.eigvecs_inv @ v
out = jnp.exp(-A.eigvals * dt) * out
out = A.eigvecs @ out
return out.real

def transition_mean(x, dt):
return b + expm_vp(x - b, dt)

def transition_cov_sqrt_vp(v, dt):
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
out = diag * v
out = A.sym_eigvecs @ out
return out.real

def next_x(x, dt, tkey):
randv = jax.random.normal(tkey, shape=x.shape)
return transition_mean(x, dt) + transition_cov_sqrt_vp(randv, dt)

def scan_body(x_and_key, dt):
x, rk = x_and_key
rk, rk_use = jax.random.split(rk)
x = next_x(x, dt, rk_use)
return (x, rk), x

dts = jnp.diff(ts)

xs = jax.lax.scan(scan_body, (x0, key), dts)[1]
xs = jnp.concatenate([jnp.expand_dims(x0, axis=0), xs], axis=0)
return xs


def _sample_identity_diffusion_associative_scan(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> Array:
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

Expand Down Expand Up @@ -89,6 +145,7 @@ def sample(
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
associative_scan: bool = True,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
Expand Down Expand Up @@ -116,6 +173,7 @@ def sample(
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
associative_scan: If True, uses jax.lax.associative_scan.
Returns:
Array-like, desired samples.
Expand All @@ -125,5 +183,5 @@ def sample(

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y)
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan)
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)

0 comments on commit 5b254bf

Please sign in to comment.