Skip to content

Commit

Permalink
add grade involution, clifford conjugation, general inverses
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinKa committed Oct 4, 2023
1 parent 88bc0cc commit c67202d
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 25 deletions.
8 changes: 8 additions & 0 deletions egga/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def __xor__(self, other: Expression) -> Expression:
def __or__(self, other: Expression) -> Expression:
...

@staticmethod
def grade_involution(other: Expression) -> Expression:
...

@staticmethod
def clifford_conjugation(other: Expression) -> Expression:
...

@staticmethod
def inverse(other: Expression) -> Expression:
...
Expand Down
174 changes: 151 additions & 23 deletions egga/geometric_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
symplectic=False,
eq_solve=False,
costs: Optional[Dict[str, int]] = None,
full_inverse=False,
):
if costs is None:
costs = {}
Expand Down Expand Up @@ -122,6 +123,14 @@ def __xor__(self, other: MathExpr) -> MathExpr:
def __or__(self, other: MathExpr) -> MathExpr:
...

@egraph.function(cost=costs.get("grade_involution"))
def grade_involution(other: MathExpr) -> MathExpr:
...

@egraph.function(cost=costs.get("clifford_conjugation"))
def clifford_conjugation(other: MathExpr) -> MathExpr:
...

@egraph.function(cost=costs.get("inverse"))
def inverse(other: MathExpr) -> MathExpr:
...
Expand Down Expand Up @@ -206,6 +215,8 @@ def sandwich(r: MathExpr, x: MathExpr) -> MathExpr:
def diff(value: MathExpr, wrt: MathExpr) -> MathExpr:
...

MathExpr.grade_involution = grade_involution
MathExpr.clifford_conjugation = clifford_conjugation
MathExpr.inverse = inverse
MathExpr.boolean = boolean
MathExpr.scalar = scalar
Expand Down Expand Up @@ -359,31 +370,113 @@ def register_inner(medium=True):
)

def register_division(medium=True):
# TODO: Left / right inverse

egraph.register(
# Divide identity
rewrite(x_1 / scalar_literal(1.0)).to(x_1),
# Divide self
rewrite(x_1 / x_1).to(scalar_literal(1.0), x_1 != scalar_literal(0.0)),
# / is syntactic sugar for multiplication by inverse
birewrite(x_1 / x_2).to(x_1 * inverse(x_2)),
# Inverse of non-zero scalar
rewrite(inverse(scalar_literal(f_1))).to(
scalar_literal(f64(1.0) / f_1), _not_close(f_1, 0.0)
),
# Inverse of basis vector
rewrite(inverse(e(s_1))).to(e(s_1) * inverse(e(s_1) * e(s_1))),
)

if medium:
egraph.register(
# 1 / x is inverse of x
birewrite(scalar_literal(1.0) / x_1).to(inverse(x_1)),
# Division is right multiplication inverse
birewrite(x_1 / x_2).to(x_1 * inverse(x_2)),
# Inverse basis vector
# TODO: Figure out why this is broken x_1 * x_1 != scalar_literal(0.0)
rule(eq(x_1).to(inverse(x_2))).then(
x_2 * x_2 != scalar_literal(0.0)
),
birewrite(inverse(x_1)).to(
x_1 / (x_1 * x_1), x_1 * x_1 != scalar_literal(0.0)
),
)
if full_inverse:
dims = len(signature)

