-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement prototype for torch based fermionic library.
- Loading branch information
0 parents
commit 1987c4b
Showing
12 changed files
with
2,565 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
name: CI | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
CI: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
include: | ||
- python-version: "3.10" | ||
pytorch-version: "1.11" | ||
- python-version: "3.10" | ||
pytorch-version: "1.12" | ||
- python-version: "3.10" | ||
pytorch-version: "1.13" | ||
- python-version: "3.10" | ||
pytorch-version: "2.0" | ||
- python-version: "3.10" | ||
pytorch-version: "2.1" | ||
|
||
- python-version: "3.11" | ||
pytorch-version: "1.13" | ||
- python-version: "3.11" | ||
pytorch-version: "2.0" | ||
- python-version: "3.11" | ||
pytorch-version: "2.1" | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install requirements | ||
run: | | ||
pip install pylint mypy pytest pytest-cov | ||
pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu | ||
pip install multimethod | ||
- name: Run pylint | ||
run: pylint tat | ||
working-directory: ${{ github.workspace }} | ||
- name: Run mypy | ||
run: mypy tat | ||
working-directory: ${{ github.workspace }} | ||
- name: Run pytest | ||
run: pytest | ||
working-directory: ${{ github.workspace }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
.coverage | ||
__pycache__ | ||
env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# TAT | ||
|
||
A Fermionic tensor library based on pytorch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
[project] | ||
name = "tat" | ||
version = "0.4.0" | ||
authors = [ | ||
{email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"} | ||
] | ||
description = "A Fermionic tensor library based on pytorch." | ||
readme = "README.md" | ||
requires-python = ">=3.10" | ||
license = {text = "GPL-3.0-or-later"} | ||
dependencies = [ | ||
"multimethod", | ||
"torch", | ||
] | ||
|
||
[tool.pylint] | ||
max-line-length = 120 | ||
generated-members = "torch.*" | ||
|
||
[tool.yapf] | ||
based_on_style = "google" | ||
column_limit = 120 | ||
|
||
[tool.mypy] | ||
check_untyped_defs = true | ||
disallow_untyped_defs = true | ||
|
||
[tool.pytest.ini_options] | ||
pythonpath = "." | ||
testpaths = ["tests",] | ||
addopts = "--cov=tat" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
The tat is a Fermionic tensor library based on pytorch. | ||
""" | ||
|
||
from .edge import Edge | ||
from .tensor import Tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
Some internal utility used by tat. | ||
""" | ||
|
||
import torch | ||
|
||
# pylint: disable=missing-function-docstring | ||
# pylint: disable=no-else-return | ||
|
||
|
||
def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor: | ||
return tensor.reshape([-1 if i == index else 1 for i in range(rank)]) | ||
|
||
|
||
def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype is torch.bool: | ||
return tensor | ||
else: | ||
return -tensor | ||
|
||
|
||
def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: | ||
if tensor_1.dtype is torch.bool: | ||
return torch.logical_xor(tensor_1, tensor_2) | ||
else: | ||
return torch.add(tensor_1, tensor_2) | ||
|
||
|
||
def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor: | ||
# pylint: disable=singleton-comparison | ||
if tensor.dtype is torch.bool: | ||
return tensor == False | ||
else: | ||
return tensor == 0 | ||
|
||
|
||
def parity(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype is torch.bool: | ||
return tensor | ||
else: | ||
return tensor % 2 != 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
""" | ||
This file implement a compat layer for legacy TAT interface. | ||
""" | ||
|
||
from __future__ import annotations | ||
import typing | ||
from multimethod import multimethod | ||
import torch | ||
from .edge import Edge as E | ||
from .tensor import Tensor as T | ||
|
||
# pylint: disable=too-few-public-methods | ||
# pylint: disable=too-many-instance-attributes | ||
|
||
|
||
class CompatSymmetry: | ||
""" | ||
The common Symmetry namespace | ||
""" | ||
|
||
def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None: | ||
self.fermion: tuple[bool, ...] = fermion | ||
self.dtypes: tuple[torch.dtype, ...] = dtypes | ||
|
||
# pylint: disable=invalid-name | ||
self.S: CompatScalar | ||
self.D: CompatScalar | ||
self.C: CompatScalar | ||
self.Z: CompatScalar | ||
self.float32: CompatScalar | ||
self.float64: CompatScalar | ||
self.float: CompatScalar | ||
self.complex64: CompatScalar | ||
self.complex128: CompatScalar | ||
self.complex: CompatScalar | ||
|
||
self.S = self.float32 = CompatScalar(self, torch.float32) | ||
self.D = self.float64 = self.float = CompatScalar(self, torch.float64) | ||
self.C = self.complex64 = CompatScalar(self, torch.complex64) | ||
self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128) | ||
|
||
def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: | ||
# Segments may be [Sym] or [(Sym, Size)] | ||
try: | ||
# try [(Sym, Size)] first | ||
return self._parse_segments_kernel(segments) | ||
except TypeError: | ||
# Cannot unpack is a type error, value[index] is a type error, too | ||
# convert [Sym] to [(Sym, Size)] | ||
return self._parse_segments_kernel([(sym, 1) for sym in segments]) | ||
|
||
def _parse_segments_kernel(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: | ||
# [(Sym, Size)] for every element | ||
dimension = sum(dim for _, dim in segments) | ||
symmetry = tuple( | ||
torch.tensor( | ||
sum( | ||
([self._parse_segments_get_subsymmetry(sym, index)] * dim | ||
for sym, dim in segments), | ||
[], | ||
), # Concat all segment for this subsymmetry | ||
dtype=sub_symmetry, | ||
) # Generate subsymmetry one by one | ||
for index, sub_symmetry in enumerate(self.dtypes)) | ||
return symmetry, dimension | ||
|
||
def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object: | ||
# Most of time, symmetry is a tuple of subsymmetry | ||
# But if there is only ome subsymmetry, it could not be a tuple but subsymmetry itself. | ||
# pylint: disable=no-else-return | ||
if isinstance(sym, tuple): | ||
return sym[index] | ||
else: | ||
if len(self.fermion) == 1: | ||
return sym | ||
else: | ||
raise TypeError(f"{sym=} is not subscriptable") | ||
|
||
@multimethod | ||
def Edge(self: CompatSymmetry, dimension: int) -> E: | ||
""" | ||
Create edge with compat interface. | ||
""" | ||
# pylint: disable=invalid-name | ||
symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False) | ||
|
||
@Edge.register | ||
def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E: | ||
symmetry, dimension = self._parse_segments(segments) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) | ||
|
||
@Edge.register | ||
def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E: | ||
segments, arrow = segments_and_bool | ||
symmetry, dimension = self._parse_segments(segments) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) | ||
|
||
|
||
class CompatScalar: | ||
""" | ||
The common Scalar namespace. | ||
""" | ||
|
||
def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None: | ||
self.symmetry: CompatSymmetry = symmetry | ||
self.dtype: torch.dtype = dtype | ||
|
||
@multimethod | ||
def Tensor(self: CompatScalar, names: list[str], edges: list) -> T: | ||
""" | ||
Create tensor with compat names and edges. | ||
""" | ||
# pylint: disable=invalid-name | ||
return T( | ||
tuple(names), | ||
tuple(self.symmetry.Edge(edge) for edge in edges), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
|
||
@Tensor.register | ||
def _(self: CompatScalar) -> T: | ||
result = T( | ||
(), | ||
(), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
result.data.reshape([-1])[0] = 1 | ||
return result | ||
|
||
@Tensor.register | ||
def _( | ||
self: CompatScalar, | ||
number: typing.Any, | ||
names: list[str] | None = None, | ||
edge_symmetry: list | None = None, | ||
edge_arrow: list[bool] | None = None, | ||
) -> T: | ||
# Create high rank tensor with only one element | ||
if names is None: | ||
names = [] | ||
if edge_symmetry is None: | ||
edge_symmetry = [None for _ in names] | ||
if edge_arrow is None: | ||
edge_arrow = [False for _ in names] | ||
result = T( | ||
tuple(names), | ||
tuple( | ||
E( | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
symmetry=tuple( | ||
torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype) | ||
for index, dtype in enumerate(self.symmetry.dtypes)), | ||
dimension=1, | ||
arrow=arrow, | ||
) | ||
for symmetry, arrow in zip(edge_symmetry, edge_arrow)), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
result.data.reshape([-1])[0] = number | ||
return result | ||
|
||
def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object: | ||
# pylint: disable=no-else-return | ||
if sym is None: | ||
return 0 | ||
elif isinstance(sym, tuple): | ||
return sym[index] | ||
else: | ||
if len(self.symmetry.fermion) == 1: | ||
return sym | ||
else: | ||
raise TypeError(f"{sym=} is not subscriptable") | ||
|
||
|
||
No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=()) | ||
Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,)) | ||
U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,)) | ||
Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,)) | ||
FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool)) | ||
FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int)) | ||
Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,)) | ||
FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int)) | ||
Normal: CompatSymmetry = No |
Oops, something went wrong.