Skip to content

Commit

Permalink
feat: improving Lre __repr__()
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Jun 30, 2024
1 parent 685bce1 commit 011d936
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
9 changes: 9 additions & 0 deletions linear_relational/Lre.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def calculate_subject_activation(
vec = vec / vec.norm()
return vec

def __repr__(self) -> str:
return f"InvertedLre({self.relation}, rank {self.rank}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})"


class LowRankLre(nn.Module):
"""Low-rank approximation of a LRE"""
Expand Down Expand Up @@ -140,6 +143,9 @@ def calculate_object_activation(
vec = vec / vec.norm()
return vec

def __repr__(self) -> str:
return f"LowRankLre({self.relation}, rank {self.rank}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})"


class Lre(nn.Module):
"""Linear Relational Embedding"""
Expand Down Expand Up @@ -211,3 +217,6 @@ def _low_rank_svd(
low_rank_v: torch.Tensor = v[:, :rank].to(self.weight.dtype)
low_rank_s: torch.Tensor = s[:rank].to(self.weight.dtype)
return low_rank_u, low_rank_s, low_rank_v

def __repr__(self) -> str:
return f"Lre({self.relation}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})"
4 changes: 4 additions & 0 deletions tests/test_Lre.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_Lre_invert() -> None:
bias=bias,
weight=torch.eye(3),
)
assert lre.__repr__() == "Lre(test, layers 5 -> 10, mean)"
inv_lre = lre.invert(rank=2)
assert inv_lre.relation == "test"
assert inv_lre.subject_layer == 5
Expand All @@ -23,6 +24,7 @@ def test_Lre_invert() -> None:
assert inv_lre.s.shape == (2,)
assert inv_lre.v.shape == (3, 2)
assert inv_lre.rank == 2
assert inv_lre.__repr__() == "InvertedLre(test, rank 2, layers 5 -> 10, mean)"


def test_Lre_to_low_rank() -> None:
Expand All @@ -36,6 +38,7 @@ def test_Lre_to_low_rank() -> None:
weight=torch.eye(3),
)
low_rank_lre = lre.to_low_rank(rank=2)
assert lre.__repr__() == "Lre(test, layers 5 -> 10, mean)"
assert low_rank_lre.relation == "test"
assert low_rank_lre.subject_layer == 5
assert low_rank_lre.object_layer == 10
Expand All @@ -45,6 +48,7 @@ def test_Lre_to_low_rank() -> None:
assert low_rank_lre.s.shape == (2,)
assert low_rank_lre.v.shape == (3, 2)
assert low_rank_lre.rank == 2
assert low_rank_lre.__repr__() == "LowRankLre(test, rank 2, layers 5 -> 10, mean)"


def test_LowRankLre_calculate_object_activation_unnormalized() -> None:
Expand Down

0 comments on commit 011d936

Please sign in to comment.