Skip to content

Commit

Permalink
Implement QR decomposition by Givens rot and Householder refl.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 18, 2023
1 parent 52a417d commit 7bf3bde
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 4 deletions.
233 changes: 233 additions & 0 deletions tat/_qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""
This module implements QR decomposition based on Givens rotation and Householder reflection.
"""

import typing
import torch

# pylint: disable=invalid-name


@torch.jit.script
def _syminvadj(X: torch.Tensor) -> torch.Tensor:
ret = X + X.H
ret.diagonal().real[:] *= 1 / 2
return ret


@torch.jit.script
def _triliminvadjskew(X: torch.Tensor) -> torch.Tensor:
ret = torch.tril(X - X.H)
if torch.is_complex(X):
ret.diagonal().imag[:] *= 1 / 2
return ret


@torch.jit.script
def _qr_backward(
Q: torch.Tensor,
R: torch.Tensor,
Q_grad: typing.Optional[torch.Tensor],
R_grad: typing.Optional[torch.Tensor],
) -> typing.Optional[torch.Tensor]:
# see https://arxiv.org/pdf/2009.10071.pdf section 4.3 and 4.5
# see pytorch torch/csrc/autograd/FunctionsManual.cpp:linalg_qr_backward
m = Q.size(0)
n = R.size(1)

if Q_grad is not None:
if R_grad is not None:
MH = R_grad @ R.H - Q.H @ Q_grad
else:
MH = -Q.H @ Q_grad
else:
if R_grad is not None:
MH = R_grad @ R.H
else:
return None

# pylint: disable=no-else-return
if m >= n:
# Deep and square matrix
b = Q @ _syminvadj(torch.triu(MH))
if Q_grad is not None:
b = b + Q_grad
return torch.linalg.solve_triangular(R.H, b, upper=False, left=False)
else:
# Wide matrix
b = Q @ (_triliminvadjskew(-MH))
result = torch.linalg.solve_triangular(R[:, :m].H, b, upper=False, left=False)
result = torch.cat((result, torch.zeros([m, n - m], dtype=result.dtype, device=result.device)), dim=1)
if R_grad is not None:
result = result + Q @ R_grad
return result


class CommonQR(torch.autograd.Function):
"""
Implement the autograd function for QR.
"""

# pylint: disable=abstract-method

@staticmethod
def backward( # type: ignore[override]
ctx: typing.Any,
Q_grad: torch.Tensor | None,
R_grad: torch.Tensor | None,
) -> torch.Tensor | None:
# pylint: disable=arguments-differ
Q, R = ctx.saved_tensors
return _qr_backward(Q, R, Q_grad, R_grad)


@torch.jit.script
def _normalize_diagonal(a: torch.Tensor) -> torch.Tensor:
r = torch.sqrt(a.conj() * a)
return torch.where(
r == torch.zeros([], dtype=a.dtype, device=a.device),
torch.ones([], dtype=a.dtype, device=a.device),
a / r,
)


@torch.jit.script
def _givens_parameter(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
r = torch.sqrt(a.conj() * a + b.conj() * b)
return torch.where(
b == torch.zeros([], dtype=a.dtype, device=a.device),
torch.ones([], dtype=a.dtype, device=a.device),
a / r,
), torch.where(
b == torch.zeros([], dtype=a.dtype, device=a.device),
torch.zeros([], dtype=a.dtype, device=a.device),
b / r,
)


@torch.jit.script
def _givens_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
m, n = A.shape
k = min(m, n)
Q = torch.eye(m, dtype=A.dtype, device=A.device)
R = A.clone()

# Parallel strategy
# Every row rotated to the nearest row above
for g in range(m - 1, 0, -1):
# rotate R[g, 0], R[g+2, 1], R[g+4, 2], ...
for i, col in zip(range(g, m, 2), range(n)):
j = i - 1
# Rotate inside column col
# Rotate from row i to row j
c, s = _givens_parameter(R[j, col], R[i, col])
Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]
for g in range(1, k):
# rotate R[g+1, g], R[g+1+2, g+1], R[g+1+4, g+2], ...
for i, col in zip(range(g + 1, m, 2), range(g, n)):
j = i - 1
# Rotate inside column col
# Rotate from row i to row j
c, s = _givens_parameter(R[j, col], R[i, col])
Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]

# for j in range(n):
# for i in range(j + 1, m):
# col = j
# # Rotate inside column col
# # Rotate from row i to row j
# c, s = _givens_parameter(R[j, col], R[i, col])
# Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
# R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]

# Make diagonal positive
c = _normalize_diagonal(R.diagonal()).conj()
Q[:k] *= torch.unsqueeze(c, 1)
R[:k] *= torch.unsqueeze(c, 1)

Q, R = Q[:k].H, R[:k]
return Q, R


class GivensQR(CommonQR):
"""
Compute the reduced QR decomposition using Givens rotation.
"""

# pylint: disable=abstract-method

@staticmethod
def forward( # type: ignore[override]
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=arguments-differ
Q, R = _givens_qr(A)
ctx.save_for_backward(Q, R)
return Q, R


@torch.jit.script
def _normalize_delta(a: torch.Tensor) -> torch.Tensor:
norm = a.norm()
return torch.where(
norm == torch.zeros([], dtype=a.dtype, device=a.device),
torch.zeros([], dtype=a.dtype, device=a.device),
a / norm,
)


@torch.jit.script
def _reflect_target(x: torch.Tensor) -> torch.Tensor:
return torch.norm(x) * _normalize_diagonal(x[0])


@torch.jit.script
def _householder_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
m, n = A.shape
k = min(m, n)
Q = torch.eye(m, dtype=A.dtype, device=A.device)
R = A.clone()

for i in range(k):
x = R[i:, i]
v = torch.zeros_like(x)
# For complex matrix, it require <v|x> = <x|v>, i.e. v[0] and x[0] have opposite argument.
v[0] = _reflect_target(x)
# Reflect x to v
delta = _normalize_delta(v - x)
# H = 1 - 2 |Delta><Delta|
R[i:] -= 2 * torch.outer(delta, delta.conj() @ R[i:])
Q[i:] -= 2 * torch.outer(delta, delta.conj() @ Q[i:])

# Make diagonal positive
c = _normalize_diagonal(R.diagonal()).conj()
Q[:k] *= torch.unsqueeze(c, 1)
R[:k] *= torch.unsqueeze(c, 1)

Q, R = Q[:k].H, R[:k]
return Q, R


class HouseholderQR(CommonQR):
"""
Compute the reduced QR decomposition using Householder reflection.
"""

# pylint: disable=abstract-method

@staticmethod
def forward( # type: ignore[override]
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=arguments-differ
Q, R = _householder_qr(A)
ctx.save_for_backward(Q, R)
return Q, R


givens_qr = GivensQR.apply
householder_qr = HouseholderQR.apply
17 changes: 13 additions & 4 deletions tat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multimethod import multimethod
import torch
from . import _utility
from ._qr import givens_qr
from .edge import Edge

# pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -1739,7 +1740,12 @@ def qr(
names=("QR_Q", "QR_R"),
)

data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced")
if self.fermion:
# Blocked tensor, use Givens rotation
data_q, data_r = givens_qr(tensor.data)
else:
# Non-blocked tensor, use Householder reflection
data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced")

free_edge_q = tensor.edges[0]
common_edge_q = Tensor._guess_edge(torch.abs(data_q), free_edge_q, True)
Expand All @@ -1750,17 +1756,20 @@ def qr(
dtypes=self.dtypes,
data=data_q,
)
tensor_q._ensure_mask() # pylint: disable=protected-access
# tensor_q._ensure_mask()
free_edge_r = tensor.edges[1]
common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False)
# common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False)
# Sometimes R matrix maybe singular, guess edge will return arbitary symmetry, so do not use guessed edge.
# In the other hand, QR based on Givens rotation always gives blocked result, which can be believed.
common_edge_r = common_edge_q.conjugate()
tensor_r = Tensor(
names=(common_name_r, "QR_R"),
edges=(common_edge_r, free_edge_r),
fermion=self.fermion,
dtypes=self.dtypes,
data=data_r,
)
tensor_r._ensure_mask() # pylint: disable=protected-access
# tensor_r._ensure_mask()
assert common_edge_q.conjugate() == common_edge_r

tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q}, False, set())
Expand Down
59 changes: 59 additions & 0 deletions tests/test_qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"Test QR"

import torch
from tat._qr import givens_qr, householder_qr

# pylint: disable=missing-function-docstring
# pylint: disable=invalid-name


def check_givens(A: torch.Tensor) -> None:
m, n = A.size()
Q, R = givens_qr(A)
assert torch.allclose(A, Q @ R)
assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
grad_check = torch.autograd.gradcheck(
givens_qr,
A,
eps=1e-8,
atol=1e-4,
)
assert grad_check


def test_qr_real_givens() -> None:
check_givens(torch.randn(7, 5, dtype=torch.float64, requires_grad=True))
check_givens(torch.randn(5, 5, dtype=torch.float64, requires_grad=True))
check_givens(torch.randn(5, 7, dtype=torch.float64, requires_grad=True))


def test_qr_complex_givens() -> None:
check_givens(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True))
check_givens(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True))
check_givens(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))


def check_householder(A: torch.Tensor) -> None:
m, n = A.size()
Q, R = householder_qr(A)
assert torch.allclose(A, Q @ R)
assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
grad_check = torch.autograd.gradcheck(
householder_qr,
A,
eps=1e-8,
atol=1e-4,
)
assert grad_check


def test_qr_real_householder() -> None:
check_householder(torch.randn(7, 5, dtype=torch.float64, requires_grad=True))
check_householder(torch.randn(5, 5, dtype=torch.float64, requires_grad=True))
check_householder(torch.randn(5, 7, dtype=torch.float64, requires_grad=True))


def test_qr_complex_householder() -> None:
check_householder(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True))
check_householder(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True))
check_householder(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))

0 comments on commit 7bf3bde

Please sign in to comment.