Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warn_if #762

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
branched_error_if as branched_error_if,
EquinoxTracetimeError as EquinoxTracetimeError,
error_if as error_if,
warn_if as warn_if,
)
from ._eval_shape import filter_eval_shape as filter_eval_shape
from ._filters import (
Expand Down
107 changes: 107 additions & 0 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,110 @@ def assert_dce(
else:
# Don't run if not JIT'ing, as without the compiler nothing will be DCE'd.
return x


# ------------------------------------


def warn_if(
x: PyTree,
pred: Bool[ArrayLike, "..."],
msg: str,
*,
category: Warning | None = None,
stacklevel: int = 1
) -> PyTree:
"""Surfaces a warning based on runtime values. Works even under JIT.

**Arguments:**

- `x`: will be returned unchanged. This is used to determine where the warning check
happens in the overall computation: it will happen after `x` is computed and
before the return value is used. `x` can be any PyTree, and it must contain at
least one array.
- `pred`: a boolean for whether to raise an warning. Can be an array of bools; an
warning will be raised if any of them are `True`. If vmap'd then an warning will be
raised if any batch element has `True`.
- `msg`: the string to display as an warning message.

**Returns:**

The original argument `x` unchanged. **If this return value is unused then the warning
check will not be performed.** (It will be removed as part of dead code
elimination.)

!!! Example

```python
@jax.jit
def f(x):
x = warn_if(x, x < 0, "x must be >= 0")
# ...use x in your computation...
return x

f(jax.numpy.array(-1))
```
"""
return branched_warn_if(x, pred, 0, [msg], category=category, stacklevel=stacklevel)


def branched_warn_if(
x: PyTree,
pred: Bool[ArrayLike, "..."],
index: Int[ArrayLike, "..."],
msgs: Sequence[str],
*,
category: Warning | None = None,
stacklevel: int = 1
) -> PyTree:
"""As [`equinox.warn_if`][], but will raise one of
several `msgs` depending on the value of `index`. If `index` is vmap'd, then the
warn message from the largest value (across the whole batch) will be used.
"""
leaves = jtu.tree_leaves((x, pred, index))
# This carefully does not perform any JAX operations if `pred` and `index` are
# a bool and an int.
# This ensures we can use `warn_if` before init_google.
if any(is_array(leaf) for leaf in leaves):
return branched_warn_if_impl_jit(x, pred, index, msgs, category=category, stacklevel=stacklevel)
else:
return branched_warn_if_impl(x, pred, index, msgs, category=category, stacklevel=stacklevel)


def warning_callback(message, *, category: Warning | None, stacklevel: int):
warnings.warn(message, category=category, stacklevel=stacklevel)


def branched_warn_if_impl(
x: PyTree,
pred: Bool[ArrayLike, "..."],
index: Int[ArrayLike, "..."],
msgs: Sequence[str],
*,
category: Warning | None = None,
stacklevel: int = 1
) -> PyTree:
with jax.ensure_compile_time_eval():
# This carefully does not perform any JAX operations if `pred` and `index` are
# a bool and an int.
# This ensures we can use `error_if` before init_google.
if not isinstance(pred, bool):
pred = unvmap_any(pred)
if not isinstance(index, int):
index = unvmap_max(index)
if not isinstance(pred, jax.core.Tracer):
if isinstance(pred, Array):
pred = pred.item()
assert type(pred) is bool
if pred:
if not isinstance(index, jax.core.Tracer):
if isinstance(index, Array):
index = index.item()
assert type(index) is int
jax.debug.callback(warning_callback, msgs[index], category=category, stacklevel=stacklevel)
return x


# filter_jit does some work to produce nicer runtime warning messages.
# We also place it here to ensure a consistent experience when using JAX in eager mode.
branched_warn_if_impl_jit = filter_jit(branched_warn_if_impl)
128 changes: 128 additions & 0 deletions tests/test_warns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import pytest
import warnings


def _f(x):
x = eqx.warn_if(x, x < 0, "x must be non-negative")
return jax.nn.relu(x)


# Strangely, JAX raises different errors depending on context.
_warn = pytest.warns(UserWarning)


def test_basic():
jf = jax.jit(_f)
_f(1.0)
jf(1.0)
with _warn:
_f(-1.0)
with _warn:
jf(-1.0)
with _warn:
jf(-1.0)


def test_vmap():
vf = jax.vmap(_f)
jvf = jax.jit(vf)
good = jnp.array([1.0, 1.0])
bad1 = jnp.array([1.0, -1.0])
bad2 = jnp.array([-1.0, -1.0])

vf(good)
jvf(good)
with _warn:
vf(bad1)
with _warn:
vf(bad2)
with _warn:
jvf(bad1)
with _warn:
jvf(bad2)


def test_jvp():
def g(p, t):
return jax.jvp(_f, (p,), (t,))

jg = jax.jit(g)

for h in (g, jg):
h(1.0, 1.0)
h(1.0, -1.0)
with _warn:
h(-1.0, 1.0)
with _warn:
h(-1.0, -1.0)


def test_grad():
g = jax.grad(_f)
jg = jax.jit(g)

for h in (g, jg):
h(1.0)
with _warn:
h(-1.0)


def test_grad2():
@jax.jit
@jax.grad
def f(x, y, z):
x = eqxi.nondifferentiable_backward(x)
x, y = eqx.warn_if((x, y), z, "oops")
return y

f(1.0, 1.0, True)


def test_tracetime():
@jax.jit
def f(x):
return eqx.warn_if(x, True, "hi")

with pytest.warns(UserWarning):
f(1.0)


def test_assert_dce():
@jax.jit
def f(x):
x = x + 1
eqxi.assert_dce(x, msg="oh no")
return x

f(1.0)

@jax.jit
def g(x):
x = x + 1
eqxi.assert_dce(x, msg="oh no")
return x

with jax.disable_jit():
g(1.0)


# def test_traceback_runtime_eqx():
# @eqx.filter_jit
# def f(x):
# return g(x)

# @eqx.filter_jit
# def g(x):
# return eqx.warn_if(x, x > 0, "egads")

# f(jnp.array(1.0))
# except Exception as e:
# assert e.__cause__ is None
# msg = str(e).strip()
# assert msg.startswith("egads")
# assert "EQX_ON_warn" in msg
# assert msg.endswith("information.")