diff --git a/tat/compat.py b/tat/compat.py index 96f3b947c..92639e2d8 100644 --- a/tat/compat.py +++ b/tat/compat.py @@ -267,17 +267,9 @@ def parity(int_parity: int) -> bool: # Segment index -def _get_index_for_position(position: tuple[typing.Any, int], edge: E) -> int: - sym, index = position - if not isinstance(sym, tuple): - sym = (sym,) - return next(total_index for total_index in range(edge.dimension) if all( - sub_sym == sub_symmetry[total_index] for sub_sym, sub_symmetry in zip(sym, edge.symmetry))) + index - - @T._prepare_position.register # pylint: disable=protected-access,no-member def _(self: T, position: dict[str, tuple[typing.Any, int]]) -> tuple[int, ...]: - return tuple(_get_index_for_position(position[name], edge) for name, edge in zip(self.names, self.edges)) + return tuple(index_by_point(edge, position[name]) for name, edge in zip(self.names, self.edges)) # Function renames @@ -323,6 +315,28 @@ def exponential(self: T, pairs: set[tuple[str, str]], step: int | None = None) - return origin_exponential(self, pairs) +# Edge point conversion + + +@_compat_function(E) +def index_by_point(self: E, point: tuple[typing.Any, int]) -> int: + "Get index by point on an edge" + sym, sub_index = point + if not isinstance(sym, tuple): + sym = (sym,) + return next(total_index for total_index in range(self.dimension) if all( + sub_sym == sub_symmetry[total_index] for sub_sym, sub_symmetry in zip(sym, self.symmetry))) + sub_index + + +@_compat_function(E) +def point_by_index(self: E, index: int) -> tuple[typing.Any, int]: + "Get point by index on an edge" + sym = tuple(sub_symmetry[index] for sub_symmetry in self.symmetry) + sub_index = sum( + 1 for i in range(index) if all(sub_sym == sub_symmetry[i] for sub_sym, sub_symmetry in zip(sym, self.symmetry))) + return sym, sub_index + + # Random utility