Skip to content

Commit

Permalink
Implement prototype for torch based fermionic library.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 5, 2023
0 parents commit 1987c4b
Show file tree
Hide file tree
Showing 12 changed files with 2,565 additions and 0 deletions.
50 changes: 50 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: CI

on: [push, pull_request]

jobs:
CI:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- python-version: "3.10"
pytorch-version: "1.11"
- python-version: "3.10"
pytorch-version: "1.12"
- python-version: "3.10"
pytorch-version: "1.13"
- python-version: "3.10"
pytorch-version: "2.0"
- python-version: "3.10"
pytorch-version: "2.1"

- python-version: "3.11"
pytorch-version: "1.13"
- python-version: "3.11"
pytorch-version: "2.0"
- python-version: "3.11"
pytorch-version: "2.1"

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install requirements
run: |
pip install pylint mypy pytest pytest-cov
pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
pip install multimethod
- name: Run pylint
run: pylint tat
working-directory: ${{ github.workspace }}
- name: Run mypy
run: mypy tat
working-directory: ${{ github.workspace }}
- name: Run pytest
run: pytest
working-directory: ${{ github.workspace }}
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.coverage
__pycache__
env
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TAT

A Fermionic tensor library based on pytorch.
31 changes: 31 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[project]
name = "tat"
version = "0.4.0"
authors = [
{email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"}
]
description = "A Fermionic tensor library based on pytorch."
readme = "README.md"
requires-python = ">=3.10"
license = {text = "GPL-3.0-or-later"}
dependencies = [
"multimethod",
"torch",
]

[tool.pylint]
max-line-length = 120
generated-members = "torch.*"

[tool.yapf]
based_on_style = "google"
column_limit = 120

[tool.mypy]
check_untyped_defs = true
disallow_untyped_defs = true

[tool.pytest.ini_options]
pythonpath = "."
testpaths = ["tests",]
addopts = "--cov=tat"
6 changes: 6 additions & 0 deletions tat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
The tat is a Fermionic tensor library based on pytorch.
"""

from .edge import Edge
from .tensor import Tensor
41 changes: 41 additions & 0 deletions tat/_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Some internal utility used by tat.
"""

import torch

# pylint: disable=missing-function-docstring
# pylint: disable=no-else-return


def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor:
return tensor.reshape([-1 if i == index else 1 for i in range(rank)])


def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.bool:
return tensor
else:
return -tensor


def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
if tensor_1.dtype is torch.bool:
return torch.logical_xor(tensor_1, tensor_2)
else:
return torch.add(tensor_1, tensor_2)


def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor:
# pylint: disable=singleton-comparison
if tensor.dtype is torch.bool:
return tensor == False
else:
return tensor == 0


def parity(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.bool:
return tensor
else:
return tensor % 2 != 0
191 changes: 191 additions & 0 deletions tat/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
This file implement a compat layer for legacy TAT interface.
"""

from __future__ import annotations
import typing
from multimethod import multimethod
import torch
from .edge import Edge as E
from .tensor import Tensor as T

# pylint: disable=too-few-public-methods
# pylint: disable=too-many-instance-attributes


class CompatSymmetry:
"""
The common Symmetry namespace
"""

def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None:
self.fermion: tuple[bool, ...] = fermion
self.dtypes: tuple[torch.dtype, ...] = dtypes

# pylint: disable=invalid-name
self.S: CompatScalar
self.D: CompatScalar
self.C: CompatScalar
self.Z: CompatScalar
self.float32: CompatScalar
self.float64: CompatScalar
self.float: CompatScalar
self.complex64: CompatScalar
self.complex128: CompatScalar
self.complex: CompatScalar

self.S = self.float32 = CompatScalar(self, torch.float32)
self.D = self.float64 = self.float = CompatScalar(self, torch.float64)
self.C = self.complex64 = CompatScalar(self, torch.complex64)
self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128)

def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]:
# Segments may be [Sym] or [(Sym, Size)]
try:
# try [(Sym, Size)] first
return self._parse_segments_kernel(segments)
except TypeError:
# Cannot unpack is a type error, value[index] is a type error, too
# convert [Sym] to [(Sym, Size)]
return self._parse_segments_kernel([(sym, 1) for sym in segments])

def _parse_segments_kernel(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]:
# [(Sym, Size)] for every element
dimension = sum(dim for _, dim in segments)
symmetry = tuple(
torch.tensor(
sum(
([self._parse_segments_get_subsymmetry(sym, index)] * dim
for sym, dim in segments),
[],
), # Concat all segment for this subsymmetry
dtype=sub_symmetry,
) # Generate subsymmetry one by one
for index, sub_symmetry in enumerate(self.dtypes))
return symmetry, dimension

def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object:
# Most of time, symmetry is a tuple of subsymmetry
# But if there is only ome subsymmetry, it could not be a tuple but subsymmetry itself.
# pylint: disable=no-else-return
if isinstance(sym, tuple):
return sym[index]
else:
if len(self.fermion) == 1:
return sym
else:
raise TypeError(f"{sym=} is not subscriptable")

@multimethod
def Edge(self: CompatSymmetry, dimension: int) -> E:
"""
Create edge with compat interface.
"""
# pylint: disable=invalid-name
symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False)

@Edge.register
def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E:
symmetry, dimension = self._parse_segments(segments)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)

@Edge.register
def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E:
segments, arrow = segments_and_bool
symmetry, dimension = self._parse_segments(segments)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)


class CompatScalar:
"""
The common Scalar namespace.
"""

def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None:
self.symmetry: CompatSymmetry = symmetry
self.dtype: torch.dtype = dtype

@multimethod
def Tensor(self: CompatScalar, names: list[str], edges: list) -> T:
"""
Create tensor with compat names and edges.
"""
# pylint: disable=invalid-name
return T(
tuple(names),
tuple(self.symmetry.Edge(edge) for edge in edges),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
dtype=self.dtype,
)

@Tensor.register
def _(self: CompatScalar) -> T:
result = T(
(),
(),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
dtype=self.dtype,
)
result.data.reshape([-1])[0] = 1
return result

@Tensor.register
def _(
self: CompatScalar,
number: typing.Any,
names: list[str] | None = None,
edge_symmetry: list | None = None,
edge_arrow: list[bool] | None = None,
) -> T:
# Create high rank tensor with only one element
if names is None:
names = []
if edge_symmetry is None:
edge_symmetry = [None for _ in names]
if edge_arrow is None:
edge_arrow = [False for _ in names]
result = T(
tuple(names),
tuple(
E(
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
symmetry=tuple(
torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype)
for index, dtype in enumerate(self.symmetry.dtypes)),
dimension=1,
arrow=arrow,
)
for symmetry, arrow in zip(edge_symmetry, edge_arrow)),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
dtype=self.dtype,
)
result.data.reshape([-1])[0] = number
return result

def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object:
# pylint: disable=no-else-return
if sym is None:
return 0
elif isinstance(sym, tuple):
return sym[index]
else:
if len(self.symmetry.fermion) == 1:
return sym
else:
raise TypeError(f"{sym=} is not subscriptable")


No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=())
Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,))
U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,))
Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,))
FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool))
FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int))
Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,))
FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int))
Normal: CompatSymmetry = No
Loading

0 comments on commit 1987c4b

Please sign in to comment.