diff --git a/tests/test_utils.py b/tests/test_utils.py index b10dbfe..25f4673 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,7 +14,7 @@ def test_handle_matrix_inputs_arrays(): a, d = preprocess(A, D) - A_star, D_star = handle_matrix_inputs(A, D) + A_star, D_star = preprocess(A, D) assert isinstance(A_star, ProcessedDriftMatrix) assert isinstance(D_star, ProcessedDiffusionMatrix) diff --git a/thermox/prob.py b/thermox/prob.py index 7dc08f4..9f28418 100644 --- a/thermox/prob.py +++ b/thermox/prob.py @@ -107,7 +107,7 @@ def log_prob( A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, must be the transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.utils._preprocess_drift_matrix. + not thermox.utils.preprocess_drift_matrix. b: Drift displacement vector. D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). diff --git a/thermox/sampler.py b/thermox/sampler.py index af20aa7..c4faa96 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -102,7 +102,7 @@ def sample( A: Drift matrix (Array or thermox.ProcessedDriftMatrix). Note : If a thermox.ProcessedDriftMatrix instance is used as input, must be the transformed drift matrix, A_y, given by thermox.preprocess, - not thermox.utils._preprocess_drift_matrix. + not thermox.utils.preprocess_drift_matrix. b: Drift displacement vector. D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). diff --git a/thermox/utils.py b/thermox/utils.py index eb89b4c..ffb4c08 100644 --- a/thermox/utils.py +++ b/thermox/utils.py @@ -99,8 +99,8 @@ def handle_matrix_inputs( Helper function for sample and log_prob functions. Args: - A: Drift matrix. - D: Diffusion matrix. + A: Drift matrix (Array or thermox.ProcessedDriftMatrix). + D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix). Returns: ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2.