Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm training instability fix #675

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 144 additions & 15 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
from collections.abc import Hashable, Sequence
from typing import Optional, Union
from typing import Literal, Optional, Union

import jax
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float, PRNGKeyArray
from jaxtyping import Array, Float, Int, PRNGKeyArray

from .._misc import default_floating_dtype
from .._module import field
Expand Down Expand Up @@ -40,28 +41,92 @@ class BatchNorm(StatefulLayer, strict=True):
statistics updated. During inference then just the running statistics are used.
Whether the model is in training or inference mode should be toggled using
[`equinox.nn.inference_mode`][].

With `approach = "batch"` during training the batch mean and variance are used
for normalization. For inference the exponential running mean and ubiased
variance are used for normalization in accordance with the cited paper below.
Let `m` be momentum:

$\text{TrainStats}_t = \text{BatchStats}_t$

$\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}
\text{BatchStats}_i}{\text{max} \left(1.0 - m^{t+1}, \varepsilon \right)}$

With `approach = "ema"` exponential running means and variances are kept. During
training the batch statistics are used to fill in the running statistics until
they are populated. In addition a linear iterpolation is used between the batch
and running statistics over the `warmup_period`. During inference the running
statistics are used for normalization:



$\text{WarmupFrac}_t = \text{min} \left(1.0, \frac{t}{\text{WarmupPeriod}} \right)$

$\text{TrainStats}_t = (1.0 - \text{WarmupFrac}_t) * BatchStats_t +
\text{WarmupFrac}_t * \left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}\text{BatchStats}_i$

$\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}
\text{BatchStats}_i}{\text{max} \left(1.0 - m^{t+1}, \varepsilon \right)}$


$\text{Note: } \frac{(1.0 - m)\sum_{i=0}^{t}m^{t-i}}{1.0 - m^{t+1}} =
\frac{(1.0 - m)\sum_{i=0}^{t}m^{i}}{1.0 - m^{t+1}}$
$= \frac{(1.0 - m)\frac{1.0 - m^{t+1}}{1.0 - m}}{1.0 - m^{t+1}} = 1$

`approach = "ema_compatibility"` reproduces the original equinox BatchNorm
behavior. It often results in training instabilities and `approach = "batch"`
or `"ema"` is recommended.

??? cite

[Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)

```bibtex
@article{DBLP:journals/corr/IoffeS15,
author = {Sergey Ioffe and
Christian Szegedy},
title = {Batch Normalization: Accelerating Deep Network Training
by Reducing Internal Covariate Shift},
journal = {CoRR},
volume = {abs/1502.03167},
year = {2015},
url = {http://arxiv.org/abs/1502.03167},
eprinttype = {arXiv},
eprint = {1502.03167},
timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```

""" # noqa: E501

weight: Optional[Float[Array, "input_size"]]
bias: Optional[Float[Array, "input_size"]]
first_time_index: StateIndex[Bool[Array, ""]]
count_index: StateIndex[Int[Array, ""]]
state_index: StateIndex[
tuple[Float[Array, "input_size"], Float[Array, "input_size"]]
]
zero_frac_index: StateIndex[Float[Array, ""]]
axis_name: Union[Hashable, Sequence[Hashable]]
inference: bool
input_size: int = field(static=True)
approach: Literal["batch", "ema", "ema_compatibility"] = field(static=True)
eps: float = field(static=True)
channelwise_affine: bool = field(static=True)
momentum: float = field(static=True)
warmup_period: int = field(static=True)

def __init__(
self,
input_size: int,
axis_name: Union[Hashable, Sequence[Hashable]],
approach: Optional[Literal["batch", "ema", "ema_compatibility"]] = None,
eps: float = 1e-5,
channelwise_affine: bool = True,
momentum: float = 0.99,
warmup_period: int = 1,
inference: bool = False,
dtype=None,
):
Expand All @@ -71,11 +136,17 @@ def __init__(
- `axis_name`: The name of the batch axis to compute statistics over, as passed
to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a
tuple or a list) of names, to compute statistics over multiple named axes.
- `approach`: The approach to use for the running statistics. If `approach=None`
a warning will be raised and approach will default to `"ema_compatibility"`.
During training `"batch"` only uses batch statisics while`"ema"` and
`"ema_compatibility"` uses the running statistics.
- `eps`: Value added to the denominator for numerical stability.
- `channelwise_affine`: Whether the module has learnable channel-wise affine
parameters.
- `momentum`: The rate at which to update the running statistics. Should be a
value between 0 and 1 exclusive.
- `warmup_period`: The interpolation period between batch and running
statistics. Only used when `approach=\"ema\"`.
- `inference`: If `False` then the batch means and variances will be calculated
and used to update the running statistics. If `True` then the running
statistics are directly used for normalisation. This may be toggled with
Expand All @@ -86,26 +157,46 @@ def __init__(
64-bit mode.
"""

