Skip to content

Commit

Permalink
Implement prototype for torch based fermionic library.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 6, 2023
0 parents commit ae0845c
Show file tree
Hide file tree
Showing 13 changed files with 3,492 additions and 0 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: CI

on: [push, pull_request]

jobs:
CI:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- python-version: "3.10"
pytorch-version: "1.12"
- python-version: "3.10"
pytorch-version: "1.13"
- python-version: "3.10"
pytorch-version: "2.0"
- python-version: "3.10"
pytorch-version: "2.1"

- python-version: "3.11"
pytorch-version: "1.13"
- python-version: "3.11"
pytorch-version: "2.0"
- python-version: "3.11"
pytorch-version: "2.1"

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install requirements
run: |
pip install pylint==2.17 mypy==1.6 pytest==7.4 pytest-cov==4.1
pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
pip install multimethod
- name: Run pylint
run: pylint tat
working-directory: ${{ github.workspace }}
- name: Run mypy
run: mypy tat
working-directory: ${{ github.workspace }}
- name: Run pytest
run: pytest
working-directory: ${{ github.workspace }}
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.coverage
.mypy_cache
__pycache__
env
675 changes: 675 additions & 0 deletions LICENSE.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TAT

A Fermionic tensor library based on pytorch.
31 changes: 31 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[project]
name = "tat"
version = "0.4.0"
authors = [
{email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"}
]
description = "A Fermionic tensor library based on pytorch."
readme = "README.md"
requires-python = ">=3.10"
license = {text = "GPL-3.0-or-later"}
dependencies = [
"multimethod>=1.9",
"torch>=1.12",
]

[tool.pylint]
max-line-length = 120
generated-members = "torch.*"

[tool.yapf]
based_on_style = "google"
column_limit = 120

[tool.mypy]
check_untyped_defs = true
disallow_untyped_defs = true

[tool.pytest.ini_options]
pythonpath = "."
testpaths = ["tests",]
addopts = "--cov=tat"
6 changes: 6 additions & 0 deletions tat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
The tat is a Fermionic tensor library based on pytorch.
"""

from .edge import Edge
from .tensor import Tensor
41 changes: 41 additions & 0 deletions tat/_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Some internal utility used by tat.
"""

import torch

# pylint: disable=missing-function-docstring
# pylint: disable=no-else-return


def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor:
return tensor.reshape([-1 if i == index else 1 for i in range(rank)])


def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.bool:
return tensor
else:
return -tensor


def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
if tensor_1.dtype is torch.bool:
return torch.logical_xor(tensor_1, tensor_2)
else:
return torch.add(tensor_1, tensor_2)


def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor:
# pylint: disable=singleton-comparison
if tensor.dtype is torch.bool:
return tensor == False
else:
return tensor == 0


def parity(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.bool:
return tensor
else:
return tensor % 2 != 0
231 changes: 231 additions & 0 deletions tat/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""
This file implements a compat layer for legacy TAT interface.
"""

from __future__ import annotations
import typing
from multimethod import multimethod
import torch
from .edge import Edge as E
from .tensor import Tensor as T

# pylint: disable=too-few-public-methods
# pylint: disable=too-many-instance-attributes


class CompatSymmetry:
"""
The common Symmetry namespace.
"""

def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None:
# This create fake module like TAT.No, TAT.Z2 or similar things, it need to specify the symmetry attributes.
# symmetry is set by two attributes: fermion and dtypes.
self.fermion: tuple[bool, ...] = fermion
self.dtypes: tuple[torch.dtype, ...] = dtypes

# pylint: disable=invalid-name
self.S: CompatScalar
self.D: CompatScalar
self.C: CompatScalar
self.Z: CompatScalar
self.float32: CompatScalar
self.float64: CompatScalar
self.float: CompatScalar
self.complex64: CompatScalar
self.complex128: CompatScalar
self.complex: CompatScalar

# In old TAT, something like TAT.No.D is a sub module for tensor with specific scalar type.
# In this compat library, it is implemented by another fake module: CompatScalar.
self.S = self.float32 = CompatScalar(self, torch.float32)
self.D = self.float64 = self.float = CompatScalar(self, torch.float64)
self.C = self.complex64 = CompatScalar(self, torch.complex64)
self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128)

def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]:
# In TAT, user could use [Sym] or [(Sym, Size)] to set segments of a edge, where [(Sym, Size)] is nothing but
# the symmetry and size of every blocks. While [Sym] acts like [(Sym, 1)], so try to treat input as
# [(Sym, Size)] First, if error raised, convert it from [Sym] to [(Sym, 1)] and try again.
try:
# try [(Sym, Size)] first
return self._parse_segments_kernel(segments)
except TypeError:
# Cannot unpack is a type error, value[index] is a type error, too. So only catch TypeError here.
# convert [Sym] to [(Sym, Size)]
return self._parse_segments_kernel([(sym, 1) for sym in segments])
# This function return the symmetry tuple and dimension

def _parse_segments_kernel(self: CompatSymmetry,
segments: list[tuple[typing.Any, int]]) -> tuple[tuple[torch.Tensor, ...], int]:
# [(Sym, Size)] for every element
dimension = sum(dim for _, dim in segments)
symmetry = tuple(
torch.tensor(
# tat.Edge need torch.Tensor as its symmetry, convert it to torch.Tensor with specific dtype.
sum(
# Concat all segment for this subsymmetry from an empty list
# Every segment is just sym[index] * dim, sometimes sym may be sub symmetry itself directly instead
# of tuple of sub symmetry, so call an utility function _parse_segments_get_subsymmetry here.
([self._parse_segments_get_subsymmetry(sym, index)] * dim
for sym, dim in segments),
[],
),
dtype=sub_symmetry,
)
# Generate subsymmetry one by one
for index, sub_symmetry in enumerate(self.dtypes))
return symmetry, dimension

def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object:
# Most of time, symmetry is a tuple of sub symmetry
# But if there is only one sub symmetry in the symmetry, it could not be a tuple but subsymmetry itself.
# pylint: disable=no-else-return
if isinstance(sym, tuple):
# If it is tuple, there is no need to do any other check
return sym[index]
else:
# If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub
# symmetry, check it.
if len(self.fermion) == 1:
return sym
else:
raise TypeError(f"{sym=} is not subscriptable")

@multimethod
def Edge(self: CompatSymmetry, dimension: int) -> E:
"""
Create edge with compat interface.
It may be created by
1. Edge(dimension) create trivial symmetry with specified dimension.
2. Edge(segments, arrow) create with given segments and arrow.
3. Edge(segments_arrow_tuple) create with a tuple of segments and arrow.
"""
# pylint: disable=invalid-name
# Generate a trivial symmetry tuple. In this tuple, every sub symmetry is a torch.zeros tensor with specific
# dtype and the same dimension.
symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False)

@Edge.register
def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E:
symmetry, dimension = self._parse_segments(segments)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)

@Edge.register
def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E:
segments, arrow = segments_and_bool
symmetry, dimension = self._parse_segments(segments)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)


class CompatScalar:
"""
The common Scalar namespace.
"""

def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None:
# This is fake module like TAT.No.D, TAT.Fermi.complex, so it records the parent symmetry information and its
# own dtype.
self.symmetry: CompatSymmetry = symmetry
self.dtype: torch.dtype = dtype

@multimethod
def Tensor(self: CompatScalar, names: list[str], edges: list) -> T:
"""
Create tensor with compat names and edges.
It may be create by
1. Tensor(names, edges) The most used interface.
2. Tensor() Create a rank-0 tensor, fill with number 1.
3. Tensor(number, names=[], edge_symmetry=[], edge_arrow=[]) Create a size-1 tensor, with specified edge, and
filled with specified number.
"""
# pylint: disable=invalid-name
return T(
tuple(names),
tuple(self.symmetry.Edge(edge) for edge in edges),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
dtype=self.dtype,
)

@Tensor.register
def _(self: CompatScalar) -> T:
result = T(
(),
(),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
data=torch.ones([], dtype=self.dtype),
)
return result

@Tensor.register
def _(
self: CompatScalar,
number: typing.Any,
names: list[str] | None = None,
edge_symmetry: list | None = None,
edge_arrow: list[bool] | None = None,
) -> T:
# Create high rank tensor with only one element
if names is None:
names = []
if edge_symmetry is None:
edge_symmetry = [None for _ in names]
if edge_arrow is None:
edge_arrow = [False for _ in names]
result = T(
tuple(names),
tuple(
# Create edge for every rank, given the only symmetry(maybe None) and arrow.
E(
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
# For every edge, its symmetry is a tuple of all sub symmetry.
symmetry=tuple(
# For every sub symmetry, get the only symmetry for it, since dimension of all edge is 1.
# It should be noticed that the symmetry may be None, tuple or sub symmetry directly.
torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype)
for index, dtype in enumerate(self.symmetry.dtypes)),
dimension=1,
arrow=arrow,
)
for symmetry, arrow in zip(edge_symmetry, edge_arrow)),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
data=torch.full([1 for _ in names], number, dtype=self.dtype),
)
return result

def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object:
# pylint: disable=no-else-return
# sym may be None, tuple or sub symmetry directly.
if sym is None:
# If is None, user may want to create symmetric edge with trivial symmetry, which should be 0 for int and
# False for bool, always return 0 here, since it will be converted to correct type by torch.tensor.
return 0
elif isinstance(sym, tuple):
# If it is tuple, there is no need to do any other check
return sym[index]
else:
# If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub
# symmetry, check it.
if len(self.symmetry.fermion) == 1:
return sym
else:
raise TypeError(f"{sym=} is not subscriptable")


# Create fake sub module for all symmetry compiled in old version TAT
No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=())
Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,))
U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,))
Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,))
FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool))
FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int))
Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,))
FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int))
Normal: CompatSymmetry = No
Loading

0 comments on commit ae0845c

Please sign in to comment.