-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement prototype for torch based fermionic library.
- Loading branch information
0 parents
commit 8b3ea14
Showing
13 changed files
with
3,487 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
name: CI | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
CI: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
include: | ||
- python-version: "3.10" | ||
pytorch-version: "1.12" | ||
- python-version: "3.10" | ||
pytorch-version: "1.13" | ||
- python-version: "3.10" | ||
pytorch-version: "2.0" | ||
- python-version: "3.10" | ||
pytorch-version: "2.1" | ||
|
||
- python-version: "3.11" | ||
pytorch-version: "1.13" | ||
- python-version: "3.11" | ||
pytorch-version: "2.0" | ||
- python-version: "3.11" | ||
pytorch-version: "2.1" | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install requirements | ||
run: | | ||
pip install pylint==2.17 mypy==1.6 pytest==7.4 pytest-cov==4.1 | ||
pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu | ||
pip install multimethod | ||
- name: Run pylint | ||
run: pylint tat | ||
working-directory: ${{ github.workspace }} | ||
- name: Run mypy | ||
run: mypy tat | ||
working-directory: ${{ github.workspace }} | ||
- name: Run pytest | ||
run: pytest | ||
working-directory: ${{ github.workspace }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
.coverage | ||
.mypy_cache | ||
__pycache__ | ||
env |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# TAT | ||
|
||
A Fermionic tensor library based on pytorch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
[project] | ||
name = "tat" | ||
version = "0.4.0" | ||
authors = [ | ||
{email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"} | ||
] | ||
description = "A Fermionic tensor library based on pytorch." | ||
readme = "README.md" | ||
requires-python = ">=3.10" | ||
license = {text = "GPL-3.0-or-later"} | ||
dependencies = [ | ||
"multimethod>=1.9", | ||
"torch>=1.12", | ||
] | ||
|
||
[tool.pylint] | ||
max-line-length = 120 | ||
generated-members = "torch.*" | ||
|
||
[tool.yapf] | ||
based_on_style = "google" | ||
column_limit = 120 | ||
|
||
[tool.mypy] | ||
check_untyped_defs = true | ||
disallow_untyped_defs = true | ||
|
||
[tool.pytest.ini_options] | ||
pythonpath = "." | ||
testpaths = ["tests",] | ||
addopts = "--cov=tat" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
The tat is a Fermionic tensor library based on pytorch. | ||
""" | ||
|
||
from .edge import Edge | ||
from .tensor import Tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
Some internal utility used by tat. | ||
""" | ||
|
||
import torch | ||
|
||
# pylint: disable=missing-function-docstring | ||
# pylint: disable=no-else-return | ||
|
||
|
||
def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor: | ||
return tensor.reshape([-1 if i == index else 1 for i in range(rank)]) | ||
|
||
|
||
def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype is torch.bool: | ||
return tensor | ||
else: | ||
return -tensor | ||
|
||
|
||
def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: | ||
if tensor_1.dtype is torch.bool: | ||
return torch.logical_xor(tensor_1, tensor_2) | ||
else: | ||
return torch.add(tensor_1, tensor_2) | ||
|
||
|
||
def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor: | ||
# pylint: disable=singleton-comparison | ||
if tensor.dtype is torch.bool: | ||
return tensor == False | ||
else: | ||
return tensor == 0 | ||
|
||
|
||
def parity(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype is torch.bool: | ||
return tensor | ||
else: | ||
return tensor % 2 != 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
""" | ||
This file implements a compat layer for legacy TAT interface. | ||
""" | ||
|
||
from __future__ import annotations | ||
import typing | ||
from multimethod import multimethod | ||
import torch | ||
from .edge import Edge as E | ||
from .tensor import Tensor as T | ||
|
||
# pylint: disable=too-few-public-methods | ||
# pylint: disable=too-many-instance-attributes | ||
|
||
|
||
class CompatSymmetry: | ||
""" | ||
The common Symmetry namespace. | ||
""" | ||
|
||
def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None: | ||
# This create fake module like TAT.No, TAT.Z2 or similar things, it need to specify the symmetry attributes. | ||
# symmetry is set by two attributes: fermion and dtypes. | ||
self.fermion: tuple[bool, ...] = fermion | ||
self.dtypes: tuple[torch.dtype, ...] = dtypes | ||
|
||
# pylint: disable=invalid-name | ||
self.S: CompatScalar | ||
self.D: CompatScalar | ||
self.C: CompatScalar | ||
self.Z: CompatScalar | ||
self.float32: CompatScalar | ||
self.float64: CompatScalar | ||
self.float: CompatScalar | ||
self.complex64: CompatScalar | ||
self.complex128: CompatScalar | ||
self.complex: CompatScalar | ||
|
||
# In old TAT, something like TAT.No.D is a sub module for tensor with specific scalar type. | ||
# In this compat library, it is implemented by another fake module: CompatScalar. | ||
self.S = self.float32 = CompatScalar(self, torch.float32) | ||
self.D = self.float64 = self.float = CompatScalar(self, torch.float64) | ||
self.C = self.complex64 = CompatScalar(self, torch.complex64) | ||
self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128) | ||
|
||
def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: | ||
# In TAT, user could use [Sym] or [(Sym, Size)] to set segments of a edge, where [(Sym, Size)] is nothing but | ||
# the symmetry and size of every blocks. While [Sym] acts like [(Sym, 1)], so try to treat input as | ||
# [(Sym, Size)] First, if error raised, convert it from [Sym] to [(Sym, 1)] and try again. | ||
try: | ||
# try [(Sym, Size)] first | ||
return self._parse_segments_kernel(segments) | ||
except TypeError: | ||
# Cannot unpack is a type error, value[index] is a type error, too. So only catch TypeError here. | ||
# convert [Sym] to [(Sym, Size)] | ||
return self._parse_segments_kernel([(sym, 1) for sym in segments]) | ||
# This function return the symmetry tuple and dimension | ||
|
||
def _parse_segments_kernel(self: CompatSymmetry, | ||
segments: list[tuple[typing.Any, int]]) -> tuple[tuple[torch.Tensor, ...], int]: | ||
# [(Sym, Size)] for every element | ||
dimension = sum(dim for _, dim in segments) | ||
symmetry = tuple( | ||
torch.tensor( | ||
# tat.Edge need torch.Tensor as its symmetry, convert it to torch.Tensor with specific dtype. | ||
sum( | ||
# Concat all segment for this subsymmetry from an empty list | ||
# Every segment is just sym[index] * dim, sometimes sym may be sub symmetry itself directly instead | ||
# of tuple of sub symmetry, so call an utility function _parse_segments_get_subsymmetry here. | ||
([self._parse_segments_get_subsymmetry(sym, index)] * dim | ||
for sym, dim in segments), | ||
[], | ||
), | ||
dtype=sub_symmetry, | ||
) | ||
# Generate subsymmetry one by one | ||
for index, sub_symmetry in enumerate(self.dtypes)) | ||
return symmetry, dimension | ||
|
||
def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object: | ||
# Most of time, symmetry is a tuple of sub symmetry | ||
# But if there is only one sub symmetry in the symmetry, it could not be a tuple but subsymmetry itself. | ||
# pylint: disable=no-else-return | ||
if isinstance(sym, tuple): | ||
# If it is tuple, there is no need to do any other check | ||
return sym[index] | ||
else: | ||
# If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub | ||
# symmetry, check it. | ||
if len(self.fermion) == 1: | ||
return sym | ||
else: | ||
raise TypeError(f"{sym=} is not subscriptable") | ||
|
||
@multimethod | ||
def Edge(self: CompatSymmetry, dimension: int) -> E: | ||
""" | ||
Create edge with compat interface. | ||
It may be created by | ||
1. Edge(dimension) create trivial symmetry with specified dimension. | ||
2. Edge(segments, arrow) create with given segments and arrow. | ||
3. Edge(segments_arrow_tuple) create with a tuple of segments and arrow. | ||
""" | ||
# pylint: disable=invalid-name | ||
# Generate a trivial symmetry tuple. In this tuple, every sub symmetry is a torch.zeros tensor with specific | ||
# dtype and the same dimension. | ||
symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False) | ||
|
||
@Edge.register | ||
def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E: | ||
symmetry, dimension = self._parse_segments(segments) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) | ||
|
||
@Edge.register | ||
def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E: | ||
segments, arrow = segments_and_bool | ||
symmetry, dimension = self._parse_segments(segments) | ||
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) | ||
|
||
|
||
class CompatScalar: | ||
""" | ||
The common Scalar namespace. | ||
""" | ||
|
||
def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None: | ||
# This is fake module like TAT.No.D, TAT.Fermi.complex, so it records the parent symmetry information and its | ||
# own dtype. | ||
self.symmetry: CompatSymmetry = symmetry | ||
self.dtype: torch.dtype = dtype | ||
|
||
@multimethod | ||
def Tensor(self: CompatScalar, names: list[str], edges: list) -> T: | ||
""" | ||
Create tensor with compat names and edges. | ||
It may be create by | ||
1. Tensor(names, edges) The most used interface. | ||
2. Tensor() Create a rank-0 tensor, fill with number 1. | ||
3. Tensor(number, names=[], edge_symmetry=[], edge_arrow=[]) Create a size-1 tensor, with specified edge, and | ||
filled with specified number. | ||
""" | ||
# pylint: disable=invalid-name | ||
return T( | ||
tuple(names), | ||
tuple(self.symmetry.Edge(edge) for edge in edges), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
|
||
@Tensor.register | ||
def _(self: CompatScalar) -> T: | ||
result = T( | ||
(), | ||
(), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
# To set element of rank-0 pytorch tensor, reshape it to rank-0, length-1 vector first and set it. | ||
result.data.reshape([-1])[0] = 1 | ||
return result | ||
|
||
@Tensor.register | ||
def _( | ||
self: CompatScalar, | ||
number: typing.Any, | ||
names: list[str] | None = None, | ||
edge_symmetry: list | None = None, | ||
edge_arrow: list[bool] | None = None, | ||
) -> T: | ||
# Create high rank tensor with only one element | ||
if names is None: | ||
names = [] | ||
if edge_symmetry is None: | ||
edge_symmetry = [None for _ in names] | ||
if edge_arrow is None: | ||
edge_arrow = [False for _ in names] | ||
result = T( | ||
tuple(names), | ||
tuple( | ||
# Create edge for every rank, given the only symmetry(maybe None) and arrow. | ||
E( | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
# For every edge, its symmetry is a tuple of all sub symmetry. | ||
symmetry=tuple( | ||
# For every sub symmetry, get the only symmetry for it, since dimension of all edge is 1. | ||
# It should be noticed that the symmetry may be None, tuple or sub symmetry directly. | ||
torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype) | ||
for index, dtype in enumerate(self.symmetry.dtypes)), | ||
dimension=1, | ||
arrow=arrow, | ||
) | ||
for symmetry, arrow in zip(edge_symmetry, edge_arrow)), | ||
fermion=self.symmetry.fermion, | ||
dtypes=self.symmetry.dtypes, | ||
dtype=self.dtype, | ||
) | ||
# To set element of rank-0 pytorch tensor, reshape it to rank-0, length-1 vector first and set it. | ||
result.data.reshape([-1])[0] = number | ||
return result | ||
|
||
def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object: | ||
# pylint: disable=no-else-return | ||
# sym may be None, tuple or sub symmetry directly. | ||
if sym is None: | ||
# If is None, user may want to create symmetric edge with trivial symmetry, which should be 0 for int and | ||
# False for bool, always return 0 here, since it will be converted to correct type by torch.tensor. | ||
return 0 | ||
elif isinstance(sym, tuple): | ||
# If it is tuple, there is no need to do any other check | ||
return sym[index] | ||
else: | ||
# If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub | ||
# symmetry, check it. | ||
if len(self.symmetry.fermion) == 1: | ||
return sym | ||
else: | ||
raise TypeError(f"{sym=} is not subscriptable") | ||
|
||
|
||
# Create fake sub module for all symmetry compiled in old version TAT | ||
No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=()) | ||
Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,)) | ||
U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,)) | ||
Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,)) | ||
FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool)) | ||
FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int)) | ||
Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,)) | ||
FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int)) | ||
Normal: CompatSymmetry = No |
Oops, something went wrong.