Skip to content

Commit

Permalink
Black with line length constraint applied.
Browse files Browse the repository at this point in the history
  • Loading branch information
rballeba committed Jul 18, 2024
1 parent d42812e commit c8107e9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
18 changes: 12 additions & 6 deletions dect/ect.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def compute_ecc(
return segment_add_coo(ecc, index)


def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
def compute_ect_points(
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
"""Computes the Euler Characteristic Transform of a batch of point clouds.
Parameters
Expand All @@ -126,7 +128,9 @@ def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTenso
return compute_ecc(nh, batch.batch, lin)


def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
def compute_ect_edges(
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
"""Computes the Euler Characteristic Transform of a batch of graphs.
Parameters
Expand Down Expand Up @@ -159,7 +163,9 @@ def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
)


def compute_ect_faces(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
def compute_ect_faces(
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
"""Computes the Euler Characteristic Transform of a batch of meshes.
Parameters
Expand Down Expand Up @@ -220,9 +226,9 @@ def __init__(self, config: ECTConfig, v=None):
super().__init__()
self.config = config
self.lin = nn.Parameter(
torch.linspace(-config.radius, config.radius, config.bump_steps).view(
-1, 1, 1, 1
),
torch.linspace(
-config.radius, config.radius, config.bump_steps
).view(-1, 1, 1, 1),
requires_grad=False,
)

Expand Down
14 changes: 9 additions & 5 deletions dect/wect.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,18 @@ def compute_wect(
return compute_wecc(nh, batch.batch, lin, batch.node_weights)
if wect_type == "edges":
# noinspection PyUnboundLocalVariable
return compute_wecc(nh, batch.batch, lin, batch.node_weights) - compute_wecc(
return compute_wecc(
nh, batch.batch, lin, batch.node_weights
) - compute_wecc(
eh, batch.batch[batch.edge_index[0]], lin, edge_weights
)
if wect_type == "faces":
# noinspection PyUnboundLocalVariable
return (
compute_wecc(nh, batch.batch, lin, batch.node_weights)
- compute_wecc(eh, batch.batch[batch.edge_index[0]], lin, edge_weights)
- compute_wecc(
eh, batch.batch[batch.edge_index[0]], lin, edge_weights
)
+ compute_wecc(fh, batch.batch[batch.face[0]], lin, face_weights)
)
raise ValueError(f"Invalid wect_type: {wect_type}")
Expand All @@ -106,9 +110,9 @@ def __init__(self, config: ECTConfig, v=None):
super().__init__()
self.config = config
self.lin = nn.Parameter(
torch.linspace(-config.radius, config.radius, config.bump_steps).view(
-1, 1, 1, 1
),
torch.linspace(
-config.radius, config.radius, config.bump_steps
).view(-1, 1, 1, 1),
requires_grad=False,
)

Expand Down

0 comments on commit c8107e9

Please sign in to comment.