From 5b254bf5a9f7652f300656669676bda77d402ac3 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 4 Jun 2024 10:52:37 +0100 Subject: [PATCH] Flexible scan --- thermox/sampler.py | 64 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/thermox/sampler.py b/thermox/sampler.py index 1943d11..47db085 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -17,6 +17,7 @@ def sample_identity_diffusion( x0: Array, A: Array | ProcessedDriftMatrix, b: Array, + associative_scan: bool = True, ) -> Array: """Collects samples from the Ornstein-Uhlenbeck process, defined as: @@ -27,8 +28,8 @@ def sample_identity_diffusion( Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2) where T=len(ts). - Uses jax.lax.associative_scan, so will run in time O(log(T) * d^2) on a GPU/TPU with - O(T) cores. + If associative_scan = True then jax.lax.associative_scan is used, so will run in + time O(log(T) * d^2) on a GPU/TPU with O(T) cores. Args: key: Jax PRNGKey. @@ -36,12 +37,67 @@ def sample_identity_diffusion( x0: Initial state of the process. A: Drift matrix (Array or thermox.ProcessedDriftMatrix). b: Drift displacement vector. + associative_scan: If True, uses jax.lax.associative_scan. Returns: Array-like, desired samples. shape: (len(ts), ) + x0.shape """ + if associative_scan: + return _sample_identity_diffusion_associative_scan(key, ts, x0, A, b) + else: + return _sample_identity_diffusion_scan(key, ts, x0, A, b) + + +def _sample_identity_diffusion_scan( + key: Array, + ts: Array, + x0: Array, + A: Array | ProcessedDriftMatrix, + b: Array, +) -> Array: + if isinstance(A, Array): + A = preprocess_drift_matrix(A) + + def expm_vp(v, dt): + out = A.eigvecs_inv @ v + out = jnp.exp(-A.eigvals * dt) * out + out = A.eigvecs @ out + return out.real + + def transition_mean(x, dt): + return b + expm_vp(x - b, dt) + def transition_cov_sqrt_vp(v, dt): + diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5 + out = diag * v + out = A.sym_eigvecs @ out + return out.real + + def next_x(x, dt, tkey): + randv = jax.random.normal(tkey, shape=x.shape) + return transition_mean(x, dt) + transition_cov_sqrt_vp(randv, dt) + + def scan_body(x_and_key, dt): + x, rk = x_and_key + rk, rk_use = jax.random.split(rk) + x = next_x(x, dt, rk_use) + return (x, rk), x + + dts = jnp.diff(ts) + + xs = jax.lax.scan(scan_body, (x0, key), dts)[1] + xs = jnp.concatenate([jnp.expand_dims(x0, axis=0), xs], axis=0) + return xs + + +def _sample_identity_diffusion_associative_scan( + key: Array, + ts: Array, + x0: Array, + A: Array | ProcessedDriftMatrix, + b: Array, +) -> Array: if isinstance(A, Array): A = preprocess_drift_matrix(A) @@ -89,6 +145,7 @@ def sample( A: Array | ProcessedDriftMatrix, b: Array, D: Array | ProcessedDiffusionMatrix, + associative_scan: bool = True, ) -> Array: """Collects samples from the Ornstein-Uhlenbeck process, defined as: @@ -116,6 +173,7 @@ def sample( not thermox.utils.preprocess_drift_matrix. b: Drift displacement vector. D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + associative_scan: If True, uses jax.lax.associative_scan. Returns: Array-like, desired samples. @@ -125,5 +183,5 @@ def sample( y0 = D.sqrt_inv @ x0 b_y = D.sqrt_inv @ b - ys = sample_identity_diffusion(key, ts, y0, A_y, b_y) + ys = sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan) return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)