From e0c081f0316a18506b05160d06484a6a985166b4 Mon Sep 17 00:00:00 2001 From: ErnstRoell Date: Wed, 3 Jul 2024 10:34:59 +0200 Subject: [PATCH] Proper initialization of learnable direction vector. --- dect/ect.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dect/ect.py b/dect/ect.py index a5d7e29..cdf91be 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -216,8 +216,12 @@ def __init__(self, config: ECTConfig, v=None): if config.fixed: self.v = nn.Parameter(v, requires_grad=False) else: - self.v = nn.Parameter(v, requires_grad=True) + self.v = nn.Parameter(torch.zeros_like(v)) geotorch.constraints.sphere(self, "v") + # Since geotorch randomizes the vector during initialization, we + # assign the values after registering it with spherical constraints. + # See Geotorch documentation for examples. + self.v = nn.Parameter(v, requires_grad=True) if config.ect_type == "points": self.compute_ect = compute_ect_points