diff --git a/e2cnn/group/representation.py b/e2cnn/group/representation.py index 12c65077..26628f0c 100644 --- a/e2cnn/group/representation.py +++ b/e2cnn/group/representation.py @@ -673,7 +673,7 @@ def build_regular_representation(group: e2cnn.group.Group) -> Tuple[List[e2cnn.g for e in group.elements: # print(index[e], e) - r = np.zeros((size, size), dtype=np.float) + r = np.zeros((size, size), dtype=float) for g in group.elements: eg = group.combine(e, g) @@ -712,7 +712,7 @@ def build_regular_representation(group: e2cnn.group.Group) -> Tuple[List[e2cnn.g P = directsum(irreps, name="irreps") - v = np.zeros((size, 1), dtype=np.float) + v = np.zeros((size, 1), dtype=float) p = 0 for irr, m in multiplicities: diff --git a/e2cnn/kernels/basis.py b/e2cnn/kernels/basis.py index 706fd1d5..0220bf36 100644 --- a/e2cnn/kernels/basis.py +++ b/e2cnn/kernels/basis.py @@ -172,7 +172,7 @@ def __eq__(self, other): return False def __hash__(self): - return hash(self.radii.tostring()) + hash(self.sigma.tostring()) + return hash(self.radii.tobytes()) + hash(self.sigma.tobytes()) class PolarBasis(KernelBasis): diff --git a/e2cnn/kernels/irreps_basis.py b/e2cnn/kernels/irreps_basis.py index 0e05265f..f4af163a 100644 --- a/e2cnn/kernels/irreps_basis.py +++ b/e2cnn/kernels/irreps_basis.py @@ -442,7 +442,7 @@ def __eq__(self, other): return np.allclose(self.mu, other.mu) and np.allclose(self.gamma, other.gamma) def __hash__(self): - return hash(self.in_irrep) + hash(self.out_irrep) + hash(self.mu.tostring()) + hash(self.gamma.tostring()) + return hash(self.in_irrep) + hash(self.out_irrep) + hash(self.mu.tobytes()) + hash(self.gamma.tobytes()) class R2ContinuousRotationsSolution(IrrepBasis): diff --git a/e2cnn/nn/modules/batchnormalization/gnorm.py b/e2cnn/nn/modules/batchnormalization/gnorm.py index 37855009..5dfbc52c 100644 --- a/e2cnn/nn/modules/batchnormalization/gnorm.py +++ b/e2cnn/nn/modules/batchnormalization/gnorm.py @@ -59,7 +59,7 @@ def __init__(self, # number of fields of each type self._nfields = defaultdict(int) - # indices of the channeles corresponding to fields belonging to each group + # indices of the channels corresponding to fields belonging to each group _indices = defaultdict(lambda: []) # whether each group of fields is contiguous or not @@ -127,7 +127,7 @@ def __init__(self, name = r.name - self._trivial_idxs[name] = trivials + self._trivial_idxs[name] = torch.tensor(trivials, dtype=torch.long) self._irreps_sizes[name] = [(s, idxs) for s, idxs in irreps.items()] self._sizes.append((name, r.size)) @@ -138,14 +138,14 @@ def __init__(self, self.register_buffer(f'vars_aggregator_{name}', aggregator) self.register_buffer(f'vars_propagator_{name}', propagator) - running_var = torch.ones((1, self._nfields[r.name], len(r.irreps), 1, 1), dtype=torch.float) - running_mean = torch.zeros((1, self._nfields[r.name], len(trivials), 1, 1), dtype=torch.float) + running_var = torch.ones((self._nfields[r.name], len(r.irreps)), dtype=torch.float) + running_mean = torch.zeros((self._nfields[r.name], len(trivials)), dtype=torch.float) self.register_buffer(f'{name}_running_var', running_var) self.register_buffer(f'{name}_running_mean', running_mean) if self.affine: - weight = Parameter(torch.ones((1, self._nfields[r.name], len(r.irreps), 1, 1)), requires_grad=True) - bias = Parameter(torch.zeros((1, self._nfields[r.name], len(trivials), 1, 1)), requires_grad=True) + weight = Parameter(torch.ones((self._nfields[r.name], len(r.irreps))), requires_grad=True) + bias = Parameter(torch.zeros((self._nfields[r.name], len(trivials))), requires_grad=True) self.register_parameter(f'{name}_weight', weight) self.register_parameter(f'{name}_bias', bias) @@ -254,12 +254,12 @@ def forward(self, input: GeometricTensor) -> GeometricTensor: if hasattr(self, f"{name}_change_of_basis"): cob = getattr(self, f"{name}_change_of_basis") - slice = torch.einsum("ds,bcsxy->bcdxy", (cob, normalized)) + normalized = torch.einsum("ds,bcsxy->bcdxy", (cob, normalized)) if not self._contiguous[name]: - output[:, indices, ...] = slice.view(b, -1, h, w) + output[:, indices, ...] = normalized.view(b, -1, h, w) else: - output[:, indices[0]:indices[1], ...] = slice.view(b, -1, h, w) + output[:, indices[0]:indices[1], ...] = normalized.view(b, -1, h, w) # if self._contiguous[name]: # slice2 = output[:, indices[0]:indices[1], ...] @@ -289,24 +289,24 @@ def _compute_statistics(self, t: torch.Tensor, name: str): b, c, s, x, y = t.shape - l = len(trivial_idxs) + l = trivial_idxs.numel() # number of samples in the tensor used to estimate the statistics N = b * x * y # compute the mean of the trivial fields - trivial_means = t[:, :, trivial_idxs, ...].view(b, c, l, x, y).sum(dim=(0, 3, 4), keepdim=True).detach() / N + trivial_means = t[:, :, trivial_idxs, ...].view(b, c, l, x, y).sum(dim=(0, 3, 4), keepdim=False).detach() / N # compute the mean of squares of all channels - vars = (t ** 2).view(b, c, s, x, y).sum(dim=(0, 3, 4), keepdim=True).detach() / N + vars = (t ** 2).view(b, c, s, x, y).sum(dim=(0, 3, 4), keepdim=False).detach() / N # For the non-trivial fields the mean of the fields is 0, so we can compute the variance as the mean of the # norms squared. # For trivial channels, we need to subtract the squared mean - vars[:, :, trivial_idxs, ...] -= trivial_means**2 + vars[:, trivial_idxs] -= trivial_means**2 # aggregate the squared means of the channels which belong to the same irrep - vars = torch.einsum("io,bcixy->bcoxy", (vars_aggregator, vars)) + vars = torch.einsum("io,ci->co", (vars_aggregator, vars)) # Correct the estimation of the variance with Bessel's correction correction = N/(N-1) if N > 1 else 1. @@ -321,21 +321,25 @@ def _scale(self, t: torch.Tensor, scales: torch.Tensor, name: str, out: torch.Te vars_aggregator = getattr(self, f"vars_propagator_{name}") + ndims = len(t.shape[3:]) + scale_shape = (1, scales.shape[0], vars_aggregator.shape[0]) + (1,)*ndims # scale all fields - out[...] = t * torch.einsum("oi,bcixy->bcoxy", (vars_aggregator, scales)) - - # assert torch.allclose(t, out) + out[...] = t * torch.einsum("oi,ci->co", (vars_aggregator, scales)).reshape(scale_shape) return out def _shift(self, t: torch.Tensor, trivial_bias: torch.Tensor, name: str, out: torch.Tensor = None): if out is None: - out = torch.zeros_like(t) - + out = t.clone() + else: + out[:] = t + trivial_idxs = self._trivial_idxs[name] + bias_shape = (1,) + trivial_bias.shape + (1,)*(len(t.shape) - 3) + # add bias to the trivial fields - out[:, :, trivial_idxs, ...] = t[:, :, trivial_idxs, ...] + trivial_bias - + out[:, :, trivial_idxs, ...] += trivial_bias.view(bias_shape) + return out diff --git a/e2cnn/nn/modules/batchnormalization/inner.py b/e2cnn/nn/modules/batchnormalization/inner.py index e3d30eb2..145d36d8 100644 --- a/e2cnn/nn/modules/batchnormalization/inner.py +++ b/e2cnn/nn/modules/batchnormalization/inner.py @@ -100,6 +100,16 @@ def __init__(self, track_running_stats=self.track_running_stats ) self.add_module('batch_norm_[{}]'.format(s), _batchnorm) + + def reset_running_stats(self): + for s, contiguous in self._contiguous.items(): + batchnorm = getattr(self, f'batch_norm_[{s}]') + batchnorm.reset_running_stats() + + def reset_parameters(self): + for s, contiguous in self._contiguous.items(): + batchnorm = getattr(self, f'batch_norm_[{s}]') + batchnorm.reset_parameters() def forward(self, input: GeometricTensor) -> GeometricTensor: r""" diff --git a/e2cnn/nn/modules/r2_conv/r2convolution.py b/e2cnn/nn/modules/r2_conv/r2convolution.py index 8597ab20..c07880d6 100644 --- a/e2cnn/nn/modules/r2_conv/r2convolution.py +++ b/e2cnn/nn/modules/r2_conv/r2convolution.py @@ -345,6 +345,7 @@ def forward(self, input: GeometricTensor): _filter, stride=self.stride, dilation=self.dilation, + padding=(0,0), groups=self.groups, bias=_bias) @@ -507,12 +508,30 @@ def export(self): _filter = self.filter _bias = self.expanded_bias + if self.padding_mode not in ['zeros']: + x, y = torch.__version__.split('.')[:2] + if int(x) < 1 or int(y) < 5: + if self.padding_mode == 'circular': + raise ImportError( + "'{}' padding mode had some issues in old `torch` versions. Therefore, we only support conversion from version 1.5 but only version {} is installed.".format( + self.padding_mode, torch.__version__ + ) + ) + + else: + raise ImportError( + "`torch` supports '{}' padding mode only from version 1.5 but only version {} is installed.".format( + self.padding_mode, torch.__version__ + ) + ) + # build the PyTorch Conv2d module has_bias = self.bias is not None conv = torch.nn.Conv2d(self.in_type.size, self.out_type.size, self.kernel_size, padding=self.padding, + padding_mode=self.padding_mode, stride=self.stride, dilation=self.dilation, groups=self.groups, diff --git a/test/nn/test_batchnorm.py b/test/nn/test_batchnorm.py index e0962448..b756485f 100644 --- a/test/nn/test_batchnorm.py +++ b/test/nn/test_batchnorm.py @@ -56,6 +56,23 @@ def test_dihedral_induced_norm(self): self.check_bn(bn) + def test_dihedral_inner_norm(self): + N = 8 + g = FlipRot2dOnR2(N) + + g.fibergroup._build_quotient_representations() + + reprs = [] + for r in g.representations.values(): + if 'pointwise' in r.supported_nonlinearities: + reprs.append(r) + + r = FieldType(g, reprs) + + bn = InnerBatchNorm(r, affine=False, momentum=1.) + + self.check_bn(bn) + def check_bn(self, bn: EquivariantModule): x = 10*torch.randn(300, bn.in_type.size, 1, 1) + 20 diff --git a/test/nn/test_export.py b/test/nn/test_export.py index a3c99e9e..1b9b6a3d 100644 --- a/test/nn/test_export.py +++ b/test/nn/test_export.py @@ -22,28 +22,30 @@ def test_R2Conv(self): for st in [1, 2]: for d in [1, 3]: for gr in [1, 3]: - for i in range(2): + for mode in ['zeros', 'reflect', 'replicate', 'circular']: + for i in range(2): - c_in = 1 + np.random.randint(4) - c_out = 1 + np.random.randint(4) - - c_in *= gr - c_out *= gr - - f_in = FieldType(gs, [gs.regular_repr]*c_in) - f_out = FieldType(gs, [gs.regular_repr]*c_out) - - conv = R2Conv( - f_in, f_out, - kernel_size=ks, - padding=pd, - stride=st, - dilation=d, - groups=gr, - bias=True, - ) - - self.check_exported(conv) + c_in = 1 + np.random.randint(4) + c_out = 1 + np.random.randint(4) + + c_in *= gr + c_out *= gr + + f_in = FieldType(gs, [gs.regular_repr]*c_in) + f_out = FieldType(gs, [gs.regular_repr]*c_out) + + conv = R2Conv( + f_in, f_out, + kernel_size=ks, + padding=pd, + padding_mode=mode, + stride=st, + dilation=d, + groups=gr, + bias=True, + ) + + self.check_exported(conv) def test_R2Conv_mix(self):