jaxGPs
is a light-weight framework for performing Gaussian Process Regression
with the help of JAX
.
It’s main purpose is to provide a thin but useful abstraction that makes it comfortable to use JAX-based inference, while still being easy to customize and play around with.
To achieve this, two of the major design goals were to not depend on another library besides JAX (such as objax), while still providing an object-based interface to the Gaussian Process Regression components. As such, it makes heavy use of JAX pytrees, as well as closures. It also utilizes python Descriptors for parameter tracking.
from pprint import pprint
import jax.numpy as jnp
from jaxGPs import GPR
from jaxGPs import ExponentialQuadratic
x, y = get_toy_data()
gpr = GPR(ExponentialQuadratic())
gpr.update_data(x, y)
gpr.fit_scipy()
pprint(gpr.parameters())
atx = jnp.linspace(x.min(), x.max(), 200)[:, jnp.newaxis]
mu, cov = gpr.predict_f(atx)
Using a python virtual environment, and assuming the jaxGPs
repository has
been cloned and cd
‘ed into:
python -m venv venv
source venv/bin/activate
pip install .
- see more examples with toy data at applications/toy_data_example.py
GaussMarkov
- also known as White Noise, Brownian or Wiener
Exponential
- also known as Matern1/2, but with slightly different scaling
ExponentialQuadratic
- also known as Gaussian, RBF, SquaredExponential,…
HavExponential
HavExponentialQuad
ConstantMean
GPR