Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm training instability fix #675

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
from collections.abc import Hashable, Sequence
from typing import Optional, Union
from typing import Literal, Optional, Union

import jax
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float, PRNGKeyArray
from jaxtyping import Array, Float, Int, PRNGKeyArray

from .._misc import default_floating_dtype
from .._module import field
Expand Down Expand Up @@ -44,24 +45,29 @@ class BatchNorm(StatefulLayer, strict=True):

weight: Optional[Float[Array, "input_size"]]
bias: Optional[Float[Array, "input_size"]]
first_time_index: StateIndex[Bool[Array, ""]]
count_index: StateIndex[Int[Array, ""]]
state_index: StateIndex[
tuple[Float[Array, "input_size"], Float[Array, "input_size"]]
]
zero_frac_index: StateIndex[Float[Array, ""]]
axis_name: Union[Hashable, Sequence[Hashable]]
inference: bool
input_size: int = field(static=True)
approach: Literal["batch", "ema"] = field(static=True)
eps: float = field(static=True)
channelwise_affine: bool = field(static=True)
momentum: float = field(static=True)
warmup_period: int = field(static=True)

def __init__(
self,
input_size: int,
axis_name: Union[Hashable, Sequence[Hashable]],
approach: Optional[Literal["batch", "ema"]] = None,
eps: float = 1e-5,
channelwise_affine: bool = True,
momentum: float = 0.99,
warmup_period: int = 1000,
inference: bool = False,
dtype=None,
):
Expand All @@ -71,11 +77,17 @@ def __init__(
- `axis_name`: The name of the batch axis to compute statistics over, as passed
to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a
tuple or a list) of names, to compute statistics over multiple named axes.
- `approach`: The approach to use for the running statistics. If `approach=None`
a warning will be raised and approach will default to `"batch"`. During
training `"batch"` only uses batch statisics while`"ema"` uses the running
statistics.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So continuing from my previous comment -- probably the default should be ema if approach=None.

