From a10c8450f532e653bb759e96601fb485081dfe63 Mon Sep 17 00:00:00 2001 From: nstarman Date: Mon, 17 Jun 2024 17:11:36 -0400 Subject: [PATCH] WIP Signed-off-by: nstarman --- equinox/__init__.py | 1 + equinox/_errors.py | 107 ++++++++++++++++++++++++++++++++++++ tests/test_warns.py | 128 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 236 insertions(+) create mode 100644 tests/test_warns.py diff --git a/equinox/__init__.py b/equinox/__init__.py index 1e5b37b8..671aff69 100644 --- a/equinox/__init__.py +++ b/equinox/__init__.py @@ -25,6 +25,7 @@ branched_error_if as branched_error_if, EquinoxTracetimeError as EquinoxTracetimeError, error_if as error_if, + warn_if as warn_if, ) from ._eval_shape import filter_eval_shape as filter_eval_shape from ._filters import ( diff --git a/equinox/_errors.py b/equinox/_errors.py index 0ffb4cf0..2d5ac2a0 100644 --- a/equinox/_errors.py +++ b/equinox/_errors.py @@ -326,3 +326,110 @@ def assert_dce( else: # Don't run if not JIT'ing, as without the compiler nothing will be DCE'd. return x + + +# ------------------------------------ + + +def warn_if( + x: PyTree, + pred: Bool[ArrayLike, "..."], + msg: str, + *, + category: Warning | None = None, + stacklevel: int = 1 +) -> PyTree: + """Surfaces a warning based on runtime values. Works even under JIT. + + **Arguments:** + + - `x`: will be returned unchanged. This is used to determine where the warning check + happens in the overall computation: it will happen after `x` is computed and + before the return value is used. `x` can be any PyTree, and it must contain at + least one array. + - `pred`: a boolean for whether to raise an warning. Can be an array of bools; an + warning will be raised if any of them are `True`. If vmap'd then an warning will be + raised if any batch element has `True`. + - `msg`: the string to display as an warning message. + + **Returns:** + + The original argument `x` unchanged. **If this return value is unused then the warning + check will not be performed.** (It will be removed as part of dead code + elimination.) + + !!! Example + + ```python + @jax.jit + def f(x): + x = warn_if(x, x < 0, "x must be >= 0") + # ...use x in your computation... + return x + + f(jax.numpy.array(-1)) + ``` + """ + return branched_warn_if(x, pred, 0, [msg], category=category, stacklevel=stacklevel) + + +def branched_warn_if( + x: PyTree, + pred: Bool[ArrayLike, "..."], + index: Int[ArrayLike, "..."], + msgs: Sequence[str], + *, + category: Warning | None = None, + stacklevel: int = 1 +) -> PyTree: + """As [`equinox.warn_if`][], but will raise one of + several `msgs` depending on the value of `index`. If `index` is vmap'd, then the + warn message from the largest value (across the whole batch) will be used. + """ + leaves = jtu.tree_leaves((x, pred, index)) + # This carefully does not perform any JAX operations if `pred` and `index` are + # a bool and an int. + # This ensures we can use `warn_if` before init_google. + if any(is_array(leaf) for leaf in leaves): + return branched_warn_if_impl_jit(x, pred, index, msgs, category=category, stacklevel=stacklevel) + else: + return branched_warn_if_impl(x, pred, index, msgs, category=category, stacklevel=stacklevel) + + +def warning_callback(message, *, category: Warning | None, stacklevel: int): + warnings.warn(message, category=category, stacklevel=stacklevel) + + +def branched_warn_if_impl( + x: PyTree, + pred: Bool[ArrayLike, "..."], + index: Int[ArrayLike, "..."], + msgs: Sequence[str], + *, + category: Warning | None = None, + stacklevel: int = 1 +) -> PyTree: + with jax.ensure_compile_time_eval(): + # This carefully does not perform any JAX operations if `pred` and `index` are + # a bool and an int. + # This ensures we can use `error_if` before init_google. + if not isinstance(pred, bool): + pred = unvmap_any(pred) + if not isinstance(index, int): + index = unvmap_max(index) + if not isinstance(pred, jax.core.Tracer): + if isinstance(pred, Array): + pred = pred.item() + assert type(pred) is bool + if pred: + if not isinstance(index, jax.core.Tracer): + if isinstance(index, Array): + index = index.item() + assert type(index) is int + jax.debug.callback(warning_callback, msgs[index], category=category, stacklevel=stacklevel) + return x + + +# filter_jit does some work to produce nicer runtime warning messages. +# We also place it here to ensure a consistent experience when using JAX in eager mode. +branched_warn_if_impl_jit = filter_jit(branched_warn_if_impl) \ No newline at end of file diff --git a/tests/test_warns.py b/tests/test_warns.py new file mode 100644 index 00000000..a3dd7e8a --- /dev/null +++ b/tests/test_warns.py @@ -0,0 +1,128 @@ +import equinox as eqx +import equinox.internal as eqxi +import jax +import jax.numpy as jnp +import pytest +import warnings + + +def _f(x): + x = eqx.warn_if(x, x < 0, "x must be non-negative") + return jax.nn.relu(x) + + +# Strangely, JAX raises different errors depending on context. +_warn = pytest.warns(UserWarning) + + +def test_basic(): + jf = jax.jit(_f) + _f(1.0) + jf(1.0) + with _warn: + _f(-1.0) + with _warn: + jf(-1.0) + with _warn: + jf(-1.0) + + +def test_vmap(): + vf = jax.vmap(_f) + jvf = jax.jit(vf) + good = jnp.array([1.0, 1.0]) + bad1 = jnp.array([1.0, -1.0]) + bad2 = jnp.array([-1.0, -1.0]) + + vf(good) + jvf(good) + with _warn: + vf(bad1) + with _warn: + vf(bad2) + with _warn: + jvf(bad1) + with _warn: + jvf(bad2) + + +def test_jvp(): + def g(p, t): + return jax.jvp(_f, (p,), (t,)) + + jg = jax.jit(g) + + for h in (g, jg): + h(1.0, 1.0) + h(1.0, -1.0) + with _warn: + h(-1.0, 1.0) + with _warn: + h(-1.0, -1.0) + + +def test_grad(): + g = jax.grad(_f) + jg = jax.jit(g) + + for h in (g, jg): + h(1.0) + with _warn: + h(-1.0) + + +def test_grad2(): + @jax.jit + @jax.grad + def f(x, y, z): + x = eqxi.nondifferentiable_backward(x) + x, y = eqx.warn_if((x, y), z, "oops") + return y + + f(1.0, 1.0, True) + + +def test_tracetime(): + @jax.jit + def f(x): + return eqx.warn_if(x, True, "hi") + + with pytest.warns(UserWarning): + f(1.0) + + +def test_assert_dce(): + @jax.jit + def f(x): + x = x + 1 + eqxi.assert_dce(x, msg="oh no") + return x + + f(1.0) + + @jax.jit + def g(x): + x = x + 1 + eqxi.assert_dce(x, msg="oh no") + return x + + with jax.disable_jit(): + g(1.0) + + +# def test_traceback_runtime_eqx(): +# @eqx.filter_jit +# def f(x): +# return g(x) + +# @eqx.filter_jit +# def g(x): +# return eqx.warn_if(x, x > 0, "egads") + +# f(jnp.array(1.0)) +# except Exception as e: +# assert e.__cause__ is None +# msg = str(e).strip() +# assert msg.startswith("egads") +# assert "EQX_ON_warn" in msg +# assert msg.endswith("information.")