From 3340a50892725f32c42b0efd29a4a9e411d09d61 Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Tue, 14 May 2024 20:34:48 +0000 Subject: [PATCH 1/8] Fix how sample function handles Processed inputs --- thermox/sampler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/thermox/sampler.py b/thermox/sampler.py index 54feeba..0371c29 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -103,12 +103,13 @@ def sample( Array-like, desired samples. shape: (len(ts), ) + x0.shape """ - if isinstance(A, Array) and isinstance(D, Array): + if isinstance(A, Array) or isinstance(D, Array): + if isinstance(A, ProcessedDriftMatrix): + A = A.val + if isinstance(D, ProcessedDiffusionMatrix): + D = D.val A_y, D = preprocess(A, D) - assert isinstance(A_y, ProcessedDriftMatrix) - assert isinstance(D, ProcessedDiffusionMatrix) - y0 = D.sqrt_inv @ x0 b_y = D.sqrt_inv @ b ys = sample_identity_diffusion(key, ts, y0, A_y, b_y) From 1bca838338ad4e72b15484c8e0d196953ac4229b Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Wed, 15 May 2024 17:40:54 +0000 Subject: [PATCH 2/8] Fix how sample handles processed input --- tests/test_sampler.py | 41 +++++++++++++++++++++++++++++++++++++++++ thermox/sampler.py | 2 ++ 2 files changed, 43 insertions(+) create mode 100644 tests/test_sampler.py diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 0000000..df6709d --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,41 @@ +import jax +from jax import numpy as jnp + +import thermox + + +def test_sample_array_input(): + key = jax.random.PRNGKey(0) + dim = 3 + dt = 0.1 + ts = jnp.arange(0, 10_000, dt) + + A = jnp.array([[3, 2], [2, 4.0]]) + b, x0 = jnp.zeros(dim), jnp.zeros(dim) + D = 2 * jnp.eye(dim) + + samples = thermox.sample(key, ts, x0, A, b, D) + + samp_cov = jnp.cov(samples.T) + samp_mean = jnp.mean(samples.T, axis=1) + assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) + assert jnp.allclose(samp_mean, b, atol=1e-1) + +def test_sample_processed_input(): + key = jax.random.PRNGKey(0) + dim = 3 + dt = 0.1 + ts = jnp.arange(0, 10_000, dt) + + A = jnp.array([[3, 2], [2, 4.0]]) + b, x0 = jnp.zeros(dim), jnp.zeros(dim) + D = 2 * jnp.eye(dim) + + A, D = thermox.utils.preprocess(A, D) + + samples = thermox.sample(key, ts, x0, A, b, D) + + samp_cov = jnp.cov(samples.T) + samp_mean = jnp.mean(samples.T, axis=1) + assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) + assert jnp.allclose(samp_mean, b, atol=1e-1) diff --git a/thermox/sampler.py b/thermox/sampler.py index 0371c29..ff37123 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -109,6 +109,8 @@ def sample( if isinstance(D, ProcessedDiffusionMatrix): D = D.val A_y, D = preprocess(A, D) + else: + A_y = A y0 = D.sqrt_inv @ x0 b_y = D.sqrt_inv @ b From 6359a18253082e6493c026da5b1085d8a33f9612 Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Wed, 15 May 2024 17:47:34 +0000 Subject: [PATCH 3/8] Fix style --- tests/test_sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index df6709d..58ecfcf 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -21,6 +21,7 @@ def test_sample_array_input(): assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) assert jnp.allclose(samp_mean, b, atol=1e-1) + def test_sample_processed_input(): key = jax.random.PRNGKey(0) dim = 3 From 5afaed87ded5a2e2b4b107b0a78b3f67dd6cde4e Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Wed, 15 May 2024 18:36:24 +0000 Subject: [PATCH 4/8] Add sampler tests --- tests/test_sampler.py | 48 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 58ecfcf..0a43d6e 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -6,7 +6,7 @@ def test_sample_array_input(): key = jax.random.PRNGKey(0) - dim = 3 + dim = 2 dt = 0.1 ts = jnp.arange(0, 10_000, dt) @@ -24,7 +24,7 @@ def test_sample_array_input(): def test_sample_processed_input(): key = jax.random.PRNGKey(0) - dim = 3 + dim = 2 dt = 0.1 ts = jnp.arange(0, 10_000, dt) @@ -32,9 +32,49 @@ def test_sample_processed_input(): b, x0 = jnp.zeros(dim), jnp.zeros(dim) D = 2 * jnp.eye(dim) - A, D = thermox.utils.preprocess(A, D) + A_proc, D_proc = thermox.utils.preprocess(A, D) - samples = thermox.sample(key, ts, x0, A, b, D) + samples = thermox.sample(key, ts, x0, A_proc, b, D_proc) + + samp_cov = jnp.cov(samples.T) + samp_mean = jnp.mean(samples.T, axis=1) + assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) + assert jnp.allclose(samp_mean, b, atol=1e-1) + + +def test_sample_processed_drift_array_diffusion_input(): + key = jax.random.PRNGKey(0) + dim = 2 + dt = 0.1 + ts = jnp.arange(0, 10_000, dt) + + A = jnp.array([[3, 2], [2, 4.0]]) + b, x0 = jnp.zeros(dim), jnp.zeros(dim) + D = 2 * jnp.eye(dim) + + A_proc, D_proc = thermox.utils.preprocess(A, D) + + samples = thermox.sample(key, ts, x0, A_proc, b, D) + + samp_cov = jnp.cov(samples.T) + samp_mean = jnp.mean(samples.T, axis=1) + assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) + assert jnp.allclose(samp_mean, b, atol=1e-1) + + +def test_sample_array_drift_processed_diffusion_input(): + key = jax.random.PRNGKey(0) + dim = 2 + dt = 0.1 + ts = jnp.arange(0, 10_000, dt) + + A = jnp.array([[3, 2], [2, 4.0]]) + b, x0 = jnp.zeros(dim), jnp.zeros(dim) + D = 2 * jnp.eye(dim) + + A_proc, D_proc = thermox.utils.preprocess(A, D) + + samples = thermox.sample(key, ts, x0, A, b, D_proc) samp_cov = jnp.cov(samples.T) samp_mean = jnp.mean(samples.T, axis=1) From 6d46b4c005ed3503bfdfdec5f6c56dcebf17c64e Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Wed, 22 May 2024 14:42:55 +0000 Subject: [PATCH 5/8] Add function to handle array inputs --- tests/test_sampler.py | 60 ----------------------------------------- tests/test_utils.py | 63 +++++++++++++++++++++++++++++++++++++++++++ thermox/prob.py | 12 ++++----- thermox/sampler.py | 36 +++++++++++-------------- thermox/utils.py | 33 ++++++++++++++++++++--- 5 files changed, 113 insertions(+), 91 deletions(-) create mode 100644 tests/test_utils.py diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 0a43d6e..14fb46f 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -20,63 +20,3 @@ def test_sample_array_input(): samp_mean = jnp.mean(samples.T, axis=1) assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) assert jnp.allclose(samp_mean, b, atol=1e-1) - - -def test_sample_processed_input(): - key = jax.random.PRNGKey(0) - dim = 2 - dt = 0.1 - ts = jnp.arange(0, 10_000, dt) - - A = jnp.array([[3, 2], [2, 4.0]]) - b, x0 = jnp.zeros(dim), jnp.zeros(dim) - D = 2 * jnp.eye(dim) - - A_proc, D_proc = thermox.utils.preprocess(A, D) - - samples = thermox.sample(key, ts, x0, A_proc, b, D_proc) - - samp_cov = jnp.cov(samples.T) - samp_mean = jnp.mean(samples.T, axis=1) - assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) - assert jnp.allclose(samp_mean, b, atol=1e-1) - - -def test_sample_processed_drift_array_diffusion_input(): - key = jax.random.PRNGKey(0) - dim = 2 - dt = 0.1 - ts = jnp.arange(0, 10_000, dt) - - A = jnp.array([[3, 2], [2, 4.0]]) - b, x0 = jnp.zeros(dim), jnp.zeros(dim) - D = 2 * jnp.eye(dim) - - A_proc, D_proc = thermox.utils.preprocess(A, D) - - samples = thermox.sample(key, ts, x0, A_proc, b, D) - - samp_cov = jnp.cov(samples.T) - samp_mean = jnp.mean(samples.T, axis=1) - assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) - assert jnp.allclose(samp_mean, b, atol=1e-1) - - -def test_sample_array_drift_processed_diffusion_input(): - key = jax.random.PRNGKey(0) - dim = 2 - dt = 0.1 - ts = jnp.arange(0, 10_000, dt) - - A = jnp.array([[3, 2], [2, 4.0]]) - b, x0 = jnp.zeros(dim), jnp.zeros(dim) - D = 2 * jnp.eye(dim) - - A_proc, D_proc = thermox.utils.preprocess(A, D) - - samples = thermox.sample(key, ts, x0, A, b, D_proc) - - samp_cov = jnp.cov(samples.T) - samp_mean = jnp.mean(samples.T, axis=1) - assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1) - assert jnp.allclose(samp_mean, b, atol=1e-1) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..8f3d5ac --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,63 @@ +import jax +from jax import numpy as jnp + +from thermox.utils import ( + handle_matrix_inputs, + ProcessedDriftMatrix, + ProcessedDiffusionMatrix, + preprocess, + preprocess_diffusion_matrix, + preprocess_drift_matrix +) + + +def test_handle_matrix_inputs_arrays(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = handle_matrix_inputs(A, D) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert jnp.all(a.val == A_star.val) + + +def test_handle_matrix_inputs_processed(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = handle_matrix_inputs(a, d) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert jnp.all(a.val == A_star.val) + + +def test_handle_matrix_inputs_array_drift_processed_diffusion(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = handle_matrix_inputs(A, d) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert jnp.all(a.val == A_star.val) + + +def test_handle_matrix_inputs_array_diffusion_processed_drift(): + A = jnp.array([[1, 3], [1, 4]]) + D = jnp.array([[9, 4], [4, 20]]) + + a, d = preprocess(A, D) + + A_star, D_star = handle_matrix_inputs(a, D) + + assert isinstance(A_star, ProcessedDriftMatrix) + assert isinstance(D_star, ProcessedDiffusionMatrix) + assert not jnp.all(a.val == A_star.val) diff --git a/thermox/prob.py b/thermox/prob.py index f0ea53c..7e2f7ce 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -3,7 +3,7 @@ from jax import Array, vmap from thermox.utils import ( - preprocess, + handle_matrix_inputs, preprocess_drift_matrix, ProcessedDriftMatrix, ProcessedDiffusionMatrix, @@ -100,18 +100,16 @@ def log_prob( ts: array-like, times at which samples are collected. Includes time for x0. xs: initial state of the process. A: drift matrix (Array or thermox.ProcessedDriftMatrix). + Note : If a thermox.ProcessedDriftMatrix instance is used as input, + must be transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.preprocess_drift_matrix. b: drift displacement vector. D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: Scalar log probability of given xs. """ - if isinstance(A, Array) or isinstance(D, Array): - if isinstance(A, ProcessedDriftMatrix): - A = A.val - if isinstance(D, ProcessedDiffusionMatrix): - D = D.val - A_y, D = preprocess(A, D) + A_y, D = handle_matrix_inputs(A, D) ys = vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt_inv, xs) b_y = D.sqrt_inv @ b diff --git a/thermox/sampler.py b/thermox/sampler.py index ff37123..78499d3 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -4,7 +4,7 @@ from jax import Array from thermox.utils import ( - preprocess, + handle_matrix_inputs, preprocess_drift_matrix, ProcessedDriftMatrix, ProcessedDiffusionMatrix, @@ -28,11 +28,11 @@ def sample_identity_diffusion( where T=len(ts). Args: - key: jax PRNGKey. - ts: array-like, times at which samples are collected. Includes time for x0. - x0: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + key: Jax PRNGKey. + ts: Times at which samples are collected. Includes time for x0. + x0: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + b: Drift displacement vector. Returns: Array-like, desired samples. @@ -92,25 +92,21 @@ def sample( where T=len(ts). Args: - key: jax PRNGKey. - ts: array-like, times at which samples are collected. Includes time for x0. - x0: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. - D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + key: Jax PRNGKey. + ts: Times at which samples are collected. Includes time for x0. + x0: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + Note : If a thermox.ProcessedDriftMatrix instance is used as input, + must be transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.preprocess_drift_matrix. + b: Drift displacement vector. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: Array-like, desired samples. shape: (len(ts), ) + x0.shape """ - if isinstance(A, Array) or isinstance(D, Array): - if isinstance(A, ProcessedDriftMatrix): - A = A.val - if isinstance(D, ProcessedDiffusionMatrix): - D = D.val - A_y, D = preprocess(A, D) - else: - A_y = A + A_y, D = handle_matrix_inputs(A, D) y0 = D.sqrt_inv @ x0 b_y = D.sqrt_inv @ b diff --git a/thermox/utils.py b/thermox/utils.py index 38ac601..eb89b4c 100644 --- a/thermox/utils.py +++ b/thermox/utils.py @@ -21,7 +21,7 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix: """Preprocesses matrix A (calculates eigendecompositions of A and (A+A^T)/2) Args: - A: drift matrix. + A: Drift matrix. Returns: ProcessedDriftMatrix containing eigendeomcomposition of A and (A+A^T)/2. @@ -59,7 +59,7 @@ def preprocess_diffusion_matrix(D: Array) -> ProcessedDiffusionMatrix: """Preprocesses diffusion matrix D (calculates D^0.5 and D^-0.5 via Cholesky) Args: - D: diffusion matrix. + D: Diffusion matrix. Returns: ProcessedDiffusionMatrix containing D^0.5 and D^-0.5. @@ -77,8 +77,8 @@ def preprocess( D^0.5 and D^-0.5) Args: - A: drift matrix. - D: diffusion matrix. + A: Drift matrix. + D: Diffusion matrix. Returns: ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2. @@ -89,3 +89,28 @@ def preprocess( A_y = PD.sqrt_inv @ A @ PD.sqrt PA_y = preprocess_drift_matrix(A_y) return PA_y, PD + + +def handle_matrix_inputs( + A: Array | ProcessedDriftMatrix, D: Array | ProcessedDiffusionMatrix +) -> Tuple[ProcessedDriftMatrix, ProcessedDiffusionMatrix]: + """Checks the type of the input drift matrix, A, and diffusion matrix, D, + and ensures that they are processed in the correct way. + Helper function for sample and log_prob functions. + + Args: + A: Drift matrix. + D: Diffusion matrix. + + Returns: + ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2. + where A_y = D^-0.5 @ A @ D^0.5 + ProcessedDiffusionMatrix containing D^0.5 and D^-0.5. + """ + if isinstance(A, Array) or isinstance(D, Array): + if isinstance(A, ProcessedDriftMatrix): + A = A.val + if isinstance(D, ProcessedDiffusionMatrix): + D = D.val + A, D = preprocess(A, D) + return A, D From 6163f834c697456376202541e9b4e9d10352988f Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Wed, 22 May 2024 14:46:50 +0000 Subject: [PATCH 6/8] Fix style --- tests/test_utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8f3d5ac..b10dbfe 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,3 @@ -import jax from jax import numpy as jnp from thermox.utils import ( @@ -6,15 +5,13 @@ ProcessedDriftMatrix, ProcessedDiffusionMatrix, preprocess, - preprocess_diffusion_matrix, - preprocess_drift_matrix ) def test_handle_matrix_inputs_arrays(): A = jnp.array([[1, 3], [1, 4]]) D = jnp.array([[9, 4], [4, 20]]) - + a, d = preprocess(A, D) A_star, D_star = handle_matrix_inputs(A, D) @@ -27,7 +24,7 @@ def test_handle_matrix_inputs_arrays(): def test_handle_matrix_inputs_processed(): A = jnp.array([[1, 3], [1, 4]]) D = jnp.array([[9, 4], [4, 20]]) - + a, d = preprocess(A, D) A_star, D_star = handle_matrix_inputs(a, d) @@ -40,7 +37,7 @@ def test_handle_matrix_inputs_processed(): def test_handle_matrix_inputs_array_drift_processed_diffusion(): A = jnp.array([[1, 3], [1, 4]]) D = jnp.array([[9, 4], [4, 20]]) - + a, d = preprocess(A, D) A_star, D_star = handle_matrix_inputs(A, d) @@ -53,7 +50,7 @@ def test_handle_matrix_inputs_array_drift_processed_diffusion(): def test_handle_matrix_inputs_array_diffusion_processed_drift(): A = jnp.array([[1, 3], [1, 4]]) D = jnp.array([[9, 4], [4, 20]]) - + a, d = preprocess(A, D) A_star, D_star = handle_matrix_inputs(a, D) From 1b5fb2ad73c5172a745a249726f06a33419571e1 Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Thu, 23 May 2024 15:02:03 +0000 Subject: [PATCH 7/8] Fix sample and log_prob docstring --- thermox/prob.py | 29 +++++++++++++++++------------ thermox/sampler.py | 10 +++++++--- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/thermox/prob.py b/thermox/prob.py index 7e2f7ce..7dc08f4 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -28,10 +28,10 @@ def log_prob_identity_diffusion( Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2). Args: - ts: array-like, times at which samples are collected. Includes time for x0. - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + ts: Times at which samples are collected. Includes time for x0. + xs: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + b: Drift displacement vector. Returns: Scalar log probability of given xs. """ @@ -94,17 +94,22 @@ def log_prob( Assumes x(t_0) is given deterministically. - Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2). + Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2), + where T=len(ts). + + By default, this function does the preprocessing on A and D before the evaluation. + However, the preprocessing can be done externally using thermox.preprocess + the output of which can be used as A and D here, this will skip the preprocessing. Args: - ts: array-like, times at which samples are collected. Includes time for x0. - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). + ts: Times at which samples are collected. Includes time for x0. + xs: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, - must be transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.preprocess_drift_matrix. - b: drift displacement vector. - D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils._preprocess_drift_matrix. + b: Drift displacement vector. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: Scalar log probability of given xs. diff --git a/thermox/sampler.py b/thermox/sampler.py index 78499d3..af20aa7 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -88,17 +88,21 @@ def sample( by using exact diagonalization. - Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2) + Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2), where T=len(ts). + By default, this function does the preprocessing on A and D before the evaluation. + However, the preprocessing can be done externally using thermox.preprocess + the output of which can be used as A and D here, this will skip the preprocessing. + Args: key: Jax PRNGKey. ts: Times at which samples are collected. Includes time for x0. x0: Initial state of the process. A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, - must be transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.preprocess_drift_matrix. + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils._preprocess_drift_matrix. b: Drift displacement vector. D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). From fee0aac86016d0fd63ce97bdeaef72aa447f497a Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Thu, 23 May 2024 15:59:10 +0000 Subject: [PATCH 8/8] Fix more docstrings --- tests/test_utils.py | 2 +- thermox/prob.py | 2 +- thermox/sampler.py | 2 +- thermox/utils.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index b10dbfe..25f4673 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,7 +14,7 @@ def test_handle_matrix_inputs_arrays(): a, d = preprocess(A, D) - A_star, D_star = handle_matrix_inputs(A, D) + A_star, D_star = preprocess(A, D) assert isinstance(A_star, ProcessedDriftMatrix) assert isinstance(D_star, ProcessedDiffusionMatrix) diff --git a/thermox/prob.py b/thermox/prob.py index 7dc08f4..9f28418 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -107,7 +107,7 @@ def log_prob( A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, must be the transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.utils._preprocess_drift_matrix. + not thermox.utils.preprocess_drift_matrix. b: Drift displacement vector. D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). diff --git a/thermox/sampler.py b/thermox/sampler.py index af20aa7..c4faa96 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -102,7 +102,7 @@ def sample( A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, must be the transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.utils._preprocess_drift_matrix. + not thermox.utils.preprocess_drift_matrix. b: Drift displacement vector. D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). diff --git a/thermox/utils.py b/thermox/utils.py index eb89b4c..ffb4c08 100644 --- a/thermox/utils.py +++ b/thermox/utils.py @@ -99,8 +99,8 @@ def handle_matrix_inputs( Helper function for sample and log_prob functions. Args: - A: Drift matrix. - D: Diffusion matrix. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2.