if dims > 5:
# # Shirokov inverse https://arxiv.org/abs/2005.04015 Theorem 4
n = 2 ** ((dims + 1) // 2)
u = x_1
for k in range(1, n):
c = scalar_literal(n / k) * select_grade(u, scalar_literal(0.0))
u_minus_c = u - c
u = x_1 * u_minus_c

# As soon as u is a scalar, we can calculate an inverse
# TODO: might be missing some scalar factors on early
# termination
egraph.register(
rule(eq(x_2).to(inverse(x_1))).then(u),
rewrite(inverse(x_1)).to(
u_minus_c * scalar_literal(f64(1.0) / f_1),
eq(u).to(scalar_literal(f_1)),
_not_close(f_1, 0.0),
),
)

"""
- Example
dims = 2
n = 2
u = e_1 + e_2
-- Iter k=1
c = 0
u_minus_c = e_1 + e_2
u = (e_1 + e_2) * (e_1 + e_2) = 2
-- Result
inv = (e_1 + e_2) / 2
"""
else:
# Closed form inverses (https://dx.doi.org/10.1016/j.amc.2017.05.027)
# More optimized forms from clifford (https://github.com/pygae/clifford).
# TODO: Might be able to pick lower dimension ones depending
# on which basis vectors are present?

x_1_conj = clifford_conjugation(x_1)
x_1_x_1_conj = x_1 * x_1_conj
x_1_conj_rev_x_1_x_1_conj = x_1_conj * ~x_1_x_1_conj
x_1_x_1_conj_rev_x_1_x_1_conj = x_1 * x_1_conj_rev_x_1_x_1_conj

if dims == 0:
numerator = scalar_literal(1.0)
elif dims == 1:
numerator = grade_involution(x_1)
elif dims == 2:
numerator = clifford_conjugation(x_1)
elif dims == 3:
numerator = x_1_conj * ~x_1_x_1_conj
elif dims == 4:
numerator = x_1_conj * (
x_1_x_1_conj
- scalar_literal(2.0)
* (
select_grade(x_1_x_1_conj, scalar_literal(3.0))
+ select_grade(x_1_x_1_conj, scalar_literal(4.0))
)
)
elif len(signature) == 5:
numerator = x_1_conj_rev_x_1_x_1_conj * (
x_1_x_1_conj_rev_x_1_x_1_conj
- scalar_literal(2.0)
* (
select_grade(
x_1_x_1_conj_rev_x_1_x_1_conj, scalar_literal(1.0)
)
+ select_grade(
x_1_x_1_conj_rev_x_1_x_1_conj, scalar_literal(4.0)
)
)
)
else:
raise NotImplementedError("Unreachable")

denominator = select_grade(x_1 * numerator, scalar_literal(0.0))

egraph.register(
rule(eq(x_2).to(inverse(x_1))).then(denominator),
rewrite(inverse(x_1)).to(
numerator * (scalar_literal(f64(1.0) / f_1)),
eq(denominator).to(
scalar_literal(f_1),
),
_not_close(f_1, 0.0),
),
)

# Multiplicative equation solving with inverses
if eq_solve:
egraph.register(
rule(eq(x_3).to(x_1 * x_2)).then(
Expand Down Expand Up @@ -451,6 +544,8 @@ def register_scalar(medium=True):
birewrite(scalar_variable(s_1)).to(scalar(variable(s_1))),
rewrite(scalar_variable(s_1)).to(variable(s_1)),
# -0 is 0 (apparently not true for f64)
birewrite(scalar_literal(-0.0)).to(scalar_literal(0.0)),
birewrite(-scalar_literal(0.0)).to(scalar_literal(0.0)),
union(scalar_literal(-0.0)).with_(scalar_literal(0.0)),
union(-scalar_literal(0.0)).with_(scalar_literal(0.0)),
# Scalar
Expand Down Expand Up @@ -544,7 +639,7 @@ def register_grade():
egraph.register(
rewrite(grade(basis_blade)).to(
scalar_literal(float(blade_grade)), *conds
)
),
)

# Select grade
Expand All @@ -570,7 +665,6 @@ def register_grade():

def register_basic_ga(medium=True):
basis_vectors = [e(str(i)) for i in range(len(signature))]

egraph.register(
# e_i^2 = signature[i]
*map(
Expand Down Expand Up @@ -670,6 +764,34 @@ def register_reverse(medium=True):
birewrite(~(x_1 + x_2)).to(~x_1 + ~x_2),
)

def register_grade_involution():
egraph.register(
birewrite(grade_involution(x_1 * x_2)).to(
grade_involution(x_1) * grade_involution(x_2)
),
birewrite(grade_involution(x_1 + x_2)).to(
grade_involution(x_1) + grade_involution(x_2)
),
rewrite(grade_involution(scalar_literal(f_1))).to(scalar_literal(f_1)),
rewrite(grade_involution(e(s_1))).to(-e(s_1)),
rewrite(grade_involution(grade_involution(x_1))).to(x_1),
)

def register_clifford_conjugation():
egraph.register(
birewrite(clifford_conjugation(x_1 * x_2)).to(
clifford_conjugation(x_2) * clifford_conjugation(x_1)
),
birewrite(clifford_conjugation(x_1 + x_2)).to(
clifford_conjugation(x_1) + clifford_conjugation(x_2)
),
rewrite(clifford_conjugation(scalar_literal(f_1))).to(
scalar_literal(f_1)
),
rewrite(clifford_conjugation(e(s_1))).to(-e(s_1)),
rewrite(clifford_conjugation(clifford_conjugation(x_1))).to(x_1),
)

def register_equality():
egraph.register(
# Comm
Expand Down Expand Up @@ -794,6 +916,8 @@ def register_diff(medium=True):
register_wedge()
register_inner()
register_reverse()
register_grade_involution()
register_clifford_conjugation()
register_diff()

set_active_ruleset(medium_ruleset)
Expand All @@ -815,6 +939,8 @@ def register_diff(medium=True):
register_wedge()
register_inner()
register_reverse()
register_grade_involution()
register_clifford_conjugation()
register_diff()

set_active_ruleset(fast_ruleset)
Expand All @@ -836,6 +962,8 @@ def register_diff(medium=True):
register_wedge(medium=False)
register_inner(medium=False)
register_reverse(medium=False)
register_grade_involution()
register_clifford_conjugation()
register_diff(medium=False)

self.egraph = egraph
Expand Down
4 changes: 3 additions & 1 deletion examples/basic_equation_solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@
# Make LHS equal to RHS
ga.egraph.register(union(lhs).with_(rhs))

assert str(simplify(ga, x)) == str(ga.expr_cls.e("0"))
solved = simplify(ga, x)
print("X:", solved)
assert str(solved) == str(ga.expr_cls.e("0"))
2 changes: 1 addition & 1 deletion tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from itertools import combinations

ga = GeometricAlgebra([0.0, 1.0, 1.0, 1.0, -1.0])
ga = GeometricAlgebra([0.0, 1.0, 1.0, 1.0, -1.0], eq_solve=False, full_inverse=True)
E = ga.expr_cls

e_0, e_1, e_2, e_3, e_4 = ga.basis_vectors
Expand Down

0 comments on commit c67202d

Please sign in to comment.