Skip to content

Commit

Permalink
Proper initialization of learnable direction vector.
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Jul 3, 2024
1 parent d58569b commit e0c081f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion dect/ect.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ def __init__(self, config: ECTConfig, v=None):
if config.fixed:
self.v = nn.Parameter(v, requires_grad=False)
else:
self.v = nn.Parameter(v, requires_grad=True)
self.v = nn.Parameter(torch.zeros_like(v))
geotorch.constraints.sphere(self, "v")
# Since geotorch randomizes the vector during initialization, we
# assign the values after registering it with spherical constraints.
# See Geotorch documentation for examples.
self.v = nn.Parameter(v, requires_grad=True)

if config.ect_type == "points":
self.compute_ect = compute_ect_points
Expand Down

0 comments on commit e0c081f

Please sign in to comment.