From 5ad1a1bfc7de90e2fcfe0bedfda77f964eaa6b46 Mon Sep 17 00:00:00 2001 From: T Coxon <97948946+tttc3@users.noreply.github.com> Date: Mon, 19 Feb 2024 19:02:44 +0000 Subject: [PATCH] Replace usages of PRNGKey with key (#2) Deprecates usage of `jax.random.PRNGKey` in favour of `jax.random.key` as per [JEP 9263](https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html) --- README.md | 2 +- docs/quickstart.md | 6 +++--- mccube/_kernels/base.py | 2 +- mccube/_kernels/random.py | 4 ++-- mccube/_solvers.py | 2 +- tests/test_kernels.py | 6 +++--- tests/test_solvers.py | 6 +++--- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 9ae96a5..473bdd3 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ from mccube import ( gaussian_wasserstein_metric, ) -key = jr.PRNGKey(42) +key = jr.key(42) n, d = 512, 10 t0 = 0.0 epochs = 512 diff --git a/docs/quickstart.md b/docs/quickstart.md index c19557d..884254e 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -54,7 +54,7 @@ from mccube import gaussian_wasserstein_metric, unpack_particles jax.config.update("jax_enable_x64", True) -key, rng_key = jr.split(jr.PRNGKey(42)) +key, rng_key = jr.split(jr.key(42)) n, d = 512, 10 t0 = 0.0 n_epochs = 1024 @@ -133,7 +133,7 @@ def inference_loop(kernel, initial_state, n_epochs, num_chains, *, key): return states -key, sampler_key = jr.split(jr.PRNGKey(42)) +key, sampler_key = jr.split(jr.key(42)) sampler = blackjax.mala(logdensity, dt0) init_state = jax.vmap(sampler.init)(y0) state = inference_loop( @@ -181,7 +181,7 @@ from mccube import ( BinaryTreePartitioningKernel, ) -key = jr.PRNGKey(42) +key = jr.key(42) gaussian_cubature = Hadamard(GaussianRegion(d)) mcc_cde = diffrax.WeaklyDiagonalControlTerm( lambda t, p, args: jnp.sqrt(2.0), diff --git a/mccube/_kernels/base.py b/mccube/_kernels/base.py index d5a4528..7845710 100644 --- a/mccube/_kernels/base.py +++ b/mccube/_kernels/base.py @@ -111,7 +111,7 @@ class PartitioningRecombinationKernel(AbstractRecombinationKernel): import jax.numpy as jnp import jax.random as jr - key = jr.PRNGKey(42) + key = jr.key(42) y0 = jnp.ones((64,8)) n, d = y0.shape diff --git a/mccube/_kernels/random.py b/mccube/_kernels/random.py index fb12f4e..52cb6c4 100644 --- a/mccube/_kernels/random.py +++ b/mccube/_kernels/random.py @@ -27,7 +27,7 @@ class MonteCarloKernel(AbstractRecombinationKernel): import jax.numpy as jnp import jax.random as jr - key = jr.PRNGKey(42) + key = jr.key(42) kernel = mccube.MonteCarloKernel({"y": 3}, key=key) y0 = {"y": jnp.ones((10,2))} result = kernel(..., y0, ...) @@ -80,7 +80,7 @@ class MonteCarloPartitioningKernel(AbstractPartitioningKernel): import jax.numpy as jnp import jax.random as jr - key = jr.PRNGKey(42) + key = jr.key(42) kernel = mccube.MonteCarloKernel(..., key=key) partitioning_kernel = mccube.MonteCarloPartitioningKernel(4, kernel) y0 = jnp.ones((12,2)) diff --git a/mccube/_solvers.py b/mccube/_solvers.py index 68d7b12..70f5801 100644 --- a/mccube/_solvers.py +++ b/mccube/_solvers.py @@ -52,7 +52,7 @@ class MCCSolver(AbstractWrappedSolver[_SolverState]): import jax.random as jr from diffrax import diffeqsolve, Euler - key, rng_key = jr.split(jr.PRNGKey(42)) + key, rng_key = jr.split(jr.key(42)) t0, t1 = 0.0, 1.0 dt0 = 0.001 particles = jnp.ones((32,8)) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index dfe07d4..366af76 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -71,7 +71,7 @@ def test_monte_carlo_partitioning_kernel(): y0 = jnp.array([[1.0, 0.01], [2.0, 1.0], [3.0, 100.0], [4.0, 10000.0]]) - key = jr.PRNGKey(42) + key = jr.key(42) mc_kernel = MonteCarloKernel(None, key=key) kernel = mccube.MonteCarloPartitioningKernel(n_parts, mc_kernel) values = kernel(0.0, y0, ...) @@ -80,7 +80,7 @@ def test_monte_carlo_partitioning_kernel(): jnp.unique(values, return_counts=True), jnp.unique(y0, return_counts=True) ) - key = jr.PRNGKey(42) + key = jr.key(42) mc_kernel = MonteCarloKernel(None, weighting_function=lambda x: x, key=key) kernel = mccube.MonteCarloPartitioningKernel(n_parts, mc_kernel) values = kernel(0.0, y0, ..., weighted=True) @@ -160,7 +160,7 @@ def test_binary_tree_partitioning_kernel(mode): # n, d = 64, 2 -# key = jr.PRNGKey(42) +# key = jr.key(42) # y0 = jr.multivariate_normal(key, jnp.zeros(d), jnp.eye(d), (n,)) # weights = jnp.arange(1.0, n + 1.0) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 8943439..f3595b8 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -18,7 +18,7 @@ from .helpers import gaussian_formulae -key = jr.PRNGKey(42) +key = jr.key(42) init_key, rng_key = jr.split(key) t0 = 0.0 dt0 = 0.05 @@ -46,14 +46,14 @@ def test_diffrax_ula(): ode = ODETerm(lambda t, p, args: grad_logdensity(p)) cde = WeaklyDiagonalControlTerm( lambda t, p, args: jnp.sqrt(2.0), - VirtualBrownianTree(t0, t1, dt0 / 10, (k, d), key=jr.PRNGKey(42)), + VirtualBrownianTree(t0, t1, dt0 / 10, (k, d), key=jr.key(42)), ) terms = MultiTerm(ode, cde) diffeqsolve(terms, Euler(), t0, t1, dt0, y0) def test_MCCSolver_init(): - key = jr.PRNGKey(42) + key = jr.key(42) with pytest.raises(ValueError) as e, pytest.warns(UserWarning) as w: mccube.MCCSolver(EulerHeun(), mccube.MonteCarloKernel(10, key=key), 0)