Skip to content

Commit

Permalink
fix gnorm, check torch version in R2Conv.export() for padding_mode
Browse files Browse the repository at this point in the history
test export
test inner batch norm
inner batchnorm: method to reset stats and params
minor changes to solve deprecation warnings
  • Loading branch information
Gabri95 committed Mar 17, 2021
2 parents 05a581d + b838c2a commit 0a755fa
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 46 deletions.
4 changes: 2 additions & 2 deletions e2cnn/group/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def build_regular_representation(group: e2cnn.group.Group) -> Tuple[List[e2cnn.g

for e in group.elements:
# print(index[e], e)
r = np.zeros((size, size), dtype=np.float)
r = np.zeros((size, size), dtype=float)
for g in group.elements:

eg = group.combine(e, g)
Expand Down Expand Up @@ -712,7 +712,7 @@ def build_regular_representation(group: e2cnn.group.Group) -> Tuple[List[e2cnn.g

P = directsum(irreps, name="irreps")

v = np.zeros((size, 1), dtype=np.float)
v = np.zeros((size, 1), dtype=float)

p = 0
for irr, m in multiplicities:
Expand Down
2 changes: 1 addition & 1 deletion e2cnn/kernels/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __eq__(self, other):
return False

def __hash__(self):
return hash(self.radii.tostring()) + hash(self.sigma.tostring())
return hash(self.radii.tobytes()) + hash(self.sigma.tobytes())


class PolarBasis(KernelBasis):
Expand Down
2 changes: 1 addition & 1 deletion e2cnn/kernels/irreps_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __eq__(self, other):
return np.allclose(self.mu, other.mu) and np.allclose(self.gamma, other.gamma)

def __hash__(self):
return hash(self.in_irrep) + hash(self.out_irrep) + hash(self.mu.tostring()) + hash(self.gamma.tostring())
return hash(self.in_irrep) + hash(self.out_irrep) + hash(self.mu.tobytes()) + hash(self.gamma.tobytes())


class R2ContinuousRotationsSolution(IrrepBasis):
Expand Down
46 changes: 25 additions & 21 deletions e2cnn/nn/modules/batchnormalization/gnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self,
# number of fields of each type
self._nfields = defaultdict(int)

# indices of the channeles corresponding to fields belonging to each group
# indices of the channels corresponding to fields belonging to each group
_indices = defaultdict(lambda: [])

# whether each group of fields is contiguous or not
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(self,

name = r.name

self._trivial_idxs[name] = trivials
self._trivial_idxs[name] = torch.tensor(trivials, dtype=torch.long)
self._irreps_sizes[name] = [(s, idxs) for s, idxs in irreps.items()]
self._sizes.append((name, r.size))

Expand All @@ -138,14 +138,14 @@ def __init__(self,
self.register_buffer(f'vars_aggregator_{name}', aggregator)
self.register_buffer(f'vars_propagator_{name}', propagator)

running_var = torch.ones((1, self._nfields[r.name], len(r.irreps), 1, 1), dtype=torch.float)
running_mean = torch.zeros((1, self._nfields[r.name], len(trivials), 1, 1), dtype=torch.float)
running_var = torch.ones((self._nfields[r.name], len(r.irreps)), dtype=torch.float)
running_mean = torch.zeros((self._nfields[r.name], len(trivials)), dtype=torch.float)
self.register_buffer(f'{name}_running_var', running_var)
self.register_buffer(f'{name}_running_mean', running_mean)

if self.affine:
weight = Parameter(torch.ones((1, self._nfields[r.name], len(r.irreps), 1, 1)), requires_grad=True)
bias = Parameter(torch.zeros((1, self._nfields[r.name], len(trivials), 1, 1)), requires_grad=True)
weight = Parameter(torch.ones((self._nfields[r.name], len(r.irreps))), requires_grad=True)
bias = Parameter(torch.zeros((self._nfields[r.name], len(trivials))), requires_grad=True)
self.register_parameter(f'{name}_weight', weight)
self.register_parameter(f'{name}_bias', bias)

Expand Down Expand Up @@ -254,12 +254,12 @@ def forward(self, input: GeometricTensor) -> GeometricTensor:

if hasattr(self, f"{name}_change_of_basis"):
cob = getattr(self, f"{name}_change_of_basis")
slice = torch.einsum("ds,bcsxy->bcdxy", (cob, normalized))
normalized = torch.einsum("ds,bcsxy->bcdxy", (cob, normalized))

if not self._contiguous[name]:
output[:, indices, ...] = slice.view(b, -1, h, w)
output[:, indices, ...] = normalized.view(b, -1, h, w)
else:
output[:, indices[0]:indices[1], ...] = slice.view(b, -1, h, w)
output[:, indices[0]:indices[1], ...] = normalized.view(b, -1, h, w)

# if self._contiguous[name]:
# slice2 = output[:, indices[0]:indices[1], ...]
Expand Down Expand Up @@ -289,24 +289,24 @@ def _compute_statistics(self, t: torch.Tensor, name: str):

b, c, s, x, y = t.shape

l = len(trivial_idxs)
l = trivial_idxs.numel()

# number of samples in the tensor used to estimate the statistics
N = b * x * y

# compute the mean of the trivial fields
trivial_means = t[:, :, trivial_idxs, ...].view(b, c, l, x, y).sum(dim=(0, 3, 4), keepdim=True).detach() / N
trivial_means = t[:, :, trivial_idxs, ...].view(b, c, l, x, y).sum(dim=(0, 3, 4), keepdim=False).detach() / N

# compute the mean of squares of all channels
vars = (t ** 2).view(b, c, s, x, y).sum(dim=(0, 3, 4), keepdim=True).detach() / N
vars = (t ** 2).view(b, c, s, x, y).sum(dim=(0, 3, 4), keepdim=False).detach() / N

# For the non-trivial fields the mean of the fields is 0, so we can compute the variance as the mean of the
# norms squared.
# For trivial channels, we need to subtract the squared mean
vars[:, :, trivial_idxs, ...] -= trivial_means**2
vars[:, trivial_idxs] -= trivial_means**2

# aggregate the squared means of the channels which belong to the same irrep
vars = torch.einsum("io,bcixy->bcoxy", (vars_aggregator, vars))
vars = torch.einsum("io,ci->co", (vars_aggregator, vars))

# Correct the estimation of the variance with Bessel's correction
correction = N/(N-1) if N > 1 else 1.
Expand All @@ -321,21 +321,25 @@ def _scale(self, t: torch.Tensor, scales: torch.Tensor, name: str, out: torch.Te

vars_aggregator = getattr(self, f"vars_propagator_{name}")

ndims = len(t.shape[3:])
scale_shape = (1, scales.shape[0], vars_aggregator.shape[0]) + (1,)*ndims
# scale all fields
out[...] = t * torch.einsum("oi,bcixy->bcoxy", (vars_aggregator, scales))

# assert torch.allclose(t, out)
out[...] = t * torch.einsum("oi,ci->co", (vars_aggregator, scales)).reshape(scale_shape)

return out

def _shift(self, t: torch.Tensor, trivial_bias: torch.Tensor, name: str, out: torch.Tensor = None):

if out is None:
out = torch.zeros_like(t)

out = t.clone()
else:
out[:] = t

trivial_idxs = self._trivial_idxs[name]

bias_shape = (1,) + trivial_bias.shape + (1,)*(len(t.shape) - 3)

# add bias to the trivial fields
out[:, :, trivial_idxs, ...] = t[:, :, trivial_idxs, ...] + trivial_bias
out[:, :, trivial_idxs, ...] += trivial_bias.view(bias_shape)

return out
10 changes: 10 additions & 0 deletions e2cnn/nn/modules/batchnormalization/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ def __init__(self,
track_running_stats=self.track_running_stats
)
self.add_module('batch_norm_[{}]'.format(s), _batchnorm)

def reset_running_stats(self):
for s, contiguous in self._contiguous.items():
batchnorm = getattr(self, f'batch_norm_[{s}]')
batchnorm.reset_running_stats()

def reset_parameters(self):
for s, contiguous in self._contiguous.items():
batchnorm = getattr(self, f'batch_norm_[{s}]')
batchnorm.reset_parameters()

def forward(self, input: GeometricTensor) -> GeometricTensor:
r"""
Expand Down
19 changes: 19 additions & 0 deletions e2cnn/nn/modules/r2_conv/r2convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def forward(self, input: GeometricTensor):
_filter,
stride=self.stride,
dilation=self.dilation,
padding=(0,0),
groups=self.groups,
bias=_bias)

Expand Down Expand Up @@ -507,12 +508,30 @@ def export(self):
_filter = self.filter
_bias = self.expanded_bias

if self.padding_mode not in ['zeros']:
x, y = torch.__version__.split('.')[:2]
if int(x) < 1 or int(y) < 5:
if self.padding_mode == 'circular':
raise ImportError(
"'{}' padding mode had some issues in old `torch` versions. Therefore, we only support conversion from version 1.5 but only version {} is installed.".format(
self.padding_mode, torch.__version__
)
)

else:
raise ImportError(
"`torch` supports '{}' padding mode only from version 1.5 but only version {} is installed.".format(
self.padding_mode, torch.__version__
)
)

# build the PyTorch Conv2d module
has_bias = self.bias is not None
conv = torch.nn.Conv2d(self.in_type.size,
self.out_type.size,
self.kernel_size,
padding=self.padding,
padding_mode=self.padding_mode,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
Expand Down
17 changes: 17 additions & 0 deletions test/nn/test_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ def test_dihedral_induced_norm(self):

self.check_bn(bn)

def test_dihedral_inner_norm(self):
N = 8
g = FlipRot2dOnR2(N)

g.fibergroup._build_quotient_representations()

reprs = []
for r in g.representations.values():
if 'pointwise' in r.supported_nonlinearities:
reprs.append(r)

r = FieldType(g, reprs)

bn = InnerBatchNorm(r, affine=False, momentum=1.)

self.check_bn(bn)

def check_bn(self, bn: EquivariantModule):

x = 10*torch.randn(300, bn.in_type.size, 1, 1) + 20
Expand Down
44 changes: 23 additions & 21 deletions test/nn/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,30 @@ def test_R2Conv(self):
for st in [1, 2]:
for d in [1, 3]:
for gr in [1, 3]:
for i in range(2):
for mode in ['zeros', 'reflect', 'replicate', 'circular']:
for i in range(2):

c_in = 1 + np.random.randint(4)
c_out = 1 + np.random.randint(4)

c_in *= gr
c_out *= gr

f_in = FieldType(gs, [gs.regular_repr]*c_in)
f_out = FieldType(gs, [gs.regular_repr]*c_out)

conv = R2Conv(
f_in, f_out,
kernel_size=ks,
padding=pd,
stride=st,
dilation=d,
groups=gr,
bias=True,
)

self.check_exported(conv)
c_in = 1 + np.random.randint(4)
c_out = 1 + np.random.randint(4)

c_in *= gr
c_out *= gr

f_in = FieldType(gs, [gs.regular_repr]*c_in)
f_out = FieldType(gs, [gs.regular_repr]*c_out)

conv = R2Conv(
f_in, f_out,
kernel_size=ks,
padding=pd,
padding_mode=mode,
stride=st,
dilation=d,
groups=gr,
bias=True,
)

self.check_exported(conv)

def test_R2Conv_mix(self):

Expand Down

0 comments on commit 0a755fa

Please sign in to comment.