diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 000000000..57c583dd2 --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,50 @@ +name: CI + +on: [push, pull_request] + +jobs: + CI: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - python-version: "3.10" + pytorch-version: "1.11" + - python-version: "3.10" + pytorch-version: "1.12" + - python-version: "3.10" + pytorch-version: "1.13" + - python-version: "3.10" + pytorch-version: "2.0" + - python-version: "3.10" + pytorch-version: "2.1" + + - python-version: "3.11" + pytorch-version: "1.13" + - python-version: "3.11" + pytorch-version: "2.0" + - python-version: "3.11" + pytorch-version: "2.1" + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install requirements + run: | + pip install pylint mypy pytest pytest-cov + pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu + pip install multimethod + - name: Run pylint + run: pylint tat + working-directory: ${{ github.workspace }} + - name: Run mypy + run: mypy tat + working-directory: ${{ github.workspace }} + - name: Run pytest + run: pytest + working-directory: ${{ github.workspace }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..83b293b69 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.coverage +__pycache__ +env \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000..3370c23ae --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# TAT + +A Fermionic tensor library based on pytorch. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..44e627766 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "tat" +version = "0.4.0" +authors = [ + {email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"} +] +description = "A Fermionic tensor library based on pytorch." +readme = "README.md" +requires-python = ">=3.10" +license = {text = "GPL-3.0-or-later"} +dependencies = [ + "multimethod", + "torch", +] + +[tool.pylint] +max-line-length = 120 +generated-members = "torch.*" + +[tool.yapf] +based_on_style = "google" +column_limit = 120 + +[tool.mypy] +check_untyped_defs = true +disallow_untyped_defs = true + +[tool.pytest.ini_options] +pythonpath = "." +testpaths = ["tests",] +addopts = "--cov=tat" diff --git a/tat/__init__.py b/tat/__init__.py new file mode 100644 index 000000000..9f8cc4bcb --- /dev/null +++ b/tat/__init__.py @@ -0,0 +1,6 @@ +""" +The tat is a Fermionic tensor library based on pytorch. +""" + +from .edge import Edge +from .tensor import Tensor diff --git a/tat/_utility.py b/tat/_utility.py new file mode 100644 index 000000000..1e890d29e --- /dev/null +++ b/tat/_utility.py @@ -0,0 +1,41 @@ +""" +Some internal utility used by tat. +""" + +import torch + +# pylint: disable=missing-function-docstring +# pylint: disable=no-else-return + + +def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor: + return tensor.reshape([-1 if i == index else 1 for i in range(rank)]) + + +def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype is torch.bool: + return tensor + else: + return -tensor + + +def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: + if tensor_1.dtype is torch.bool: + return torch.logical_xor(tensor_1, tensor_2) + else: + return torch.add(tensor_1, tensor_2) + + +def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor: + # pylint: disable=singleton-comparison + if tensor.dtype is torch.bool: + return tensor == False + else: + return tensor == 0 + + +def parity(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype is torch.bool: + return tensor + else: + return tensor % 2 != 0 diff --git a/tat/compat.py b/tat/compat.py new file mode 100644 index 000000000..c3cae443c --- /dev/null +++ b/tat/compat.py @@ -0,0 +1,191 @@ +""" +This file implement a compat layer for legacy TAT interface. +""" + +from __future__ import annotations +import typing +from multimethod import multimethod +import torch +from .edge import Edge as E +from .tensor import Tensor as T + +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-instance-attributes + + +class CompatSymmetry: + """ + The common Symmetry namespace + """ + + def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None: + self.fermion: tuple[bool, ...] = fermion + self.dtypes: tuple[torch.dtype, ...] = dtypes + + # pylint: disable=invalid-name + self.S: CompatScalar + self.D: CompatScalar + self.C: CompatScalar + self.Z: CompatScalar + self.float32: CompatScalar + self.float64: CompatScalar + self.float: CompatScalar + self.complex64: CompatScalar + self.complex128: CompatScalar + self.complex: CompatScalar + + self.S = self.float32 = CompatScalar(self, torch.float32) + self.D = self.float64 = self.float = CompatScalar(self, torch.float64) + self.C = self.complex64 = CompatScalar(self, torch.complex64) + self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128) + + def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: + # Segments may be [Sym] or [(Sym, Size)] + try: + # try [(Sym, Size)] first + return self._parse_segments_kernel(segments) + except TypeError: + # Cannot unpack is a type error, value[index] is a type error, too + # convert [Sym] to [(Sym, Size)] + return self._parse_segments_kernel([(sym, 1) for sym in segments]) + + def _parse_segments_kernel(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: + # [(Sym, Size)] for every element + dimension = sum(dim for _, dim in segments) + symmetry = tuple( + torch.tensor( + sum( + ([self._parse_segments_get_subsymmetry(sym, index)] * dim + for sym, dim in segments), + [], + ), # Concat all segment for this subsymmetry + dtype=sub_symmetry, + ) # Generate subsymmetry one by one + for index, sub_symmetry in enumerate(self.dtypes)) + return symmetry, dimension + + def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object: + # Most of time, symmetry is a tuple of subsymmetry + # But if there is only ome subsymmetry, it could not be a tuple but subsymmetry itself. + # pylint: disable=no-else-return + if isinstance(sym, tuple): + return sym[index] + else: + if len(self.fermion) == 1: + return sym + else: + raise TypeError(f"{sym=} is not subscriptable") + + @multimethod + def Edge(self: CompatSymmetry, dimension: int) -> E: + """ + Create edge with compat interface. + """ + # pylint: disable=invalid-name + symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False) + + @Edge.register + def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E: + symmetry, dimension = self._parse_segments(segments) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) + + @Edge.register + def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E: + segments, arrow = segments_and_bool + symmetry, dimension = self._parse_segments(segments) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) + + +class CompatScalar: + """ + The common Scalar namespace. + """ + + def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None: + self.symmetry: CompatSymmetry = symmetry + self.dtype: torch.dtype = dtype + + @multimethod + def Tensor(self: CompatScalar, names: list[str], edges: list) -> T: + """ + Create tensor with compat names and edges. + """ + # pylint: disable=invalid-name + return T( + tuple(names), + tuple(self.symmetry.Edge(edge) for edge in edges), + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + dtype=self.dtype, + ) + + @Tensor.register + def _(self: CompatScalar) -> T: + result = T( + (), + (), + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + dtype=self.dtype, + ) + result.data.reshape([-1])[0] = 1 + return result + + @Tensor.register + def _( + self: CompatScalar, + number: typing.Any, + names: list[str] | None = None, + edge_symmetry: list | None = None, + edge_arrow: list[bool] | None = None, + ) -> T: + # Create high rank tensor with only one element + if names is None: + names = [] + if edge_symmetry is None: + edge_symmetry = [None for _ in names] + if edge_arrow is None: + edge_arrow = [False for _ in names] + result = T( + tuple(names), + tuple( + E( + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + symmetry=tuple( + torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype) + for index, dtype in enumerate(self.symmetry.dtypes)), + dimension=1, + arrow=arrow, + ) + for symmetry, arrow in zip(edge_symmetry, edge_arrow)), + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + dtype=self.dtype, + ) + result.data.reshape([-1])[0] = number + return result + + def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object: + # pylint: disable=no-else-return + if sym is None: + return 0 + elif isinstance(sym, tuple): + return sym[index] + else: + if len(self.symmetry.fermion) == 1: + return sym + else: + raise TypeError(f"{sym=} is not subscriptable") + + +No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=()) +Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,)) +U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,)) +Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,)) +FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool)) +FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int)) +Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,)) +FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int)) +Normal: CompatSymmetry = No diff --git a/tat/edge.py b/tat/edge.py new file mode 100644 index 000000000..7e8d9af78 --- /dev/null +++ b/tat/edge.py @@ -0,0 +1,278 @@ +""" +This file contains the definition of tensor edge. +""" + +from __future__ import annotations +import functools +import operator +import torch +from . import _utility + +# pylint: disable=too-many-arguments + + +class Edge: + """ + The edge type of tensor. + """ + + __slots__ = "_fermion", "_dtypes", "_symmetry", "_dimension", "_arrow", "_parity" + + @property + def fermion(self: Edge) -> tuple[bool, ...]: + """ + A tuple records whether every sub symmetry is fermionic. Its length is the number of sub symmetry. + """ + return self._fermion + + @property + def dtypes(self: Edge) -> tuple[torch.dtype, ...]: + """ + A tuple records the basic dtype of every sub symmetry. Its length is the number of sub symmetry. + """ + return self._dtypes + + @property + def symmetry(self: Edge) -> tuple[torch.Tensor, ...]: + """ + A tuple containing all symmetry of this edge. Its length is the number of sub symmetry. Every element of it is a + sub symmetry. + """ + return self._symmetry + + @property + def dimension(self: Edge) -> int: + """ + The dimension of this edge. + """ + return self._dimension + + @property + def arrow(self: Edge) -> bool: + """ + The arrow of this edge. + """ + return self._arrow + + @property + def parity(self: Edge) -> torch.Tensor: + """ + The parity of this edge. + """ + return self._parity + + def __init__( + self: Edge, + *, + fermion: tuple[bool, ...] | None = None, + dtypes: tuple[torch.dtype, ...] | None = None, + symmetry: tuple[torch.Tensor, ...] | None = None, + dimension: int | None = None, + arrow: bool | None = None, + **kwargs: torch.Tensor, + ) -> None: + """ + Create an edge with essential information. + + Examples: + - Edge(dimension=5) + - Edge(symmetry=(torch.tensor([False, False, True, True]),)) + - Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([False, True])), arrow=True) + + Parameters + ---------- + fermion : tuple[bool, ...], optional + Whether each sub symmetry is fermionic symmetry, its length should be the same to symmetry. But it could be + left empty, if so, a total bosonic edge will be created. + dtypes : tuple[torch.dtype, ...], optional + The basic dtype to identify each sub symmetry, its length should be the same to symmetry, and it is nothing + but the dtypes of each tensor in the symmetry. It could be left empty, if so, it will be derived from + symmetry. + symmetry : tuple[torch.Tensor, ...], optional + The symmetry information of every sub symmetry, each of sub symmetry should be a one dimensional tensor with + the same length dimension, and their dtype should be integral type, aka, int or bool. + dimension : int, optional + The dimension of the edge, if not specified, dimension will be detected from symmetry. + arrow : bool, optional + The arrow direction of the edge, it is essential for fermionic edge, aka, an edge with fermionic sub + symmetry. + """ + # Symmetry could be left empty to create no symmetry edge + if symmetry is None: + symmetry = () + + # Fermion could be empty if it is total bosonic edge + if fermion is None: + fermion = tuple(False for _ in symmetry) + + # Dtypes could be empty and derived from symmetry + if dtypes is None: + dtypes = tuple(sub_symmetry.dtype for sub_symmetry in symmetry) + # Check dtype is compatible with symmetry + assert all(sub_symmetry.dtype is sub_dtype for sub_symmetry, sub_dtype in zip(symmetry, dtypes)) + # Check dtype is valid, aka, bool or int + assert all(not (sub_symmetry.is_floating_point() or sub_symmetry.is_complex()) for sub_symmetry in symmetry) + + # The fermion, dtypes and symmetry information should have the same length + assert len(fermion) == len(dtypes) == len(symmetry) + + # If dimension not set, get dimension from symmetry + if dimension is None: + dimension = len(symmetry[0]) + # Check if the dimensions of different sub_symmetry mismatch + assert all(sub_symmetry.size() == (dimension,) for sub_symmetry in symmetry) + + if arrow is None: + # Arrow not set, it should be bosonic edge. + arrow = False + assert not any(fermion) + + self._fermion: tuple[bool, ...] = fermion + self._dtypes: tuple[torch.dtype, ...] = dtypes + self._symmetry: tuple[torch.Tensor, ...] = symmetry + self._dimension: int = dimension + self._arrow: bool = arrow + + self._parity: torch.Tensor + if "parity" in kwargs: + self._parity = kwargs.pop("parity") + else: + self._parity = self._generate_parity() + assert not kwargs + + def _generate_parity(self: Edge) -> torch.Tensor: + return functools.reduce( + # Reduce sub parity by xor + torch.logical_xor, + ( + # The parity of sub symmetry + _utility.parity(sub_symmetry) + # Loop all sub symmetry + for sub_symmetry, sub_fermion in zip(self.symmetry, self.fermion) + # But only reduce if it is fermion sub symmetry + if sub_fermion), + # Reduce with start as tensor filled with False + torch.zeros(self.dimension, dtype=torch.bool), + ) + + def conjugate(self: Edge) -> Edge: + """ + Get the conjugated edge. + + Returns + ------- + Edge + The conjugated edge. + """ + # The only two difference of conjguated edge is symmetry and arrow + return Edge( + fermion=self.fermion, + dtypes=self.dtypes, + symmetry=tuple( + _utility.neg_symmetry(sub_symmetry) # bool -> same, int -> neg + for sub_symmetry in self.symmetry), + dimension=self.dimension, + arrow=not self.arrow, + parity=self.parity, + ) + + def __eq__(self: Edge, other: object) -> bool: + if not isinstance(other, Edge): + return NotImplemented + return (self.dimension == other.dimension and # Compare int dimension and bool arrow first + self.arrow == other.arrow and # Since they are fast to compare + self.fermion == other.fermion and # Then the tuple of bool are compared + self.dtypes == other.dtypes and # Then the tuple of dtypes are compared + all( # All of symmetries are compared at last, since it is biggest + torch.equal(self_sub_symmetry, other_sub_symmetry) + for self_sub_symmetry, other_sub_symmetry in zip(self.symmetry, other.symmetry))) + + def __str__(self: Edge) -> str: + # pylint: disable=no-else-return + if any(self.fermion): + # Fermionic edge + return f"(dimension={self.dimension}, arrow={self.arrow}, fermion={self.fermion}, symmetry={self.symmetry})" + elif self.fermion: + # Bosonic edge + return f"(dimension={self.dimension}, symmetry={self.symmetry})" + else: + # Trivial edge + return f"(dimension={self.dimension})" + + def __repr__(self: Edge) -> str: + return f"Edge{self.__str__()}" + + @staticmethod + def merge_edges( + edges: tuple[Edge, ...], + fermion: tuple[bool, ...] | None = None, + dtypes: tuple[torch.dtype, ...] | None = None, + arrow: bool | None = None, + ) -> Edge: + """ + Merge several edges into one edge. + + Parameters + ---------- + edges : tuple[Edge, ...] + The edges to be merged. + fermion : tuple[bool, ...], optional + Whether each sub symmetry is fermionic, it could be left empty to derive from edges + dtypes : tuple[torch.dtype, ...], optional + The base type of sub symmetry, it could be left empty to derive from edges + arrow : bool, optional + The arrow of all the edges, it is useful if edges is empty. + + Returns + ------- + Edge + The result edge merged by edges. + """ + # If fermion not set, get it from edges + if fermion is None: + fermion = edges[0].fermion + # All edge should share the same fermion + assert all(fermion == edge.fermion for edge in edges) + # If dtypes not set, get it from edges + if dtypes is None: + dtypes = edges[0].dtypes + # All edge should share the same dtypes + assert all(dtypes == edge.dtypes for edge in edges) + # If arrow set, check it directly, if not set, set to False or get from edges + if arrow is None: + if any(fermion): + # It is fermionic edge. + arrow = edges[0].arrow + else: + # It is bosonic edge, set to False directly since it is useless. + arrow = False + # All edge should share the same arrow + assert all(arrow == edge.arrow for edge in edges) + + rank = len(edges) + # Merge edge + dimension = functools.reduce(operator.mul, (edge.dimension for edge in edges), 1) + symmetry = tuple( + # Every merged sub symmetry is calculated by reduce and flatten + functools.reduce( + # The reduce operator depend on the dtype of this sub symmetry + _utility.add_symmetry, + ( + # The sub symmetry of every edge will be reshape to be reduced. + _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, rank) + # The sub symmetry of every edge is reduced one by one + for current_index, edge in enumerate(edges)), + # Reduce from a rank-0 tensor + torch.zeros([], dtype=sub_symmetry_dtype), + ).reshape([-1]) + # Merge every sub symmetry one by one + for sub_symmetry_index, sub_symmetry_dtype in enumerate(dtypes)) + + # parity not set here since it need recalculation + return Edge( + fermion=fermion, + dtypes=dtypes, + symmetry=symmetry, + dimension=dimension, + arrow=arrow, + ) diff --git a/tat/tensor.py b/tat/tensor.py new file mode 100644 index 000000000..003fda7d6 --- /dev/null +++ b/tat/tensor.py @@ -0,0 +1,1727 @@ +""" +This file defined the core tensor type for tat package. +""" + +from __future__ import annotations +import typing +import operator +import functools +from multimethod import multimethod +import torch +from . import _utility +from .edge import Edge + +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-arguments +# pylint: disable=too-many-lines + + +class Tensor: + """ + The main tensor type, which wraps pytorch tensor and provides edge names and Fermionic functions. + """ + + __slots__ = "_fermion", "_dtypes", "_names", "_edges", "_data", "_mask" + + def __str__(self: Tensor) -> str: + return f"(names={self.names}, edges={self.edges}, data={self.data})" + + def __repr__(self: Tensor) -> str: + return f"Tensor(names={self.names}, edges={self.edges})" + + @property + def fermion(self: Tensor) -> tuple[bool, ...]: + """ + A tuple records whether every sub symmetry is fermionic. Its length is the number of sub symmetry. + """ + return self._fermion + + @property + def dtypes(self: Tensor) -> tuple[torch.dtype, ...]: + """ + A tuple records the basic dtype of every sub symmetry. Its length is the number of sub symmetry. + """ + return self._dtypes + + @property + def names(self: Tensor) -> tuple[str, ...]: + """ + The edge names of this tensor. + """ + return self._names + + @property + def edges(self: Tensor) -> tuple[Edge, ...]: + """ + The edges information of this tensor. + """ + return self._edges + + @property + def data(self: Tensor) -> torch.Tensor: + """ + The content data of this tensor. + """ + return self._data + + @property + def mask(self: Tensor) -> torch.Tensor: + """ + The content data mask of this tensor. + """ + return self._mask + + @property + def rank(self: Tensor) -> int: + """ + The rank of this tensor. + """ + return len(self._names) + + @property + def dtype(self: Tensor) -> torch.dtype: + """ + The data type of the content in this tensor. + """ + return self.data.dtype + + @property + def btype(self: Tensor) -> str: + """ + The data type of the content in this tensor, represented in BLAS/LAPACK convention. + """ + if self.dtype is torch.float32: + return 'S' + if self.dtype is torch.float64: + return 'D' + if self.dtype is torch.complex64: + return 'C' + if self.dtype is torch.complex128: + return 'Z' + return '?' + + @property + def is_complex(self: Tensor) -> bool: + """ + Whether it is a complex tensor + """ + return self.dtype.is_complex + + @property + def is_real(self: Tensor) -> bool: + """ + Whether it is a real tensor + """ + return self.dtype.is_floating_point + + def edge_by_name(self: Tensor, name: str) -> Edge: + """ + Get edge by the edge name of this tensor. + + Parameters + ---------- + name : str + The given edge name. + + Returns + ------- + Edge + The edge with the given edge name. + """ + assert name in self.names + return self.edges[self.names.index(name)] + + def _arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + new_data = operate(self.data, other.transpose(self.names).data) + else: + # Otherwise treat other as a scalar, mask should be applied later. + new_data = operate(self.data, other) + new_data *= self.mask + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=new_data, + mask=self.mask, + ) + + def __add__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.add) + + def __sub__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.sub) + + def __mul__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.mul) + + def __truediv__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.div) + + def _right_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + new_data = operate(other.transpose(self.names).data, self.data) + else: + # Otherwise treat other as a scalar, mask should be applied later. + new_data = operate(other, self.data) + new_data *= self.mask + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=new_data, + mask=self.mask, + ) + + def __radd__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.add) + + def __rsub__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.sub) + + def __rmul__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.mul) + + def __rtruediv__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.div) + + def _inplace_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + operate(self.data, other.transpose(self.names).data, out=self.data) + else: + # Otherwise treat other as a scalar, mask should be applied later. + operate(self.data, other, out=self.data) + torch.mul(self.data, self.mask, out=self.data) + return self + + def __iadd__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.add) + + def __isub__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.sub) + + def __imul__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.mul) + + def __itruediv__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.div) + + def __pos__(self: Tensor) -> Tensor: + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=+self.data, + mask=self.mask, + ) + + def __neg__(self: Tensor) -> Tensor: + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=-self.data, + mask=self.mask, + ) + + def __float__(self: Tensor) -> float: + return float(self.data) + + def __complex__(self: Tensor) -> complex: + return complex(self.data) + + def norm(self: Tensor, order: typing.Any) -> float: + """ + Get the norm of the tensor, this function will flatten tensor first before calculate norm. + + Parameters + ---------- + order + The order of norm. + + Returns + ------- + float + The norm of the tensor. + """ + # pylint: disable=not-callable + # I do not know why pylint in github action complains this + return torch.linalg.vector_norm(self.data, ord=order) + + def norm_max(self: Tensor) -> float: + "max norm" + return self.norm(+torch.inf) + + def norm_min(self: Tensor) -> float: + "min norm" + return self.norm(-torch.inf) + + def norm_num(self: Tensor) -> float: + "0-norm" + return self.norm(0) + + def norm_sum(self: Tensor) -> float: + "1-norm" + return self.norm(1) + + def norm_2(self: Tensor) -> float: + "2-norm" + return self.norm(2) + + def copy(self: Tensor) -> Tensor: + """ + Get a copy of this tensor + + Returns + ------- + Tensor + The copy of this tensor + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.clone(self.data), + mask=self.mask, + ) + + def __copy__(self: Tensor) -> Tensor: + return self.copy() + + def __deepcopy__(self: Tensor, _: typing.Any = None) -> Tensor: + return self.copy() + + def same_shape(self: Tensor) -> Tensor: + """ + Get a tensor with same shape to this tensor + + Returns + ------- + Tensor + A new tensor with the same shape to this tensor + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.zeros_like(self.data), + mask=self.mask, + ) + + def zero_(self: Tensor) -> Tensor: + """ + Set all element to zero in this tensor + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.zero_() + return self + + def sqrt(self: Tensor) -> Tensor: + """ + Get the sqrt of the tensor. + + Returns + ------- + Tensor + The sqrt of this tensor. + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.sqrt(torch.abs(self.data)), + mask=self.mask, + ) + + def reciprocal(self: Tensor) -> Tensor: + """ + Get the reciprocal of the tensor. + + Returns + ------- + Tensor + The reciprocal of this tensor. + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.where(self.data == 0, self.data, 1 / self.data), + mask=self.mask, + ) + + def range_(self: Tensor, first: typing.Any = 0, step: typing.Any = 1) -> Tensor: + """ + A useful function to generate simple data in tensor for test. + + Parameters + ---------- + first, step + Parameters to generate data. + + Returns + ------- + Tensor + Returns the tensor itself. + """ + data = torch.cumsum(self.mask.reshape([-1]), dim=0, dtype=self.data.dtype).reshape(self.data.size()) + data = (data - 1) * step + first + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=self.mask, + ) + + def to(self: Tensor, new_type: typing.Any) -> Tensor: + """ + Convert this tensor to other scalar type. + + Parameters + ---------- + new_type + The scalar data type of the new tensor. + """ + # pylint: disable=invalid-name + if isinstance(new_type, str): + if new_type in ["float32", "S"]: + new_type = torch.float32 + elif new_type in ["float64", "float", "D"]: + new_type = torch.float64 + elif new_type in ["complex64", "C"]: + new_type = torch.complex64 + elif new_type in ["complex128", "complex", "Z"]: + new_type = torch.complex128 + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=self.data.to(new_type), + mask=self.mask, + ) + + def __init__( + self: Tensor, + names: tuple[str, ...], + edges: tuple[Edge, ...], + *, + dtype: torch.dtype = torch.float, + fermion: tuple[bool, ...] | None = None, + dtypes: tuple[torch.dtype, ...] | None = None, + **kwargs: torch.Tensor, + ) -> None: + """ + Create a tensor with specific shape. + + Parameters + ---------- + names : tuple[str, ...] + The edge names of the tensor, which length is just the tensor rank. + edges : tuple[Edge, ...] + The detail information of each edge, which length is just the tensor rank. + dtype : torch.dtype, default=torch.float + The dtype of the tensor. + fermion : tuple[bool, ...], optional + Whether each sub symmetry is fermionic, it could be left empty to derive from edges + dtypes : tuple[torch.dtype, ...], optional + The base type of sub symmetry, it could be left empty to derive from edges + """ + # Check the rank is correct in names and edges + assert len(names) == len(edges) + # Check whether there are duplicated names + assert len(set(names)) == len(names) + # If fermion not set, get it from edges + if fermion is None: + fermion = edges[0].fermion + # If dtypes not set, get it from edges + if dtypes is None: + dtypes = edges[0].dtypes + # Check if fermion is correct + assert all(edge.fermion == fermion for edge in edges) + # Check if dtypes is correct + assert all(edge.dtypes == dtypes for edge in edges) + + self._fermion: tuple[bool, ...] = fermion + self._dtypes: tuple[torch.dtype, ...] = dtypes + self._names: tuple[str, ...] = names + self._edges: tuple[Edge, ...] = edges + + self._data: torch.Tensor + if "data" in kwargs: + self._data = kwargs.pop("data") + assert self.data.size() == tuple(edge.dimension for edge in self.edges) + else: + self._data = torch.zeros( + [edge.dimension for edge in self.edges], + dtype=dtype, + ) + self._mask: torch.Tensor + if "mask" in kwargs: + self._mask = kwargs.pop("mask") + assert self.mask.size() == self.data.size() + assert self.mask.dtype is torch.bool + else: + self._mask = self._generate_mask() + assert not kwargs + + def _generate_mask(self: Tensor) -> torch.Tensor: + return functools.reduce( + # Mask is valid if all sub symmetry give valid mask. + torch.logical_and, + ( + # The mask is valid if total symmetry is False or total symmetry is 0 + _utility.zero_symmetry( + # total sub symmetry is calculated by reduce + functools.reduce( + # The reduce operator depend on the dtype of this sub symmetry + _utility.add_symmetry, + ( + # The sub symmetry of every edge will be reshape to be reduced. + _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, self.rank) + # The sub symmetry of every edge is reduced one by one + for current_index, edge in enumerate(self.edges)), + # Reduce from a rank-0 tensor + torch.zeros([], dtype=sub_symmetry_dtype), + )) + # Calculate mask on every sub symmetry one by one + for sub_symmetry_index, sub_symmetry_dtype in enumerate(self.dtypes)), + # Reduce from all true mask. + torch.ones(size=self.data.size(), dtype=torch.bool), + ) + + @multimethod + def _prepare_position(self: Tensor, position: tuple[int, ...]) -> tuple[int, ...]: + return position + + @_prepare_position.register + def _(self: Tensor, position: tuple[slice, ...]) -> tuple[int, ...]: + index_by_name: dict[str, int] = {s.start: s.stop for s in position} + return tuple(index_by_name[name] for name in self.names) + + @_prepare_position.register + def _(self: Tensor, position: dict[str, int]) -> tuple[int, ...]: + return tuple(position[name] for name in self.names) + + def __getitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int]) -> typing.Any: + """ + Get the element of the tensor + + Parameters + ---------- + position : tuple[int, ...] | tuple[slice, ...] | dict[str, int] + The position of the element, which could be either tuple of index directly or a map from edge name to the + index in the corresponding edge. + """ + indices: tuple[int, ...] = self._prepare_position(position) + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + return self.data[indices] + + def __setitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int], + value: typing.Any) -> None: + """ + Set the element of the tensor + + Parameters + ---------- + position : tuple[int, ...] | tuple[slice, ...] | dict[str, int] + The position of the element, which could be either tuple of index directly or a map from edge name to the + index in the corresponding edge. + """ + indices = self._prepare_position(position) + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + if self.mask[indices]: + self.data[indices] = value + + def clear_symmetry(self: Tensor) -> Tensor: + """ + Clear all symmetry of this tensor. + + Returns + ------- + Tensor + The result tensor with symmetry cleared. + """ + # Mask must be generated again + # pylint: disable=no-else-return + if any(self.fermion): + return Tensor( + names=self.names, + edges=tuple( + Edge( + fermion=(True,), + dtypes=(torch.bool,), + symmetry=(edge.parity,), + dimension=edge.dimension, + arrow=edge.arrow, + parity=edge.parity, + ) for edge in self.edges), + fermion=(True,), + dtypes=(torch.bool,), + data=self.data, + ) + else: + return Tensor( + names=self.names, + edges=tuple( + Edge( + fermion=(), + dtypes=(), + symmetry=(), + dimension=edge.dimension, + arrow=edge.arrow, + parity=edge.parity, + ) for edge in self.edges), + fermion=(), + dtypes=(), + data=self.data, + ) + + def randn_(self: Tensor, mean: float = 0., std: float = 1.) -> Tensor: + """ + Fill the tensor with random number in normal distribution. + + Parameters + ---------- + mean, std : float + The parameter of normal distribution. + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.normal_(mean, std) + torch.mul(self.data, self.mask, out=self.data) + return self + + def rand_(self: Tensor, low: float = 0., high: float = 1.) -> Tensor: + """ + Fill the tensor with random number in uniform distribution. + + Parameters + ---------- + low, high : float + The parameter of uniform distribution. + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.uniform_(low, high) + torch.mul(self.data, self.mask, out=self.data) + return self + + def same_type_with(self: Tensor, other: Tensor) -> bool: + """ + Check whether two tensor has the same type, that is to say they share the same symmetry type. + """ + return self.fermion == other.fermion and self.dtypes == other.dtypes + + def same_shape_with(self: Tensor, other: Tensor, *, allow_transpose: bool = True) -> bool: + """ + Check whether two tensor has the same shape, that is to say the only differences between them are transpose and + data difference. + """ + if not self.same_type_with(other): + return False + # pylint: disable=no-else-return + if allow_transpose: + return dict(zip(self.names, self.edges)) == dict(zip(other.names, other.edges)) + else: + return self.names == other.names and self.edges == other.edges + + def edge_rename(self: Tensor, name_map: dict[str, str]) -> Tensor: + """ + Rename edge name for this tensor. + + Parameters + ---------- + name_map : dict[str, str] + The name map to be used in renaming edge name. + + Returns + ------- + Tensor + The tensor with names renamed. + """ + return Tensor( + names=tuple(name_map.get(name, name) for name in self.names), + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=self.data, + mask=self.mask, + ) + + def transpose(self: Tensor, names: tuple[str, ...]) -> Tensor: + """ + Transpose the tensor outplace. + + Parameters + ---------- + names : tuple[str, ...] + The new edge order identified by edge names. + + Returns + ------- + Tensor + The transpsoe tensor. + """ + if names == self.names: + return self + assert len(names) == len(self.names) + assert set(names) == set(self.names) + before_by_after = tuple(self.names.index(name) for name in names) + after_by_before = tuple(names.index(name) for name in self.names) + data = self.data.permute(before_by_after) + mask = self.mask.permute(before_by_after) + if any(self.fermion): + # It is fermionic tensor + parities_before_transpose = tuple( + _utility.unsqueeze(edge.parity, current_index, self.rank) + for current_index, edge in enumerate(self.edges)) + # Generate parity by xor all inverse pairs + parity = functools.reduce( + torch.logical_xor, + ( + torch.logical_and(parities_before_transpose[i], parities_before_transpose[j]) + # Loop every 0 <= i < j < rank + for j in range(self.rank) + for i in range(0, j) + if after_by_before[i] > after_by_before[j]), + torch.zeros([], dtype=torch.bool)) + # parity True -> -x + # parity False -> +x + data = torch.where(parity, -data, +data) + return Tensor( + names=names, + edges=tuple(self.edges[index] for index in before_by_after), + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=mask, + ) + + def reverse_edge( + self: Tensor, + reversed_names: set[str], + apply_parity: bool = False, + parity_exclude_names: set[str] | None = None, + ) -> Tensor: + """ + Reverse some edge in the tensor. + + Parameters + ---------- + reversed_names : set[str] + The edge names of those edges which will be reversed + apply_parity : bool, default=False + Whether to apply parity caused by reversing edge, since reversing edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + + Returns + ------- + Tensor + The tensor with edges reversed. + """ + if not any(self.fermion): + return self + if parity_exclude_names is None: + parity_exclude_names = set() + # Parity is xor of all valid reverse parity + parity = functools.reduce( + torch.logical_xor, + ( + _utility.unsqueeze(edge.parity, current_index, self.rank) + # Loop over all edge + for current_index, [name, edge] in enumerate(zip(self.names, self.edges)) + # Check if this edge is reversed and if this edge will be applied parity + if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))), + torch.zeros([], dtype=torch.bool), + ) + data = torch.where(parity, -self.data, +self.data) + return Tensor( + names=self.names, + edges=tuple( + Edge( + fermion=edge.fermion, + dtypes=edge.dtypes, + symmetry=edge.symmetry, + dimension=edge.dimension, + arrow=not edge.arrow if self.names[current_index] in reversed_names else edge.arrow, + parity=edge.parity, + ) for current_index, edge in enumerate(self.edges)), + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=self.mask, + ) + + def split_edge( + self: Tensor, + split_map: dict[str, tuple[tuple[str, Edge], ...]], + apply_parity: bool = False, + parity_exclude_names: set[str] | None = None, + ) -> Tensor: + """ + Split some edges in this tensor. + + Parameters + ---------- + split_map : dict[str, tuple[tuple[str, Edge], ...]] + The edge splitting plan. + apply_parity : bool, default=False + Whether to apply parity caused by splitting edge, since splitting edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + + Returns + ------- + Tensor + The tensor with edges splitted. + """ + if parity_exclude_names is None: + parity_exclude_names = set() + # Check the edge to be splitted can be merged by result edges. + assert all( + self.edge_by_name(name) == Edge.merge_edges( + tuple(new_edge for _, new_edge in split_result), + fermion=self.fermion, + dtypes=self.dtypes, + arrow=self.edge_by_name(name).arrow, + ) for name, split_result in split_map.items()) + # Calculate the result components + names: tuple[str, ...] = tuple( + # Convert the list generated by reduce to tuple + functools.reduce( + # Concat list + operator.add, + # If name in split_map, use the new names list, otherwise use name itself as a length-1 list + ([new_name for new_name, _ in split_map[name]] if name in split_map else [name] for name in self.names), + # Reduce from [] to concat all list + [], + )) + edges: tuple[Edge, ...] = tuple( + # Convert the list generated by reduce to tuple + functools.reduce( + # Concat list + operator.add, + # If name in split_map, use the new edges list, otherwise use the edge itself as a length-1 list + ([new_edge + for _, new_edge in split_map[name]] if name in split_map else [edge] + for name, edge in zip(self.names, self.edges)), + # Reduce from [] to concat all list + [], + )) + new_size = [edge.dimension for edge in edges] + data = self.data.reshape(new_size) + mask = self.mask.reshape(new_size) + + # Apply parity + if any(self.fermion): + # It is femionic tensor, parity need to be applied + new_rank = len(names) + # Parity is xor of all valid splitting parity + parity = functools.reduce( + torch.logical_xor, + ( + # Apply the parity for this splitting group here + # It is need to perform a total transpose on this split group + # {sum 0<=i tuple[str, ...]: + reversed_names: list[str] = [] + for name in reversed(self.names): + found_in_merge_map: tuple[str, tuple[str, ...]] | None = next( + ((new_name, old_names) for new_name, old_names in merge_map.items() if name in old_names), None) + if found_in_merge_map is None: + # This edge will not be merged + reversed_names.append(name) + else: + new_name, old_names = found_in_merge_map + # This edge will be merged + if name == old_names[-1]: + # Add new edge only if it is the last edge + reversed_names.append(new_name) + # Some edge is merged from no edges, it should be considered + for new_name, old_names in merge_map.items(): + if not old_names: + reversed_names.append(new_name) + return tuple(reversed(reversed_names)) + + def merge_edge( + self: Tensor, + merge_map: dict[str, tuple[str, ...]], + apply_parity: bool = False, + parity_exclude_names: set[str] | None = None, + merge_arrow: dict[str, bool] | None = None, + names: tuple[str, ...] | None = None, + ) -> Tensor: + """ + Merge some edges in this tensor. + + Parameters + ---------- + merge_map : dict[str, tuple[str, ...]] + The edge merging plan. + apply_parity : bool, default=False + Whether to apply parity caused by merging edge, since merging edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + merge_arrow : dict[str, bool], optional + For merging edge from zero edges, arrow cannot be identified automatically, it requires user set manually. + names : tuple[str, ...], optional + The edge order of the result, sometimes user may want to specify it manually. + + Returns + ------- + Tensor + The tensor with edges merged. + """ + if parity_exclude_names is None: + parity_exclude_names = set() + if merge_arrow is None: + merge_arrow = {} + # Two steps: 1. Transpose 2. Merge + if names is None: + names = self._merge_edge_get_names(merge_map) + transposed_names: tuple[str, ...] = tuple( + # Convert the list generated by reduce to tuple + functools.reduce( + # Concat list + operator.add, + # If name in merge_map, use the old names list, otherwise use name itself as a length-1 list + (list(merge_map[name]) if name in merge_map else [name] for name in names), + # Reduce from [] to concat all list + [], + )) + transposed_tensor = self.transpose(transposed_names) + edges = tuple( + Edge.merge_edges( + edges=tuple(transposed_tensor.edge_by_name(old_name) for old_name in merge_map[name]), + fermion=self.fermion, + dtypes=self.dtypes, + arrow=merge_arrow.get(name, None), # For merging edge from zero edge, arrow need to be set manually + ) if name in merge_map else transposed_tensor.edge_by_name(name) for name in names) + transposed_data = transposed_tensor.data + transposed_mask = transposed_tensor.mask + + # Apply parity + if any(self.fermion): + # It is femionic tensor, parity need to be applied + # Parity is xor of all valid splitting parity + parity = functools.reduce( + torch.logical_xor, + ( + # Apply the parity for this merging group here + # It is need to perform a total transpose on this merging group + # {sum 0<=i Tensor: + """ + Contract two tensors. + + Parameters + ---------- + other : Tensor + Another tensor to be contracted. + contract_pairs : set[tuple[str, str]] + The pairs of edges to be contract between two tensors. + fuse_names : set[str], optional + The set of edges to be fuses. + + Returns + ------- + Tensor + The result contracted by two tensors. + """ + # pylint: disable=too-many-locals + # Only same type tensor can be contracted. + assert self.same_type_with(other) + + if fuse_names is None: + fuse_names = set() + # Fuse name should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(fuse_name).symmetry) + for fuse_name in fuse_names) + + # Alias tensor + tensor_1 = self + tensor_2 = other + + # Check if contract edge and fuse edge compatible + assert all(tensor_1.edge_by_name(name) == tensor_2.edge_by_name(name) for name in fuse_names) + assert all( + tensor_1.edge_by_name(name_1).conjugate() == tensor_2.edge_by_name(name_2) + for name_1, name_2 in contract_pairs) + + # All tensor edges merged to three part: fuse edge, contract edge, free edge + + # Contract of tensor has 5 step: + # 1. reverse arrow + # reverse all free edge and fuse edge to arrow False, without parity apply. + # reverse contract edge to two arrow: False(tensor_2) and True(tensor_1), only apply parity to one tensor. + # 2. merge edge + # merge all edge in the same part to one edge, only apply parity to contract edge of one tensor + # free edge and fuse edge will not be applied parity. + # 3. contract matrix + # call matrix multiply + # 4. split edge + # split edge merged in step 2, without apply parity + # 5. reverse arrow + # reverse arrow reversed in step 1, without parity apply + + # Step 1 + contract_names_1 = {name_1 for name_1, name_2 in contract_pairs} + contract_names_2 = {name_2 for name_1, name_2 in contract_pairs} + arrow_true_names_1 = {name for name, edge in zip(tensor_1.names, tensor_1.edges) if edge.arrow} + arrow_true_names_2 = {name for name, edge in zip(tensor_2.names, tensor_2.edges) if edge.arrow} + + tensor_1 = tensor_1.reverse_edge(contract_names_1 ^ arrow_true_names_1, False, + contract_names_1 - arrow_true_names_1) + tensor_2 = tensor_2.reverse_edge(arrow_true_names_2, False, set()) + + # Step 2 + free_names_1 = tuple(name for name in tensor_1.names if name not in fuse_names and name not in contract_names_1) + free_edges_1 = tuple((name, tensor_1.edge_by_name(name)) for name in free_names_1) + free_names_2 = tuple(name for name in tensor_2.names if name not in fuse_names and name not in contract_names_2) + free_edges_2 = tuple((name, tensor_2.edge_by_name(name)) for name in free_names_2) + if tensor_1.data.nelement() > tensor_2.data.nelement(): + # Tensor 1 larger, fit by tensor 1 + ordered_fuse_names = tuple(name for name in tensor_1.names if name in fuse_names) + ordered_fuse_edges = tuple((name, tensor_1.edge_by_name(name)) for name in ordered_fuse_names) + + # pylint: disable=unnecessary-comprehension + contract_names_map = {name_1: name_2 for name_1, name_2 in contract_pairs} + ordered_contract_names_1 = tuple(name for name in tensor_1.names if name in contract_names_1) + ordered_contract_names_2 = tuple(contract_names_map[name] for name in contract_names_1) + else: + # Tensor 2 larger, fit by tensor 2 + ordered_fuse_names = tuple(name for name in tensor_2.names if name in fuse_names) + ordered_fuse_edges = tuple((name, tensor_2.edge_by_name(name)) for name in ordered_fuse_names) + + contract_names_map = {name_2: name_1 for name_1, name_2 in contract_pairs} + ordered_contract_names_2 = tuple(name for name in tensor_2.names if name in contract_names_2) + ordered_contract_names_1 = tuple(contract_names_map[name] for name in contract_names_2) + put_contract_1_right = next( + (name in contract_names_1 for name in reversed(tensor_1.names) if name not in fuse_names), True) + put_contract_2_right = next( + (name in contract_names_2 for name in reversed(tensor_2.names) if name not in fuse_names), False) + + tensor_1 = tensor_1.merge_edge( + { + "Free_1": free_names_1, + "Contract_1": ordered_contract_names_1, + "Fuse_1": ordered_fuse_names, + }, False, {"Contract_1"}, { + "Free_1": False, + "Contract_1": True, + "Fuse_1": False, + }, ("Fuse_1", "Free_1", "Contract_1") if put_contract_1_right else ("Fuse_1", "Contract_1", "Free_1")) + tensor_2 = tensor_2.merge_edge( + { + "Free_2": free_names_2, + "Contract_2": ordered_contract_names_2, + "Fuse_2": ordered_fuse_names, + }, False, set(), { + "Free_2": False, + "Contract_2": False, + "Fuse_2": False, + }, ("Fuse_2", "Free_2", "Contract_2") if put_contract_2_right else ("Fuse_2", "Contract_2", "Free_2")) + # C[fuse, free1, free2] = A[fuse, free1 contract] B[fuse, contract free2] + assert tensor_1.edge_by_name("Fuse_1") == tensor_2.edge_by_name("Fuse_2") + assert tensor_1.edge_by_name("Contact_1").conjugate() == tensor_2.edge_by_name("Contract_2") + + # Step 3 + # The standard arrow is + # (0, False, True) (0, False, False) + # aka: (a b) (c d) (c+ b+) = (a d) + # since: EPR pair order is (False True) + # put_contract_1_right should be True + # put_contract_2_right should be False + # Every mismatch generate a sign + # Total sign is + # (!put_contract_1_right) ^ (put_contract_2_right) = put_contract_1_right == put_contract_2_right + data = torch.einsum( + "b" + ("ic" if put_contract_1_right else "ci") + ",b" + ("jc" if put_contract_2_right else "cj") + "->bij", + tensor_1.data, tensor_2.data) + if put_contract_1_right == put_contract_2_right: + data = torch.where(tensor_2.edge_by_name("Free_2").parity.reshape([1, 1, -1]), -data, +data) + tensor = Tensor(names=("Fuse", "Free_1", "Free_2"), + edges=(tensor_1.edge_by_name("Fuse_1"), tensor_1.edge_by_name("Free_1"), + tensor_2.edge_by_name("Free_2")), + fermion=self.fermion, + dtypes=self.dtypes, + data=data) + + # Step 4 + tensor = tensor.split_edge({ + "Fuse": ordered_fuse_edges, + "Free_1": free_edges_1, + "Free_2": free_edges_2 + }, False, set()) + + # Step 5 + tensor = tensor.reverse_edge((arrow_true_names_1 - contract_names_1) | (arrow_true_names_2 - contract_names_2), + False, set()) + + return tensor + + def _trace_group_edge( + self: Tensor, + trace_pairs: set[tuple[str, str]], + fuse_names: dict[str, tuple[str, str]], + ) -> tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...], tuple[str, ...], tuple[str, ...], tuple[str, ...]]: + # pylint: disable=unnecessary-comprehension + trace_map = { + old_name_1: old_name_2 for old_name_1, old_name_2 in trace_pairs + } | { + old_name_2: old_name_1 for old_name_1, old_name_2 in trace_pairs + } + fuse_map = { + old_name_1: (old_name_2, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items() + } | { + old_name_2: (old_name_1, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items() + } + reversed_trace_names_1: list[str] = [] + reversed_trace_names_2: list[str] = [] + reversed_fuse_names_1: list[str] = [] + reversed_fuse_names_2: list[str] = [] + reversed_free_names: list[str] = [] + reversed_fuse_names_result: list[str] = [] + added_names: set[str] = set() + for name in reversed(self.names): + if name not in added_names: + trace_name: str | None = trace_map.get(name, None) + fuse_name: tuple[str, str] | None = fuse_map.get(name, None) + if trace_name is not None: + reversed_trace_names_2.append(name) + reversed_trace_names_1.append(trace_name) + added_names.add(trace_name) + elif fuse_name is not None: + reversed_fuse_names_2.append(name) + reversed_fuse_names_1.append(fuse_name[0]) + added_names.add(fuse_name[0]) + reversed_fuse_names_result.append(fuse_name[1]) + else: + reversed_free_names.append(name) + return ( + tuple(reversed(reversed_trace_names_1)), + tuple(reversed(reversed_trace_names_2)), + tuple(reversed(reversed_fuse_names_1)), + tuple(reversed(reversed_fuse_names_2)), + tuple(reversed(reversed_free_names)), + tuple(reversed(reversed_fuse_names_result)), + ) + + def trace( + self: Tensor, + trace_pairs: set[tuple[str, str]], + fuse_names: dict[str, tuple[str, str]] | None = None, + ) -> Tensor: + """ + Trace a tensor. + + Parameters + ---------- + trace_pairs : set[tuple[str, str]] + The pairs of edges to be traced + fuse_names : dict[str, tuple[str, str]] + The edges to be fused. + + Returns + ------- + Tensor + The traced tensor. + """ + if fuse_names is None: + fuse_names = {} + # Fuse names should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(old_name_1).symmetry) + for new_name, [old_name_1, old_name_2] in fuse_names.items()) + # Fuse names should share the same edges + assert all( + self.edge_by_name(old_name_1) == self.edge_by_name(old_name_2) + for new_name, [old_name_1, old_name_2] in fuse_names.items()) + # Trace edges should be compatible + assert all( + self.edge_by_name(old_name_1).conjugate() == self.edge_by_name(old_name_2) + for old_name_1, old_name_2 in trace_pairs) + + # Split trace pairs and fuse names to two part before main part of trace. + [ + trace_names_1, + trace_names_2, + fuse_names_1, + fuse_names_2, + free_names, + fuse_names_result, + ] = self._trace_group_edge(trace_pairs, fuse_names) + + # Make alias + tensor = self + + # Tensor edges merged to 5 parts: fuse edge 1, fuse edge 2, trace edge 1, trace edge 2, free edge + # Trace contains 5 step: + # 1. reverse all arrow to False except trace edge 1 to be True, only apply parity to one of two trace edge + # 2. merge all edge to 5 part, only apply parity to one of two trace edge + # 3. trace trivial tensor + # 4. split edge merged in step 2, without apply parity + # 5. reverse arrow reversed in step 1, without apply parity + + # Step 1 + arrow_true_names = {name for name, edge in zip(tensor.names, tensor.edges) if edge.arrow} + unordered_trace_names_1 = set(trace_names_1) + tensor = tensor.reverse_edge(unordered_trace_names_1 ^ arrow_true_names, False, + unordered_trace_names_1 - arrow_true_names) + + # Step 2 + free_edges = tuple((name, tensor.edge_by_name(name)) for name in free_names) + fuse_edges = tuple( + (new_name, tensor.edge_by_name(old_name)) for old_name, new_name in zip(fuse_names_1, fuse_names_result)) + tensor = tensor.merge_edge( + { + "Trace_1": trace_names_1, + "Trace_2": trace_names_2, + "Fuse_1": fuse_names_1, + "Fuse_2": fuse_names_2, + "Free": free_names, + }, False, {"Trace_1"}, { + "Trace_1": True, + "Trace_2": False, + "Fuse_1": False, + "Fuse_2": False, + "Free": False, + }, ("Trace_1", "Trace_2", "Fuse_1", "Fuse_2", "Free")) + # B[fuse, free] = A[trace, trace, fuse, fuse, free] + assert tensor.edges[2] == tensor.edges[3] + assert tensor.edges[0].conjugate() == tensor.edges[1] + + # Step 3 + # As tested, the order of edges in this einsum is not important + # ttffi->fi, fftti->fi, ffitt->fi, ttiff->if, ittff->if, ifftt->if + data = torch.einsum("ttffi->fi", tensor.data) + tensor = Tensor(names=("Fuse", "Free"), + edges=(tensor.edges[2], tensor.edges[4]), + fermion=self.fermion, + dtypes=self.dtypes, + data=data) + + # Step 4 + tensor = tensor.split_edge({ + "Fuse": fuse_edges, + "Free": free_edges, + }, False, set()) + + # Step 5 + tensor = tensor.reverse_edge( + # Free edge with arrow true + {name for name in free_names if name in arrow_true_names} | + # New edge from fused edge with arrow true + {new_name for old_name, new_name in zip(fuse_names_1, fuse_names_result) if old_name in arrow_true_names}, + False, + set(), + ) + + return tensor + + def conjugate(self: Tensor, trivial_metric: bool = False) -> Tensor: + """ + Get the conjugate of this tensor. + + Parameters + ---------- + trivial_metric : bool, default=False + Fermionic tensor in network may result in non positive definite metric, set this trivial_metric to True to + ensure the metric to be positive, but it breaks the associative law with tensor contract. + + Returns + ------- + Tensor + The conjugated tensor. + """ + data = torch.conj(self.data) + + # Usually, only a full transpose sign is applied. + # If trivial_metric is set True, parity in edges with arrow True is also applied. + + # Apply parity + if any(self.fermion): + # It is femionic tensor, parity need to be applied + + # Parity is parity generated from a full transpose + # {sum 0<=i Edge: + # pylint: disable=invalid-name + m, n = matrix.size() + assert edge.dimension == m + argmax = torch.argmax(matrix, dim=0) + assert argmax.size == (n,) + return Edge(fermion=edge.fermion, + dtypes=edge.dtypes, + symmetry=tuple( + _utility.neg_symmetry(sub_symmetry[argmax]) + for sub_symmetry, sub_dtype in zip(edge.symmetry, edge.dtypes)), + dimension=n, + arrow=arrow) + + def svd( + self: Tensor, + free_names_u: set[str], + common_name_u: str, + common_name_v: str, + singular_name_u: str, + singular_name_v: str, + cut: int = -1, + fuse_names: set[str] | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + """ + SVD decomposition a tensor. + + Parameters + ---------- + free_names_u : set[str] + Free names in U tensor of the result of SVD. + common_name_u, common_name_v, singular_name_u, singular_name_v : str + The name of generated edges. + cut : int, default=-1 + The cut for the singular values. + fuse_names : set[str], optional + The names of fuse edges. + + Returns + ------- + tuple[Tensor, Tensor, Tensor] + U, S, V tensor, the result of SVD. + """ + # pylint: disable=too-many-locals + if fuse_names is None: + fuse_names = set() + # Fuse name should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(fuse_name).symmetry) + for fuse_name in fuse_names) + + free_names_v = {name for name in self.names if name not in free_names_u and name not in fuse_names} + + assert all(name in self.names for name in free_names_u) + assert all(name in self.names for name in fuse_names) + assert fuse_names & free_names_u == set() + assert common_name_u not in free_names_u | fuse_names + assert common_name_v not in free_names_v | fuse_names + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(arrow_true_names, False, set()) + + ordered_fuse_names = tuple(name for name in tensor.names if name in fuse_names) + ordered_free_names_u = tuple(name for name in tensor.names if name in free_names_u) + ordered_free_names_v = tuple(name for name in tensor.names if name in free_names_v) + ordered_fuse_edges = tuple((name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in fuse_names) + ordered_free_edges_u = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_u) + ordered_free_edges_v = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_v) + + put_v_right = next((name in free_names_v for name in reversed(tensor.names) if name not in fuse_names), True) + tensor = tensor.merge_edge( + { + "SVD_F": ordered_fuse_names, + "SVD_U": ordered_free_names_u, + "SVD_V": ordered_free_names_v + }, False, set(), { + "SVD_F": False, + "SVD_U": False, + "SVD_V": False + }, ("SVD_F", "SVD_U", "SVD_V") if put_v_right else ("SVD_F", "SVD_V", "SVD_U")) + + data_1, data_s, data_2 = torch.linalg.svd(tensor.data, some=True) + if cut != -1: + data_1 = data_1[:, :, :cut] + data_s = data_s[:, :cut] + data_2 = data_2[:, :cut, :] + data_s = torch.diag_embed(data_s) + + fuse_edge = tensor.edges[0] + free_edge_1 = tensor.edges[1] + common_edge_1 = Tensor._guess_edge(torch.sum(torch.abs(data_1), dim=0), free_edge_1, True) + tensor_1 = Tensor( + names=("SVD_F", "SVD_U", common_name_u) if put_v_right else ("SVD_F", "SVD_V", common_name_v), + edges=(fuse_edge, free_edge_1, common_edge_1), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_1, + ) + free_edge_2 = tensor.edges[2] + common_edge_2 = Tensor._guess_edge(torch.sum(torch.abs(data_2.transpose(0, 1)), dim=0), free_edge_2, False) + tensor_2 = Tensor( + names=("SVD_F", common_name_v, "SVD_V") if put_v_right else ("SVD_F", common_name_u, "SVD_U"), + edges=(fuse_edge, common_edge_2, free_edge_2), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_2, + ) + assert common_edge_1.conjugate() == common_edge_2 + tensor_s = Tensor( + names=("SVD_F", singular_name_u, singular_name_v) if put_v_right else + ("SVD_F", singular_name_v, singular_name_u), + edges=(common_edge_2, common_edge_1), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_s, + ) + + tensor_u = tensor_1 if put_v_right else tensor_2 + tensor_v = tensor_2 if put_v_right else tensor_1 + + tensor_u = tensor_u.split_edge({"SVD_U": ordered_free_edges_u, "SVD_F": ordered_fuse_edges}, False, set()) + tensor_v = tensor_v.split_edge({"SVD_V": ordered_free_edges_v, "SVD_F": ordered_fuse_edges}, False, set()) + tensor_s = tensor_s.split_edge({"SVD_F": ordered_fuse_edges}, False, set()) + + tensor_u = tensor_u.reverse_edge(arrow_true_names & (free_names_u | fuse_names), False, set()) + tensor_v = tensor_v.reverse_edge(arrow_true_names & (free_names_v | fuse_names), False, set()) + tensor_s = tensor_s.reverse_edge(arrow_true_names & fuse_names, False, set()) + + return tensor_u, tensor_s, tensor_v + + def qr(self: Tensor, + free_names_direction: str, + free_names: set[str], + common_name_q: str, + common_name_r: str, + fuse_names: set[str] | None = None) -> tuple[Tensor, Tensor]: + """ + QR decomposition on this tensor. + + Parameters + ---------- + free_names_direction : 'Q' | 'q' | 'R' | 'r' + Specify which direction the free_names will set + free_names : set[str] + The names of free edges after QR decomposition. + common_name_q, common_name_r : str + The names of edges created by QR decomposition. + fuse_names : set[str], optional + The names of fuse edges + + Returns + ------- + tuple[Tensor, Tensor] + Tensor Q and R, the result of QR decomposition. + """ + # pylint: disable=invalid-name + # pylint: disable=too-many-locals + if fuse_names is None: + fuse_names = set() + # Fuse name should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(fuse_name).symmetry) + for fuse_name in fuse_names) + + if free_names_direction in {'Q', 'q'}: + free_names_q = free_names + free_names_r = {name for name in self.names if name not in free_names and name not in fuse_names} + elif free_names_direction in {'R', 'r'}: + free_names_r = free_names + free_names_q = {name for name in self.names if name not in free_names and name not in fuse_names} + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(arrow_true_names, False, set()) + + ordered_fuse_names = tuple(name for name in tensor.names if name in fuse_names) + ordered_free_names_q = tuple(name for name in tensor.names if name in free_names_q) + ordered_free_names_r = tuple(name for name in tensor.names if name in free_names_r) + ordered_fuse_edges = tuple((name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in fuse_names) + ordered_free_edges_q = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_q) + ordered_free_edges_r = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_r) + + tensor = tensor.merge_edge( + { + "QR_F": ordered_fuse_names, + "QR_Q": ordered_free_names_q, + "QR_R": ordered_free_names_r + }, False, set(), { + "QR_F": False, + "QR_Q": False, + "QR_R": False + }, ("QR_F", "QR_Q", "QR_R")) + + data_q, data_r = torch.linalg.qr(tensor.data, some=True) + + fuse_edge = tensor.edges[0] + free_edge_q = tensor.edges[1] + common_edge_q = Tensor._guess_edge(torch.sum(torch.abs(data_q), dim=0), free_edge_q, True) + tensor_q = Tensor( + names=("QR_F", "QR_Q", common_name_q), + edges=(fuse_edge, free_edge_q, common_edge_q), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_q, + ) + free_edge_r = tensor.edges[2] + common_edge_r = Tensor._guess_edge(torch.sum(torch.abs(data_r.transpose(0, 1)), dim=0), free_edge_r, False) + tensor_r = Tensor( + names=("QR_F", common_name_r, "QR_R"), + edges=(fuse_edge, common_edge_r, free_edge_r), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_r, + ) + assert common_edge_q.conjugate() == common_edge_r + + tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q, "QR_F": ordered_fuse_edges}, False, set()) + tensor_r = tensor_r.split_edge({"QR_R": ordered_free_edges_r, "QR_F": ordered_fuse_edges}, False, set()) + + tensor_q = tensor_q.reverse_edge(arrow_true_names & (free_names_q | fuse_names), False, set()) + tensor_r = tensor_r.reverse_edge(arrow_true_names & (free_names_r | fuse_names), False, set()) + + return tensor_q, tensor_r + + def identity(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor: + """ + Get the identity tensor with same shape to this tensor. + + Parameters + ---------- + pairs : set[tuple[str, str]] + The pair of edge names to specify the relation among edges to set identity tensor. + + Returns + ------- + Tensor + The result identity tensor. + """ + # pylint: disable=too-many-locals + + # The order of edges before setting identity should be (False True) + # Merge tensor directly to two edge, set identity and split it directly. + # When splitting, only apply parity to one part of edges + + # pylint: disable=unnecessary-comprehension + pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs} + added_names: set[str] = set() + reversed_names_1: list[str] = [] + reversed_names_2: list[str] = [] + for name in reversed(self.names): + if name not in added_names: + another_name = pairs_map[name] + reversed_names_2.append(name) + reversed_names_1.append(another_name) + added_names.add(another_name) + names_1 = tuple(reversed(reversed_names_1)) + names_2 = tuple(reversed(reversed_names_2)) + # unordered_names_1 = set(names_1) + unordered_names_2 = set(names_2) + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + edges_1 = tuple((name, tensor.edge_by_name(name)) for name in names_1) + edges_2 = tuple((name, tensor.edge_by_name(name)) for name in names_2) + + tensor = tensor.merge_edge({ + "Identity_1": names_1, + "Identity_2": names_2 + }, False, {"Identity_2"}, { + "Identity_1": False, + "Identity_2": True + }, ("Identity_1", "Identity_2")) + + torch.eye(*tensor.data.size(), out=tensor.data) + + tensor = tensor.split_edge({"Identity_1": edges_1, "Identity_2": edges_2}, False, {"Identity_2"}) + + tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + return tensor + + def exponential(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor: + """ + Get the exponential tensor of this tensor. + + Parameters + ---------- + pairs : set[tuple[str, str]] + The pair of edge names to specify the relation among edges to calculate exponential tensor. + + Returns + ------- + Tensor + The result exponential tensor. + """ + # pylint: disable=too-many-locals + + # The order of edges before setting exponential should be (False True) + # Merge tensor directly to two edge, set exponential and split it directly. + # When splitting, only apply parity to one part of edges + + unordered_names_1 = {name_1 for name_1, name_2 in pairs} + unordered_names_2 = {name_2 for name_1, name_2 in pairs} + if self.names and self.names[-1] in unordered_names_1: + unordered_names_1, unordered_names_2 = unordered_names_2, unordered_names_1 + # pylint: disable=unnecessary-comprehension + pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs} + reversed_names_1: list[str] = [] + reversed_names_2: list[str] = [] + for name in reversed(self.names): + if name in unordered_names_2: + another_name = pairs_map[name] + reversed_names_2.append(name) + reversed_names_1.append(another_name) + names_1 = tuple(reversed(reversed_names_1)) + names_2 = tuple(reversed(reversed_names_2)) + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + edges_1 = tuple((name, tensor.edge_by_name(name)) for name in names_1) + edges_2 = tuple((name, tensor.edge_by_name(name)) for name in names_2) + + tensor = tensor.merge_edge({ + "Exponential_1": names_1, + "Exponential_2": names_2 + }, False, {"Exponential_2"}, { + "Exponential_1": False, + "Exponential_2": True + }, ("Exponential_1", "Exponential_2")) + + tensor = Tensor( + names=tensor.names, + edges=tensor.edges, + fermion=tensor.fermion, + dtypes=tensor.dtypes, + data=torch.linalg.matrix_exp(tensor.data), + mask=tensor.mask, + ) + + tensor = tensor.split_edge({"Exponential_1": edges_1, "Exponential_2": edges_2}, False, {"Exponential_2"}) + + tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + return tensor diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 000000000..be78a74be --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,113 @@ +import torch +import tat +import tat.compat as compat + + +def test_edge_from_dimension(): + assert compat.No.Edge(4) == tat.Edge(dimension=4) + assert compat.Fermi.Edge(4) == tat.Edge(fermion=(True,), + symmetry=(torch.tensor([0, 0, 0, 0], dtype=torch.int),), + arrow=False) + assert compat.Z2.Edge(4) == tat.Edge(symmetry=(torch.tensor([False, False, False, False]),)) + + +def test_edge_from_segments(): + assert compat.Z2.Edge([ + (False, 2), + (True, 3), + ]) == tat.Edge(symmetry=(torch.tensor([False, False, True, True, True]),),) + assert compat.Fermi.Edge([ + (-1, 1), + (0, 2), + (+1, 3), + ], True) == tat.Edge( + symmetry=(torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int),), + arrow=True, + fermion=(True,), + ) + assert compat.FermiFermi.Edge([ + ((-1, -2), 1), + ((0, +1), 2), + ((+1, 0), 3), + ], True) == tat.Edge( + symmetry=( + torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int), + torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int), + ), + arrow=True, + fermion=(True, True), + ) + + +def test_edge_from_segments_without_dimension(): + assert compat.Z2.Edge([False, True]) == tat.Edge(symmetry=(torch.tensor([False, True]),)) + assert compat.Fermi.Edge([-1, 0, +1], True) == tat.Edge( + symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int),), + arrow=True, + fermion=(True,), + ) + assert compat.FermiFermi.Edge([ + (-1, -2), + (0, +1), + (+1, 0), + ], True) == tat.Edge( + symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int), torch.tensor([-2, +1, 0], dtype=torch.int)), + arrow=True, + fermion=(True, True), + ) + + +def test_edge_from_tuple(): + assert compat.FermiFermi.Edge(([ + ((-1, -2), 1), + ((0, +1), 2), + ((+1, 0), 3), + ], True)) == tat.Edge( + symmetry=( + torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int), + torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int), + ), + arrow=True, + fermion=(True, True), + ) + assert compat.FermiFermi.Edge(([ + (-1, -2), + (0, +1), + (+1, 0), + ], True)) == tat.Edge( + symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int), torch.tensor([-2, +1, 0], dtype=torch.int)), + arrow=True, + fermion=(True, True), + ) + + +def test_tensor(): + a = compat.FermiZ2.D.Tensor(["i", "j"], [ + [(-1, False), (-1, True), (0, True), (0, False)], + [(+1, True), (+1, False), (0, False), (0, True)], + ]) + b = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge( + fermion=(True, False), + symmetry=( + torch.tensor([-1, -1, 0, 0], dtype=torch.int), + torch.tensor([False, True, True, False]), + ), + arrow=False, + ), + tat.Edge( + fermion=(True, False), + symmetry=( + torch.tensor([+1, +1, 0, 0], dtype=torch.int), + torch.tensor([True, False, False, True]), + ), + arrow=False, + ), + ), + ) + assert a.same_shape_with(b, allow_transpose=False) diff --git a/tests/test_create_tensor.py b/tests/test_create_tensor.py new file mode 100644 index 000000000..079609ffa --- /dev/null +++ b/tests/test_create_tensor.py @@ -0,0 +1,89 @@ +import torch +import tat + + +def test_create_tensor(): + a = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True), + tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), fermion=(True,), arrow=False), + ), + ) + assert a.rank == 2 + assert a.names == ("i", "j") + assert a.edges[0] == tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True) + assert a.edges[1] == tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), + fermion=(True,), + arrow=False) + assert a.edges[0] == a.edge_by_name("i") + assert a.edges[1] == a.edge_by_name("j") + + +def test_tensor_get_set_item(): + a = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True), + tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), fermion=(True,), arrow=False), + ), + ) + a[{"i": 0, "j": 0}] = 1 + assert a[0, 0] == 1 + a["i":2, "j":3] = 2 + assert a[{"i": 2, "j": 3}] == 2 + a[2, 0] = 3 + assert a["i":2, "j":0] == 0 + + b = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge(symmetry=(torch.tensor([0, 0, -1]),), fermion=(False,)), + tat.Edge(symmetry=(torch.tensor([0, 0, 0, +1, +1]),), fermion=(False,)), + ), + ) + b[{"i": 0, "j": 0}] = 1 + assert b[0, 0] == 1 + b["i":2, "j":3] = 2 + assert b[{"i": 2, "j": 3}] == 2 + b[2, 0] = 3 + assert b["i":2, "j":0] == 0 + + +def test_create_randn_tensor(): + a = tat.Tensor( + ("i", "j"), + ( + tat.Edge(symmetry=(torch.tensor([False, True]),)), + tat.Edge(symmetry=(torch.tensor([False, True]),)), + ), + dtype=torch.float16, + ).randn_() + assert a.dtype == torch.float16 + assert a[0, 0] != 0 + assert a[1, 1] != 0 + assert a[0, 1] == 0 + assert a[1, 0] == 0 + + b = tat.Tensor( + ("i", "j"), + ( + tat.Edge(symmetry=(torch.tensor([False, False]), torch.tensor([0, -1]))), + tat.Edge(symmetry=(torch.tensor([False, False]), torch.tensor([0, +1]))), + ), + dtype=torch.float16, + ).randn_() + assert b.dtype == torch.float16 + assert b[0, 0] != 0 + assert b[1, 1] != 0 + assert b[0, 1] == 0 + assert b[1, 0] == 0 diff --git a/tests/test_edge.py b/tests/test_edge.py new file mode 100644 index 000000000..096973e71 --- /dev/null +++ b/tests/test_edge.py @@ -0,0 +1,33 @@ +import torch +from tat import Edge + + +def test_create_edge_and_basic(): + a = Edge(dimension=5) + assert a.arrow == False + assert a.dimension == 5 + b = Edge(symmetry=(torch.tensor([False, False, True, True]),)) + assert b.arrow == False + assert b.dimension == 4 + c = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([False, True])), arrow=True) + assert c.arrow == True + assert c.dimension == 2 + + +def test_edge_conjugate_and_equal(): + a = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, 1])), arrow=True) + b = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, -1])), arrow=False) + assert a.conjugate() == b + assert a != 2 + + +def test_repr(): + a = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, 1])), arrow=True) + repr_a = repr(a) + assert repr_a == "Edge(dimension=2, arrow=True, fermion=(False, True), symmetry=(tensor([False, True]), tensor([0, 1])))" + b = Edge(symmetry=(torch.tensor([False, True]), torch.tensor([0, 1]))) + repr_b = repr(b) + assert repr_b == "Edge(dimension=2, symmetry=(tensor([False, True]), tensor([0, 1])))" + c = Edge(dimension=4) + repr_c = repr(c) + assert repr_c == "Edge(dimension=4)"