forked from dfm/emcee
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
noticed issue dfm#440, undid previous commit, and supressed the warning in this instance
- Loading branch information
1 parent
aa8774f
commit 8c8ca12
Showing
1 changed file
with
188 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,185 +1,188 @@ | ||
""" | ||
Unit tests of some functionality in ensemble.py when the parameters are named | ||
""" | ||
|
||
import string | ||
from unittest import TestCase | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from emcee.ensemble import EnsembleSampler, ndarray_to_list_of_dicts | ||
|
||
|
||
class TestNP2ListOfDicts(TestCase): | ||
def test_ndarray_to_list_of_dicts(self): | ||
# Try different numbers of keys | ||
for n_keys in [1, 2, 10, 26]: | ||
keys = list(string.ascii_lowercase[:n_keys]) | ||
key_set = set(keys) | ||
key_dict = {key: i for i, key in enumerate(keys)} | ||
# Try different number of walker/procs | ||
for N in [1, 2, 3, 10, 100]: | ||
x = np.random.rand(N, n_keys) | ||
|
||
LOD = ndarray_to_list_of_dicts(x, key_dict) | ||
assert len(LOD) == N, "need 1 dict per row" | ||
for i, dct in enumerate(LOD): | ||
assert dct.keys() == key_set, "keys are missing" | ||
for j, key in enumerate(keys): | ||
assert dct[key] == x[i, j], f"wrong value at {(i, j)}" | ||
|
||
|
||
class TestNamedParameters(TestCase): | ||
""" | ||
Test that a keyword-based log-probability function instead of | ||
a positional. | ||
""" | ||
|
||
# Keyword based lnpdf | ||
def lnpdf(self, pars) -> np.float64: | ||
mean = pars["mean"] | ||
var = pars["var"] | ||
if var <= 0: | ||
return -np.inf | ||
return ( | ||
-0.5 * ((mean - self.x) ** 2 / var + np.log(2 * np.pi * var)).sum() | ||
) | ||
|
||
def lnpdf_mixture(self, pars) -> np.float64: | ||
mean1 = pars["mean1"] | ||
var1 = pars["var1"] | ||
mean2 = pars["mean2"] | ||
var2 = pars["var2"] | ||
if var1 <= 0 or var2 <= 0: | ||
return -np.inf | ||
return ( | ||
-0.5 | ||
* ( | ||
(mean1 - self.x) ** 2 / var1 | ||
+ np.log(2 * np.pi * var1) | ||
+ (mean2 - self.x - 3) ** 2 / var2 | ||
+ np.log(2 * np.pi * var2) | ||
).sum() | ||
) | ||
|
||
def lnpdf_mixture_grouped(self, pars) -> np.float64: | ||
mean1, mean2 = pars["means"] | ||
var1, var2 = pars["vars"] | ||
const = pars["constant"] | ||
if var1 <= 0 or var2 <= 0: | ||
return -np.inf | ||
return ( | ||
-0.5 | ||
* ( | ||
(mean1 - self.x) ** 2 / var1 | ||
+ np.log(2 * np.pi * var1) | ||
+ (mean2 - self.x - 3) ** 2 / var2 | ||
+ np.log(2 * np.pi * var2) | ||
).sum() | ||
+ const | ||
) | ||
|
||
def setUp(self): | ||
# Draw some data from a unit Gaussian | ||
self.x = np.random.randn(100) | ||
self.names = ["mean", "var"] | ||
|
||
def test_named_parameters(self): | ||
sampler = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
assert sampler.params_are_named | ||
assert list(sampler.parameter_names.keys()) == self.names | ||
|
||
def test_asserts(self): | ||
# ndim name mismatch | ||
with pytest.raises(AssertionError): | ||
_ = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=len(self.names) - 1, | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
|
||
# duplicate names | ||
with pytest.raises(AssertionError): | ||
_ = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=3, | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=["a", "b", "a"], | ||
) | ||
|
||
# vectorize turned on | ||
with pytest.raises(AssertionError): | ||
_ = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
vectorize=True, | ||
) | ||
|
||
def test_compute_log_prob(self): | ||
# Try different numbers of walkers | ||
for N in [4, 8, 10]: | ||
sampler = EnsembleSampler( | ||
nwalkers=N, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
coords = np.random.rand(N, len(self.names)) | ||
lnps, _ = sampler.compute_log_prob(coords) | ||
assert len(lnps) == N | ||
assert lnps.dtype == np.float64 | ||
|
||
def test_compute_log_prob_mixture(self): | ||
names = ["mean1", "var1", "mean2", "var2"] | ||
# Try different numbers of walkers | ||
for N in [8, 10, 20]: | ||
sampler = EnsembleSampler( | ||
nwalkers=N, | ||
ndim=len(names), | ||
log_prob_fn=self.lnpdf_mixture, | ||
parameter_names=names, | ||
) | ||
coords = np.random.rand(N, len(names)) | ||
lnps, _ = sampler.compute_log_prob(coords) | ||
assert len(lnps) == N | ||
assert lnps.dtype == np.float64 | ||
|
||
def test_compute_log_prob_mixture_grouped(self): | ||
names = {"means": [0, 1], "vars": [2, 3], "constant": 4} | ||
# Try different numbers of walkers | ||
for N in [8, 10, 20]: | ||
sampler = EnsembleSampler( | ||
nwalkers=N, | ||
ndim=5, | ||
log_prob_fn=self.lnpdf_mixture_grouped, | ||
parameter_names=names, | ||
) | ||
coords = np.random.rand(N, 5) | ||
lnps, _ = sampler.compute_log_prob(coords) | ||
assert len(lnps) == N | ||
assert lnps.dtype == np.float64 | ||
|
||
def test_run_mcmc(self): | ||
# Sort of an integration test | ||
n_walkers = 4 | ||
sampler = EnsembleSampler( | ||
nwalkers=n_walkers, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
guess = np.random.rand(n_walkers, len(self.names)) | ||
n_steps = 50 | ||
results = sampler.run_mcmc(guess, n_steps) | ||
assert results.coords.shape == (n_walkers, len(self.names)) | ||
chain = sampler.get_chain() | ||
assert chain.shape == (n_steps, n_walkers, len(self.names)) | ||
""" | ||
Unit tests of some functionality in ensemble.py when the parameters are named | ||
""" | ||
|
||
import string | ||
from unittest import TestCase | ||
import warnings | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from emcee.ensemble import EnsembleSampler, ndarray_to_list_of_dicts | ||
|
||
|
||
class TestNP2ListOfDicts(TestCase): | ||
def test_ndarray_to_list_of_dicts(self): | ||
# Try different numbers of keys | ||
for n_keys in [1, 2, 10, 26]: | ||
keys = list(string.ascii_lowercase[:n_keys]) | ||
key_set = set(keys) | ||
key_dict = {key: i for i, key in enumerate(keys)} | ||
# Try different number of walker/procs | ||
for N in [1, 2, 3, 10, 100]: | ||
x = np.random.rand(N, n_keys) | ||
|
||
LOD = ndarray_to_list_of_dicts(x, key_dict) | ||
assert len(LOD) == N, "need 1 dict per row" | ||
for i, dct in enumerate(LOD): | ||
assert dct.keys() == key_set, "keys are missing" | ||
for j, key in enumerate(keys): | ||
assert dct[key] == x[i, j], f"wrong value at {(i, j)}" | ||
|
||
|
||
class TestNamedParameters(TestCase): | ||
""" | ||
Test that a keyword-based log-probability function instead of | ||
a positional. | ||
""" | ||
|
||
# Keyword based lnpdf | ||
def lnpdf(self, pars) -> np.float64: | ||
mean = pars["mean"] | ||
var = pars["var"] | ||
if var <= 0: | ||
return -np.inf | ||
return ( | ||
-0.5 * ((mean - self.x) ** 2 / var + np.log(2 * np.pi * var)).sum() | ||
) | ||
|
||
def lnpdf_mixture(self, pars) -> np.float64: | ||
mean1 = pars["mean1"] | ||
var1 = pars["var1"] | ||
mean2 = pars["mean2"] | ||
var2 = pars["var2"] | ||
if var1 <= 0 or var2 <= 0: | ||
return -np.inf | ||
return ( | ||
-0.5 | ||
* ( | ||
(mean1 - self.x) ** 2 / var1 | ||
+ np.log(2 * np.pi * var1) | ||
+ (mean2 - self.x - 3) ** 2 / var2 | ||
+ np.log(2 * np.pi * var2) | ||
).sum() | ||
) | ||
|
||
def lnpdf_mixture_grouped(self, pars) -> np.float64: | ||
mean1, mean2 = pars["means"] | ||
var1, var2 = pars["vars"] | ||
const = pars["constant"] | ||
if var1 <= 0 or var2 <= 0: | ||
return -np.inf | ||
return ( | ||
-0.5 | ||
* ( | ||
(mean1 - self.x) ** 2 / var1 | ||
+ np.log(2 * np.pi * var1) | ||
+ (mean2 - self.x - 3) ** 2 / var2 | ||
+ np.log(2 * np.pi * var2) | ||
).sum() | ||
+ const | ||
) | ||
|
||
def setUp(self): | ||
# Draw some data from a unit Gaussian | ||
self.x = np.random.randn(100) | ||
self.names = ["mean", "var"] | ||
|
||
def test_named_parameters(self): | ||
sampler = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
assert sampler.params_are_named | ||
assert list(sampler.parameter_names.keys()) == self.names | ||
|
||
def test_asserts(self): | ||
# ndim name mismatch | ||
with pytest.raises(AssertionError): | ||
_ = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=len(self.names) - 1, | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
|
||
# duplicate names | ||
with pytest.raises(AssertionError): | ||
_ = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=3, | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=["a", "b", "a"], | ||
) | ||
|
||
# vectorize turned on | ||
with pytest.raises(AssertionError): | ||
_ = EnsembleSampler( | ||
nwalkers=10, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
vectorize=True, | ||
) | ||
|
||
def test_compute_log_prob(self): | ||
# Try different numbers of walkers | ||
for N in [4, 8, 10]: | ||
sampler = EnsembleSampler( | ||
nwalkers=N, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
coords = np.random.rand(N, len(self.names)) | ||
lnps, _ = sampler.compute_log_prob(coords) | ||
assert len(lnps) == N | ||
assert lnps.dtype == np.float64 | ||
|
||
def test_compute_log_prob_mixture(self): | ||
names = ["mean1", "var1", "mean2", "var2"] | ||
# Try different numbers of walkers | ||
for N in [8, 10, 20]: | ||
sampler = EnsembleSampler( | ||
nwalkers=N, | ||
ndim=len(names), | ||
log_prob_fn=self.lnpdf_mixture, | ||
parameter_names=names, | ||
) | ||
coords = np.random.rand(N, len(names)) | ||
lnps, _ = sampler.compute_log_prob(coords) | ||
assert len(lnps) == N | ||
assert lnps.dtype == np.float64 | ||
|
||
def test_compute_log_prob_mixture_grouped(self): | ||
names = {"means": [0, 1], "vars": [2, 3], "constant": 4} | ||
# Try different numbers of walkers | ||
for N in [8, 10, 20]: | ||
sampler = EnsembleSampler( | ||
nwalkers=N, | ||
ndim=5, | ||
log_prob_fn=self.lnpdf_mixture_grouped, | ||
parameter_names=names, | ||
) | ||
coords = np.random.rand(N, 5) | ||
lnps, _ = sampler.compute_log_prob(coords) | ||
assert len(lnps) == N | ||
assert lnps.dtype == np.float64 | ||
|
||
def test_run_mcmc(self): | ||
# Sort of an integration test | ||
n_walkers = 4 | ||
sampler = EnsembleSampler( | ||
nwalkers=n_walkers, | ||
ndim=len(self.names), | ||
log_prob_fn=self.lnpdf, | ||
parameter_names=self.names, | ||
) | ||
guess = np.random.rand(n_walkers, len(self.names)) | ||
n_steps = 50 | ||
results = sampler.run_mcmc(guess, n_steps) | ||
assert results.coords.shape == (n_walkers, len(self.names)) | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter('ignore', category=DeprecationWarning) | ||
chain = sampler.chain | ||
assert chain.shape == (n_walkers, n_steps, len(self.names)) |