if approach is None:
warnings.warn(
"BatchNorm approach is None, defaults to "
'approach="ema_compatibility". This is not recommended as '
'it can lead to training instability. Use "batch" or '
'alternatively "ema" with appropriately selected warmup '
"instead."
)
approach = "ema_compatibility"

valid_approaches = {"batch", "ema", "ema_compatibility"}
if approach not in valid_approaches:
raise ValueError(f"approach must be one of {valid_approaches}")
self.approach = approach

if warmup_period < 1:
raise ValueError("warmup_period must be >= 1")

if channelwise_affine:
self.weight = jnp.ones((input_size,))
self.bias = jnp.zeros((input_size,))
else:
self.weight = None
self.bias = None
self.first_time_index = StateIndex(jnp.array(True))
self.count_index = StateIndex(jnp.array(0, dtype=jnp.int32))
if dtype is None:
dtype = default_floating_dtype()
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
jnp.zeros((input_size,), dtype=dtype),
jnp.zeros((input_size,), dtype=dtype),
)
self.state_index = StateIndex(init_buffers)
self.zero_frac_index = StateIndex(jnp.array(1.0, dtype=dtype))
self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
self.eps = eps
self.channelwise_affine = channelwise_affine
self.momentum = momentum
self.warmup_period = warmup_period

