-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement prototype for torch based fermionic library.
Some function not implemented or defined [ ] merge_edge [ ] split_edge [ ] contract [ ] trace [ ] identity [ ] exponential [ ] conjugate [ ] svd [ ] qr
- Loading branch information
0 parents
commit 9d1b9f4
Showing
11 changed files
with
1,573 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
name: CI | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
CI: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Display Python version | ||
run: python -c "import sys; print(sys.version)" | ||
- name: Install CI tools | ||
run: pip install pylint mypy pytest pytest-cov | ||
- name: Install requirements | ||
run: pip install . | ||
working-directory: ${{ runner.workspace }} | ||
- name: Run pytest | ||
run: pytest --cov=tat | ||
working-directory: ${{ runner.workspace }} | ||
- name: Run mypy | ||
run: mypy tat | ||
working-directory: ${{ runner.workspace }} | ||
- name: Run pylint | ||
run: pylint tat | ||
working-directory: ${{ runner.workspace }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.coverage | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
* TAT | ||
|
||
A Fermionic tensor library based on pytorch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
[tool.pylint] | ||
max-line-length = 120 | ||
generated-members = 'torch.*' | ||
|
||
[tool.yapf] | ||
based_on_style = 'google' | ||
column_limit = 120 | ||
|
||
[tool.mypy] | ||
check_untyped_defs = true | ||
disallow_untyped_defs = true | ||
|
||
[project] | ||
name = 'tat' | ||
version = '0.4.0' | ||
description = "A Fermionic tensor library based on pytorch." | ||
requires-python = ">=3.7" | ||
authors = [ | ||
{email = "zh970205@mail.ustc.edu.cn"}, | ||
{name = "Hao Zhang"} | ||
] | ||
dependencies = [ | ||
'multimethod', | ||
'torch', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
The tat is a Fermionic tensor library based on pytorch. | ||
""" | ||
|
||
from .edge import Edge | ||
from .tensor import Tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
""" | ||
This file implement a compat layer for legacy TAT interface. | ||
""" | ||
|
||
from __future__ import annotations | ||
import typing | ||
from multimethod import multimethod | ||
import torch | ||
from .edge import Edge as E | ||
from .tensor import Tensor as T | ||
|
||
# pylint: disable=too-few-public-methods | ||
# pylint: disable=too-many-instance-attributes | ||
|
||
|
||
class CompatSymmetry: | ||
""" | ||
The common Symmetry namespace | ||
""" | ||
|
||
def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None: | ||
self.fermion: tuple[bool, ...] = fermion | ||
self.dtypes: tuple[torch.dtype, ...] = dtypes | ||
|
||
# pylint: disable=invalid-name | ||
self.S: CompatScalar | ||
self.D: CompatScalar | ||
self.C: CompatScalar | ||
self.Z: CompatScalar | ||
self.float32: CompatScalar | ||
self.float64: CompatScalar | ||
self.float: CompatScalar | ||
self.complex64: CompatScalar | ||
self.complex128: CompatScalar | ||
self.complex: CompatScalar | ||
|
||
self.S = self.float32 = CompatScalar(self, torch.float32) | ||
self.D = self.float64 = self.float = CompatScalar(self, torch.float64) | ||
self.C = self.complex64 = CompatScalar(self, torch.complex64) | ||
self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128) | ||
|
||
def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: | ||
# Segments may be [Sym] or [(Sym, Size)] | ||
try: | ||
# try [(Sym, Size)] first | ||
return self._parse_segments_kernel(segments) | ||
except TypeError: | ||
# Cannot unpack is a type error, value[index] is a type error, too | ||
# convert [Sym] to [(Sym, Size)] | ||
return self._parse_segments_kernel([(sym, 1) for sym in segments]) | ||
|
||
def _parse_segments_kernel(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: | ||
# [(Sym, Size)] for every element | ||
dimension = sum(dim for _, dim in segments) | ||
symmetry = tuple( | ||
torch.tensor( | ||
sum( | ||
([self._parse_segments_get_subsymmetry(sym, index)] * dim | ||
for sym, dim in segments), | ||
[], | ||
), # Concat all segment for this subsymmetry | ||
dtype=sub_symmetry, | ||
) # Generate subsymmetry one by one | ||
for index, sub_symmetry in enumerate(self.dtypes)) | ||
return symmetry, dimension | ||
|
||
def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object: | ||
# Most of time, symmetry is a tuple of subsymmetry | ||
# But if there is only ome subsymmetry, it could not be a tuple but subsymmetry itself. | ||
# pylint: disable=no-else-return | ||
if isinstance(sym, tuple): | ||
return sym[index] | ||
else: | ||
if len(self.fermion) == 1: | ||
return sym | ||
else: | ||
raise TypeError(f"{sym=} is not subscriptable") | ||
|
||
@multimethod | ||
def Edge(self: CompatSymmetry, dimension: int) -> E: | ||
""" | ||
Create edge with compat interface. | ||
""" | ||
# pylint: disable=invalid-name | ||
symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False) | ||
|
||
@Edge.register | ||
def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E: | ||
symmetry, dimension = self._parse_segments(segments) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) | ||
|
||
@Edge.register | ||
def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E: | ||
segments, arrow = segments_and_bool | ||
symmetry, dimension = self._parse_segments(segments) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) | ||
|
||
|
||
class CompatScalar: | ||
""" | ||
The common Scalar namespace. | ||
""" | ||
|
||
def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None: | ||
self.symmetry: CompatSymmetry = symmetry | ||
self.dtype: torch.dtype = dtype | ||
|
||
@multimethod | ||
def Tensor(self: CompatScalar, names: list[str], edges: list) -> T: | ||
""" | ||
Create tensor with compat names and edges. | ||
""" | ||
# pylint: disable=invalid-name | ||
return T( | ||
tuple(names), | ||
tuple(self.symmetry.Edge(edge) for edge in edges), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
|
||
@Tensor.register | ||
def _(self: CompatScalar) -> T: | ||
result = T( | ||
(), | ||
(), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
result.data.reshape([-1])[0] = 1 | ||
return result | ||
|
||
@Tensor.register | ||
def _( | ||
self: CompatScalar, | ||
number: typing.Any, | ||
names: list[str] | None = None, | ||
edge_symmetry: list | None = None, | ||
edge_arrow: list[bool] | None = None, | ||
) -> T: | ||
# Create high rank tensor with only one element | ||
if names is None: | ||
names = [] | ||
if edge_symmetry is None: | ||
edge_symmetry = [None for _ in names] | ||
if edge_arrow is None: | ||
edge_arrow = [False for _ in names] | ||
result = T( | ||
tuple(names), | ||
tuple( | ||
E( | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
symmetry=tuple( | ||
torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype) | ||
for index, dtype in enumerate(self.symmetry.dtypes)), | ||
dimension=1, | ||
arrow=arrow, | ||
) | ||
for symmetry, arrow in zip(edge_symmetry, edge_arrow)), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
result.data.reshape([-1])[0] = number | ||
return result | ||
|
||
def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object: | ||
# pylint: disable=no-else-return | ||
if sym is None: | ||
return 0 | ||
elif isinstance(sym, tuple): | ||
return sym[index] | ||
else: | ||
if len(self.symmetry.fermion) == 1: | ||
return sym | ||
else: | ||
raise TypeError(f"{sym=} is not subscriptable") | ||
|
||
|
||
No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=()) | ||
Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,)) | ||
U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,)) | ||
Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,)) | ||
FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool)) | ||
FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int)) | ||
Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,)) | ||
FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int)) | ||
Normal: CompatSymmetry = No |
Oops, something went wrong.