From 1b5fb2ad73c5172a745a249726f06a33419571e1 Mon Sep 17 00:00:00 2001 From: denismelanson <59967315+denismelanson@users.noreply.github.com> Date: Thu, 23 May 2024 15:02:03 +0000 Subject: [PATCH] Fix sample and log_prob docstring --- thermox/prob.py | 29 +++++++++++++++++------------ thermox/sampler.py | 10 +++++++--- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/thermox/prob.py b/thermox/prob.py index 7e2f7ce..7dc08f4 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -28,10 +28,10 @@ def log_prob_identity_diffusion( Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2). Args: - ts: array-like, times at which samples are collected. Includes time for x0. - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). - b: drift displacement vector. + ts: Times at which samples are collected. Includes time for x0. + xs: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + b: Drift displacement vector. Returns: Scalar log probability of given xs. """ @@ -94,17 +94,22 @@ def log_prob( Assumes x(t_0) is given deterministically. - Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2). + Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2), + where T=len(ts). + + By default, this function does the preprocessing on A and D before the evaluation. + However, the preprocessing can be done externally using thermox.preprocess + the output of which can be used as A and D here, this will skip the preprocessing. Args: - ts: array-like, times at which samples are collected. Includes time for x0. - xs: initial state of the process. - A: drift matrix (Array or thermox.ProcessedDriftMatrix). + ts: Times at which samples are collected. Includes time for x0. + xs: Initial state of the process. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, - must be transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.preprocess_drift_matrix. - b: drift displacement vector. - D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils._preprocess_drift_matrix. + b: Drift displacement vector. + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: Scalar log probability of given xs. diff --git a/thermox/sampler.py b/thermox/sampler.py index 78499d3..af20aa7 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -88,17 +88,21 @@ def sample( by using exact diagonalization. - Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2) + Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2), where T=len(ts). + By default, this function does the preprocessing on A and D before the evaluation. + However, the preprocessing can be done externally using thermox.preprocess + the output of which can be used as A and D here, this will skip the preprocessing. + Args: key: Jax PRNGKey. ts: Times at which samples are collected. Includes time for x0. x0: Initial state of the process. A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, - must be transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.preprocess_drift_matrix. + must be the transformed drift matrix, A_y, given by thermox.preprocess, + not thermox.utils._preprocess_drift_matrix. b: Drift displacement vector. D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).