diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 0000000..14fb46f --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,22 @@ +import jax +from jax import numpy as jnp + +import thermox + + +def test_sample_array_input(): + key = jax.random.PRNGKey(0) + dim = 2 + dt = 0.1 + ts = jnp.arange(0, 10_000, dt) + + A = jnp.array([[3, 2], [2, 4.0]]) + b, x0 = jnp.zeros(dim), jnp.zeros(dim) + D = 2 * jnp.eye(dim) + + samples = thermox.sample(key, ts, x0, A, b, D) + + samp_cov = jnp.cov(samples.T) + samp_mean = jnp.mean(samples.T, axis=1) + assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) + assert jnp.allclose(samp_mean, b, atol=1e-1) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..25f4673 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,60 @@ +from jax import numpy as jnp + +from thermox.utils import ( + handle_matrix_inputs, + ProcessedDriftMatrix, + ProcessedDiffusionMatrix, + preprocess, +) + + +def test_handle_matrix_inputs_arrays(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = preprocess(A, D) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert jnp.all(a.val == A_star.val) + + +def test_handle_matrix_inputs_processed(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = handle_matrix_inputs(a, d) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert jnp.all(a.val == A_star.val) + + +def test_handle_matrix_inputs_array_drift_processed_diffusion(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = handle_matrix_inputs(A, d) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert jnp.all(a.val == A_star.val) + + +def test_handle_matrix_inputs_array_diffusion_processed_drift(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = handle_matrix_inputs(a, D) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert not jnp.all(a.val == A_star.val) diff --git a/thermox/prob.py b/thermox/prob.py index f0ea53c..9f28418 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -3,7 +3,7 @@ from jax import Array, vmap from thermox.utils import ( - preprocess, + handle_matrix_inputs, preprocess_drift_matrix, ProcessedDriftMatrix, ProcessedDiffusionMatrix, @@ -28,10 +28,10 @@ def log_prob_identity_diffusion( Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2). Args: - ts: array-like, times at which samples are collected. Includes time for x0. - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + ts: Times at which samples are collected. Includes time for x0. + xs: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + b: Drift displacement vector. Returns: Scalar log probability of given xs. """ @@ -94,24 +94,27 @@ def log_prob( Assumes x(t_0) is given deterministically. - Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2). + Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2), + where T=len(ts). + + By default, this function does the preprocessing on A and D before the evaluation. + However, the preprocessing can be done externally using thermox.preprocess + the output of which can be used as A and D here, this will skip the preprocessing. Args: - ts: array-like, times at which samples are collected. Includes time for x0. - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. - D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + ts: Times at which samples are collected. Includes time for x0. + xs: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + Note : If a thermox.ProcessedDriftMatrix instance is used as input, + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils.preprocess_drift_matrix. + b: Drift displacement vector. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: Scalar log probability of given xs. """ - if isinstance(A, Array) or isinstance(D, Array): - if isinstance(A, ProcessedDriftMatrix): - A = A.val - if isinstance(D, ProcessedDiffusionMatrix): - D = D.val - A_y, D = preprocess(A, D) + A_y, D = handle_matrix_inputs(A, D) ys = vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt_inv, xs) b_y = D.sqrt_inv @ b diff --git a/thermox/sampler.py b/thermox/sampler.py index 54feeba..c4faa96 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -4,7 +4,7 @@ from jax import Array from thermox.utils import ( - preprocess, + handle_matrix_inputs, preprocess_drift_matrix, ProcessedDriftMatrix, ProcessedDiffusionMatrix, @@ -28,11 +28,11 @@ def sample_identity_diffusion( where T=len(ts). Args: - key: jax PRNGKey. - ts: array-like, times at which samples are collected. Includes time for x0. - x0: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + key: Jax PRNGKey. + ts: Times at which samples are collected. Includes time for x0. + x0: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + b: Drift displacement vector. Returns: Array-like, desired samples. @@ -88,26 +88,29 @@ def sample( by using exact diagonalization. - Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2) + Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2), where T=len(ts). + By default, this function does the preprocessing on A and D before the evaluation. + However, the preprocessing can be done externally using thermox.preprocess + the output of which can be used as A and D here, this will skip the preprocessing. + Args: - key: jax PRNGKey. - ts: array-like, times at which samples are collected. Includes time for x0. - x0: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. - D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + key: Jax PRNGKey. + ts: Times at which samples are collected. Includes time for x0. + x0: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + Note : If a thermox.ProcessedDriftMatrix instance is used as input, + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils.preprocess_drift_matrix. + b: Drift displacement vector. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: Array-like, desired samples. shape: (len(ts), ) + x0.shape """ - if isinstance(A, Array) and isinstance(D, Array): - A_y, D = preprocess(A, D) - - assert isinstance(A_y, ProcessedDriftMatrix) - assert isinstance(D, ProcessedDiffusionMatrix) + A_y, D = handle_matrix_inputs(A, D) y0 = D.sqrt_inv @ x0 b_y = D.sqrt_inv @ b diff --git a/thermox/utils.py b/thermox/utils.py index 38ac601..ffb4c08 100644 --- a/thermox/utils.py +++ b/thermox/utils.py @@ -21,7 +21,7 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix: """Preprocesses matrix A (calculates eigendecompositions of A and (A+A^T)/2) Args: - A: drift matrix. + A: Drift matrix. Returns: ProcessedDriftMatrix containing eigendeomcomposition of A and (A+A^T)/2. @@ -59,7 +59,7 @@ def preprocess_diffusion_matrix(D: Array) -> ProcessedDiffusionMatrix: """Preprocesses diffusion matrix D (calculates D^0.5 and D^-0.5 via Cholesky) Args: - D: diffusion matrix. + D: Diffusion matrix. Returns: ProcessedDiffusionMatrix containing D^0.5 and D^-0.5. @@ -77,8 +77,8 @@ def preprocess( D^0.5 and D^-0.5) Args: - A: drift matrix. - D: diffusion matrix. + A: Drift matrix. + D: Diffusion matrix. Returns: ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2. @@ -89,3 +89,28 @@ def preprocess( A_y = PD.sqrt_inv @ A @ PD.sqrt PA_y = preprocess_drift_matrix(A_y) return PA_y, PD + + +def handle_matrix_inputs( + A: Array | ProcessedDriftMatrix, D: Array | ProcessedDiffusionMatrix +) -> Tuple[ProcessedDriftMatrix, ProcessedDiffusionMatrix]: + """Checks the type of the input drift matrix, A, and diffusion matrix, D, + and ensures that they are processed in the correct way. + Helper function for sample and log_prob functions. + + Args: + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + + Returns: + ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2. + where A_y = D^-0.5 @ A @ D^0.5 + ProcessedDiffusionMatrix containing D^0.5 and D^-0.5. + """ + if isinstance(A, Array) or isinstance(D, Array): + if isinstance(A, ProcessedDriftMatrix): + A = A.val + if isinstance(D, ProcessedDiffusionMatrix): + D = D.val + A, D = preprocess(A, D) + return A, D