@jax.named_scope("eqx.nn.BatchNorm")
def __call__(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not completely obvious to me that the ema implementation, with default arguments, reproduces the previous behaviour. (For example, we have warmup_period=1000 by default?)

Can you add some comments explaining what each approach corresponds to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ema with warmup_period=1 approximately reproduces previous behavior. As I noted the start is different because of how the running statistics are initially populated. With warmup_period=1 there's no interpolation between the batch and running stats - the running stats are always used as with the previous behavior. I can give an exact replication with an extra approach if necessary.

Added some to the documentation

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an exact replication is probably important for the default behaviour, just because I'd like to be sure that we're bit-for-bit backward compatible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, it was different enough that I added it as "ema_compatibility". I changed the warning to rather strongly recommend against using "ema_compatibility". I haven't found a use case where I wouldn't expect to see the instability (at least with a larger learning rate) but that could very much be due to a lack of imagination on my part.. That part can definitely change if needed

Expand Down Expand Up @@ -143,7 +234,11 @@ def __call__(
if inference is None:
inference = self.inference
if inference:
# renormalize running stats to account for the zeroed part
zero_frac = state.get(self.zero_frac_index)
running_mean, running_var = state.get(self.state_index)
norm_mean = running_mean / jnp.maximum(1.0 - zero_frac, self.eps)
norm_var = running_var / jnp.maximum(1.0 - zero_frac, self.eps)
else:

def _stats(y):
Expand All @@ -154,16 +249,50 @@ def _stats(y):
var = jnp.maximum(0.0, var)
return mean, var

first_time = state.get(self.first_time_index)
state = state.set(self.first_time_index, jnp.array(False))

momentum = self.momentum
batch_mean, batch_var = jax.vmap(_stats)(x)
zero_frac = state.get(self.zero_frac_index)
running_mean, running_var = state.get(self.state_index)
momentum = self.momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
running_mean = lax.select(first_time, batch_mean, running_mean)
running_var = lax.select(first_time, batch_var, running_var)

if self.approach == "ema":
zero_frac = zero_frac * momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
warmup_count = state.get(self.count_index)
warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period)
state = state.set(self.count_index, warmup_count)

# fill in unpopulated part of running stats with batch stats
warmup_frac = warmup_count / self.warmup_period
norm_mean = zero_frac * batch_mean + running_mean
norm_var = zero_frac * batch_var + running_var

# apply warmup interpolation between batch and running statistics
norm_mean = (1.0 - warmup_frac) * batch_mean + warmup_frac * norm_mean
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var

elif self.approach == "ema_compatibility":
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
running_mean = lax.select(zero_frac == 1.0, batch_mean, running_mean)
running_var = lax.select(zero_frac == 1.0, batch_var, running_var)
norm_mean, norm_var = running_mean, running_var
zero_frac = 0.0 * zero_frac

else:
zero_frac = zero_frac * momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
# calculate unbiased variance for saving
axis_size = jax.lax.psum(jnp.array(1.0), self.axis_name)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using this to get the length of the "batch" axis - but not sure it's the best / correct way

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the correct way! IIRC psum(1) is actually special-cased for this purpose.

debias_coef = (axis_size) / jnp.maximum(axis_size - 1, self.eps)
running_var = (
1 - momentum
) * debias_coef * batch_var + momentum * running_var
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I neglected to use unbiased variance so corrected that here


# just use batch statistics when not in inference mode
norm_mean, norm_var = batch_mean, batch_var

state = state.set(self.zero_frac_index, zero_frac)
state = state.set(self.state_index, (running_mean, running_var))

def _norm(y, m, v, w, b):
Expand All @@ -172,5 +301,5 @@ def _norm(y, m, v, w, b):
out = out * w + b
return out

out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias)
out = jax.vmap(_norm)(x, norm_mean, norm_var, self.weight, self.bias)
return out, state
69 changes: 58 additions & 11 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_sequential(getkey):
[
eqx.nn.Linear(2, 4, key=getkey()),
eqx.nn.Linear(4, 1, key=getkey()),
eqx.nn.BatchNorm(1, axis_name="batch"),
eqx.nn.BatchNorm(1, axis_name="batch", approach="batch"),
eqx.nn.Linear(1, 3, key=getkey()),
]
)
Expand Down Expand Up @@ -158,7 +158,7 @@ def make():
inner_seq = eqx.nn.Sequential(
[
eqx.nn.Linear(2, 4, key=getkey()),
eqx.nn.BatchNorm(4, axis_name="batch")
eqx.nn.BatchNorm(4, axis_name="batch", approach="batch")
if inner_stateful
else eqx.nn.Identity(),
eqx.nn.Linear(4, 3, key=getkey()),
Expand All @@ -168,7 +168,7 @@ def make():
[
eqx.nn.Linear(5, 2, key=getkey()),
inner_seq,
eqx.nn.BatchNorm(3, axis_name="batch")
eqx.nn.BatchNorm(3, axis_name="batch", approach="batch")
if outer_stateful
else eqx.nn.Identity(),
eqx.nn.Linear(3, 6, key=getkey()),
Expand Down Expand Up @@ -825,18 +825,40 @@ def test_batch_norm(getkey):
x2 = jrandom.uniform(getkey(), (10, 5, 6))
x3 = jrandom.uniform(getkey(), (10, 5, 7, 8))

# Test that it works with a single vmap'd axis_name
# Test that it warns with no approach - defaulting to batch
with pytest.warns(UserWarning):
bn = eqx.nn.BatchNorm(5, "batch")
assert bn.approach == "ema_compatibility"

with pytest.raises(ValueError):
bn = eqx.nn.BatchNorm(5, "batch", approach="ema", warmup_period=0)

bn = eqx.nn.BatchNorm(5, "batch")
# Test initialization
bn_momentum = 0.99
bn = eqx.nn.BatchNorm(
5, "batch", approach="ema", warmup_period=10, momentum=bn_momentum
)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
running_mean, running_var = state.get(bn.state_index)
zero_frac = state.get(bn.zero_frac_index)
warmup_count = state.get(bn.count_index)
assert jnp.array_equal(running_mean, jnp.zeros(running_mean.shape))
assert jnp.array_equal(running_var, jnp.zeros(running_var.shape))
assert jnp.array_equal(zero_frac, jnp.array(1.0))
assert jnp.array_equal(warmup_count, jnp.array(0))

for x in (x1, x2, x3):
# Test that it works with a single vmap'd axis_name
for i, x in enumerate([x1, x2, x3]):
out, state = vbn(x, state)
assert out.shape == x.shape
running_mean, running_var = state.get(bn.state_index)
zero_frac = state.get(bn.zero_frac_index)
warmup_count = state.get(bn.count_index)
assert running_mean.shape == (5,)
assert running_var.shape == (5,)
assert jnp.array_equal(warmup_count, jnp.array(i + 1))
assert jnp.allclose(zero_frac, jnp.array(bn_momentum ** (i + 1)))

# Test that it fails without any vmap'd axis_name

Expand All @@ -861,7 +883,7 @@ def test_batch_norm(getkey):

# Test that it handles multiple axis_names

vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"))
vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"), approach="ema")
vvstate = eqx.nn.State(vvbn)
for axis_name in ("batch1", "batch2"):
vvbn = jax.vmap(
Expand All @@ -873,34 +895,59 @@ def test_batch_norm(getkey):
assert running_mean.shape == (6,)
assert running_var.shape == (6,)

# Test that it normalises

# Test that approach=ema normalises
x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False)
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, approach="ema")
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn(x1alt, state)
true_out = (x1alt - jnp.mean(x1alt, axis=0, keepdims=True)) / jnp.sqrt(
jnp.var(x1alt, axis=0, keepdims=True) + 1e-5
)
assert jnp.allclose(out, true_out)

# Test that approach=batch normalises in training mode
bn = eqx.nn.BatchNorm(
5, "batch", channelwise_affine=False, approach="batch", momentum=0.9
)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn(x1alt, state)
true_out = (x1alt - jnp.mean(x1alt, axis=0, keepdims=True)) / jnp.sqrt(
jnp.var(x1alt, axis=0, keepdims=True) + 1e-5
)
assert jnp.allclose(out, true_out)
# Test that approach=batch normaises in inference mode
bn_inf = eqx.nn.inference_mode(bn, value=True)
vbn_inf = jax.vmap(bn_inf, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn_inf(x1alt, state)
debias_coef = x1alt.shape[0] / (x1alt.shape[0] - 1)
true_out = (x1alt - jnp.mean(x1alt, axis=0, keepdims=True)) / jnp.sqrt(
debias_coef * jnp.var(x1alt, axis=0, keepdims=True) + 1e-5
)
assert jnp.allclose(out, true_out)

# Test that the statistics update during training
out, state = vbn(x1, state)
running_mean, running_var = state.get(bn.state_index)
out, state = vbn(3 * x1 + 10, state)
running_mean2, running_var2 = state.get(bn.state_index)
zero_frac2 = state.get(bn.zero_frac_index)
warmup_count2 = state.get(bn.count_index)
assert not jnp.allclose(running_mean, running_mean2)
assert not jnp.allclose(running_var, running_var2)

# Test that the statistics don't update at inference

ibn = eqx.nn.inference_mode(bn, value=True)
vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vibn(4 * x1 + 20, state)
running_mean3, running_var3 = state.get(bn.state_index)
zero_frac3 = state.get(bn.zero_frac_index)
warmup_count3 = state.get(bn.count_index)
assert jnp.array_equal(running_mean2, running_mean3)
assert jnp.array_equal(running_var2, running_var3)
assert jnp.array_equal(zero_frac2, zero_frac3)
assert jnp.array_equal(warmup_count2, warmup_count3)

# Test that we can differentiate through it

Expand Down
Loading