- `eps`: Value added to the denominator for numerical stability.
- `channelwise_affine`: Whether the module has learnable channel-wise affine
parameters.
- `momentum`: The rate at which to update the running statistics. Should be a
value between 0 and 1 exclusive.
- `warmup_period`: The period to warm up the running statistics. Only used when
`approach=\"ema\"`.
- `inference`: If `False` then the batch means and variances will be calculated
and used to update the running statistics. If `True` then the running
statistics are directly used for normalisation. This may be toggled with
Expand All @@ -86,26 +98,37 @@ def __init__(
64-bit mode.
"""

if approach is None:
warnings.warn('BatchNorm approach is None, defaults to approach="batch"')
approach = "batch"

valid_approaches = {"batch", "ema"}
if approach not in valid_approaches:
raise ValueError(f"approach must be one of {valid_approaches}")
self.approach = approach

if channelwise_affine:
self.weight = jnp.ones((input_size,))
self.bias = jnp.zeros((input_size,))
else:
self.weight = None
self.bias = None
self.first_time_index = StateIndex(jnp.array(True))
self.count_index = StateIndex(jnp.array(0, dtype=jnp.int32))
if dtype is None:
dtype = default_floating_dtype()
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
jnp.zeros((input_size,), dtype=dtype),
jnp.zeros((input_size,), dtype=dtype),
)
self.state_index = StateIndex(init_buffers)
self.zero_frac_index = StateIndex(jnp.array(1.0, dtype=dtype))
self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
self.eps = eps
self.channelwise_affine = channelwise_affine
self.momentum = momentum
self.warmup_period = max(1, warmup_period)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the max? Perhaps it would be better to just error out on values that are too small?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warmup_period=0 seemed natural for off - Changed to just check and error out


@jax.named_scope("eqx.nn.BatchNorm")
def __call__(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not completely obvious to me that the ema implementation, with default arguments, reproduces the previous behaviour. (For example, we have warmup_period=1000 by default?)

Can you add some comments explaining what each approach corresponds to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ema with warmup_period=1 approximately reproduces previous behavior. As I noted the start is different because of how the running statistics are initially populated. With warmup_period=1 there's no interpolation between the batch and running stats - the running stats are always used as with the previous behavior. I can give an exact replication with an extra approach if necessary.

Added some to the documentation

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an exact replication is probably important for the default behaviour, just because I'd like to be sure that we're bit-for-bit backward compatible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, it was different enough that I added it as "ema_compatibility". I changed the warning to rather strongly recommend against using "ema_compatibility". I haven't found a use case where I wouldn't expect to see the instability (at least with a larger learning rate) but that could very much be due to a lack of imagination on my part.. That part can definitely change if needed

Expand Down Expand Up @@ -143,7 +166,10 @@ def __call__(
if inference is None:
inference = self.inference
if inference:
zero_frac = state.get(self.zero_frac_index)
running_mean, running_var = state.get(self.state_index)
norm_mean = running_mean / jnp.maximum(1.0 - zero_frac, self.eps)
norm_var = running_var / jnp.maximum(1.0 - zero_frac, self.eps)
else:

def _stats(y):
Expand All @@ -154,23 +180,35 @@ def _stats(y):
var = jnp.maximum(0.0, var)
return mean, var

first_time = state.get(self.first_time_index)
state = state.set(self.first_time_index, jnp.array(False))
momentum = self.momentum
zero_frac = state.get(self.zero_frac_index)
zero_frac *= momentum
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistic nit: I tend not to use the inplace operations in JAX code. This (a) fits with the functional style a bit better, and (b) emphasises that we're definitely falling back to the zero_frac = zero_frac * momentum interpretation of the syntax. (Gosh, why does Python has two different meanings for the same syntax?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, done

state = state.set(self.zero_frac_index, zero_frac)

batch_mean, batch_var = jax.vmap(_stats)(x)
running_mean, running_var = state.get(self.state_index)
momentum = self.momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't appear to be used on the batch branch. I think the lines here can be reorganised to keep each approach only using the things it needs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are used by the batch branch when we're in inference mode so they still need to be computed and stored

running_mean = lax.select(first_time, batch_mean, running_mean)
running_var = lax.select(first_time, batch_var, running_var)
state = state.set(self.state_index, (running_mean, running_var))

if self.approach == "ema":
warmup_count = state.get(self.count_index)
warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period)
state = state.set(self.count_index, warmup_count)

warmup_frac = warmup_count / self.warmup_period
norm_mean = zero_frac * batch_mean + running_mean
norm_mean = (1.0 - warmup_frac) * batch_mean + warmup_frac * norm_mean
norm_var = zero_frac * batch_var + running_var
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely going to have to sit down and grok what's going on here more carefully! As above it would be good to have some comments / docstrings / references / etc. describing what each approach is meant to do.

(C.f. something like the MultiheadAttention docstring for an example on how to use LaTeX if it'd be helpful.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some commentary and tried making it a bit cleaner.

But overall batch mode should follow the cited paper. Ema follows the prior behavior but changes the initialization of the running stats and adds interpolation so it can be stable while training.

else:
norm_mean, norm_var = batch_mean, batch_var

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
if self.channelwise_affine:
out = out * w + b
return out

out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias)
out = jax.vmap(_norm)(x, norm_mean, norm_var, self.weight, self.bias)
return out, state
39 changes: 31 additions & 8 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_sequential(getkey):
[
eqx.nn.Linear(2, 4, key=getkey()),
eqx.nn.Linear(4, 1, key=getkey()),
eqx.nn.BatchNorm(1, axis_name="batch"),
eqx.nn.BatchNorm(1, axis_name="batch", approach="batch"),
eqx.nn.Linear(1, 3, key=getkey()),
]
)
Expand Down Expand Up @@ -158,7 +158,7 @@ def make():
inner_seq = eqx.nn.Sequential(
[
eqx.nn.Linear(2, 4, key=getkey()),
eqx.nn.BatchNorm(4, axis_name="batch")
eqx.nn.BatchNorm(4, axis_name="batch", approach="batch")
if inner_stateful
else eqx.nn.Identity(),
eqx.nn.Linear(4, 3, key=getkey()),
Expand All @@ -168,7 +168,7 @@ def make():
[
eqx.nn.Linear(5, 2, key=getkey()),
inner_seq,
eqx.nn.BatchNorm(3, axis_name="batch")
eqx.nn.BatchNorm(3, axis_name="batch", approach="batch")
if outer_stateful
else eqx.nn.Identity(),
eqx.nn.Linear(3, 6, key=getkey()),
Expand Down Expand Up @@ -825,18 +825,35 @@ def test_batch_norm(getkey):
x2 = jrandom.uniform(getkey(), (10, 5, 6))
x3 = jrandom.uniform(getkey(), (10, 5, 7, 8))

# Test that it works with a single vmap'd axis_name
# Test that it warns with no approach - defaulting to batch
with pytest.warns(UserWarning):
bn = eqx.nn.BatchNorm(5, "batch")
assert bn.approach == "batch"

bn = eqx.nn.BatchNorm(5, "batch")
# Test initialization
bn_momentum = 0.99
bn = eqx.nn.BatchNorm(5, "batch", approach="ema", momentum=bn_momentum)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
running_mean, running_var = state.get(bn.state_index)
zero_frac = state.get(bn.zero_frac_index)
warmup_count = state.get(bn.count_index)
assert jnp.array_equal(running_mean, jnp.zeros(running_mean.shape))
assert jnp.array_equal(running_var, jnp.zeros(running_var.shape))
assert jnp.array_equal(zero_frac, jnp.array(1.0))
assert jnp.array_equal(warmup_count, jnp.array(0))

for x in (x1, x2, x3):
# Test that it works with a single vmap'd axis_name
for i, x in enumerate([x1, x2, x3]):
out, state = vbn(x, state)
assert out.shape == x.shape
running_mean, running_var = state.get(bn.state_index)
zero_frac = state.get(bn.zero_frac_index)
warmup_count = state.get(bn.count_index)
assert running_mean.shape == (5,)
assert running_var.shape == (5,)
assert jnp.array_equal(warmup_count, jnp.array(i + 1))
assert jnp.allclose(zero_frac, jnp.array(bn_momentum ** (i + 1)))

# Test that it fails without any vmap'd axis_name

Expand All @@ -861,7 +878,7 @@ def test_batch_norm(getkey):

# Test that it handles multiple axis_names

vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"))
vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"), approach="ema")
vvstate = eqx.nn.State(vvbn)
for axis_name in ("batch1", "batch2"):
vvbn = jax.vmap(
Expand All @@ -876,7 +893,7 @@ def test_batch_norm(getkey):
# Test that it normalises

x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False)
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, approach="ema")
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn(x1alt, state)
Expand All @@ -890,6 +907,8 @@ def test_batch_norm(getkey):
running_mean, running_var = state.get(bn.state_index)
out, state = vbn(3 * x1 + 10, state)
running_mean2, running_var2 = state.get(bn.state_index)
zero_frac2 = state.get(bn.zero_frac_index)
warmup_count2 = state.get(bn.count_index)
assert not jnp.allclose(running_mean, running_mean2)
assert not jnp.allclose(running_var, running_var2)

Expand All @@ -899,8 +918,12 @@ def test_batch_norm(getkey):
vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vibn(4 * x1 + 20, state)
running_mean3, running_var3 = state.get(bn.state_index)
zero_frac3 = state.get(bn.zero_frac_index)
warmup_count3 = state.get(bn.count_index)
assert jnp.array_equal(running_mean2, running_mean3)
assert jnp.array_equal(running_var2, running_var3)
assert jnp.array_equal(zero_frac2, zero_frac3)
assert jnp.array_equal(warmup_count2, warmup_count3)

# Test that we can differentiate through it

Expand Down
4 changes: 2 additions & 2 deletions tests/test_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_delete_init_state():
model = eqx.nn.BatchNorm(3, "batch")
model = eqx.nn.BatchNorm(3, "batch", approach="batch")
eqx.nn.State(model)
model2 = eqx.nn.delete_init_state(model)

Expand All @@ -17,7 +17,7 @@ def test_delete_init_state():

leaves = [x for x in jtu.tree_leaves(model) if eqx.is_array(x)]
leaves2 = [x for x in jtu.tree_leaves(model2) if eqx.is_array(x)]
assert len(leaves) == len(leaves2) + 3
assert len(leaves) == len(leaves2) + 4


def test_double_state():
Expand Down
Loading