Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add grade involution, clifford conjugation, general inverses #2

Merged
merged 4 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions egga/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def __xor__(self, other: Expression) -> Expression:
def __or__(self, other: Expression) -> Expression:
...

@staticmethod
def grade_involution(other: Expression) -> Expression:
...

@staticmethod
def clifford_conjugation(other: Expression) -> Expression:
...

@staticmethod
def inverse(other: Expression) -> Expression:
...
Expand Down
215 changes: 173 additions & 42 deletions egga/geometric_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from dataclasses import dataclass
from functools import partial
from itertools import combinations
from typing import Dict, Optional, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type

from egglog import (
EGraph,
Expr,
Fact,
String,
StringLike,
egraph,
Expand All @@ -18,7 +19,6 @@
i64Like,
rewrite,
rule,
set_,
union,
var,
vars_,
Expand Down Expand Up @@ -49,6 +49,71 @@ class GeometricAlgebraRulesets:
fast: egglog_egraph.Ruleset


def _maybe_inverse(
m: type[Expression], x: Expression, dims: int
) -> Tuple[List[Expression], List[Fact], Expression]:
clifford_conjugation = m.clifford_conjugation
grade_involution = m.grade_involution
scalar_literal = m.scalar_literal
select_grade = m.select_grade

x_conj = clifford_conjugation(x)
x_x_conj = x * x_conj
x_conj_rev_x_x_conj = x_conj * ~x_x_conj
x_x_conj_rev_x_x_conj = x * x_conj_rev_x_x_conj

if dims == 0:
numerator = scalar_literal(1.0)
elif dims == 1:
numerator = grade_involution(x)
elif dims == 2:
numerator = x_conj
elif dims == 3:
numerator = x_conj * ~x_x_conj
elif dims == 4:
numerator = x_conj * (
x_x_conj
- scalar_literal(2.0)
# Invert sign of grade 3 and 4 parts
# TODO: is there a more efficient way?
* (
select_grade(x_x_conj, scalar_literal(3.0))
+ select_grade(x_x_conj, scalar_literal(4.0))
)
)
elif dims == 5:
numerator = x_conj_rev_x_x_conj * (
x_x_conj_rev_x_x_conj
- scalar_literal(2.0)
# Invert sign of grade 1 and 4 parts
# TODO: is there a more efficient way?
* (
select_grade(
x_x_conj_rev_x_x_conj,
scalar_literal(1.0),
)
+ select_grade(
x_x_conj_rev_x_x_conj,
scalar_literal(4.0),
)
)
)
else:
raise NotImplementedError("Unreachable")

denominator = x * numerator
dependencies = [denominator]

scalar_divisor = var("scalar_divisor", f64)
has_inverse = [
eq(denominator).to(scalar_literal(scalar_divisor)),
_not_close(scalar_divisor, 0.0),
]
inverse_x = numerator * (scalar_literal(f64(1.0) / scalar_divisor))

return dependencies, has_inverse, inverse_x


@dataclass
class GeometricAlgebra:
egraph: EGraph
Expand All @@ -70,6 +135,7 @@ def __init__(
symplectic=False,
eq_solve=False,
costs: Optional[Dict[str, int]] = None,
full_inverse=False,
):
if costs is None:
costs = {}
Expand Down Expand Up @@ -122,6 +188,14 @@ def __xor__(self, other: MathExpr) -> MathExpr:
def __or__(self, other: MathExpr) -> MathExpr:
...

@egraph.function(cost=costs.get("grade_involution"))
def grade_involution(other: MathExpr) -> MathExpr:
...

@egraph.function(cost=costs.get("clifford_conjugation"))
def clifford_conjugation(other: MathExpr) -> MathExpr:
...

@egraph.function(cost=costs.get("inverse"))
def inverse(other: MathExpr) -> MathExpr:
...
Expand Down Expand Up @@ -206,6 +280,8 @@ def sandwich(r: MathExpr, x: MathExpr) -> MathExpr:
def diff(value: MathExpr, wrt: MathExpr) -> MathExpr:
...

MathExpr.grade_involution = grade_involution
MathExpr.clifford_conjugation = clifford_conjugation
MathExpr.inverse = inverse
MathExpr.boolean = boolean
MathExpr.scalar = scalar
Expand Down Expand Up @@ -251,13 +327,10 @@ def register_addition(full=True, medium=True):
egraph.register(
# Comm
rewrite(x_1 + x_2).to(x_2 + x_1),
# Assoc
# rewrite(x_1 + (x_2 + x_3)).to((x_1 + x_2) + x_3),
)
if full:
egraph.register(
# Assoc
# rewrite((x_1 + x_2) + x_3).to(x_1 + (x_2 + x_3)),
birewrite(x_1 + (x_2 + x_3)).to((x_1 + x_2) + x_3),
)

Expand Down Expand Up @@ -359,46 +432,55 @@ def register_inner(medium=True):
)

def register_division(medium=True):
# TODO: Left / right inverse

egraph.register(
# Divide identity
rewrite(x_1 / scalar_literal(1.0)).to(x_1),
# Divide self
rewrite(x_1 / x_1).to(scalar_literal(1.0), x_1 != scalar_literal(0.0)),
# / is syntactic sugar for multiplication by inverse
rewrite(x_1 / x_2).to(x_1 * inverse(x_2)),
# Inverse of non-zero scalar
rewrite(inverse(scalar_literal(f_1))).to(
scalar_literal(f64(1.0) / f_1), _not_close(f_1, 0.0)
),
# Inverse of basis vector
rewrite(inverse(e(s_1))).to(e(s_1) * inverse(e(s_1) * e(s_1))),
# Inverse of product is product of inverses in reverse order
rewrite(inverse(x_1 * x_2)).to(inverse(x_2) * inverse(x_1)),
)

if medium:
egraph.register(
# 1 / x is inverse of x
birewrite(scalar_literal(1.0) / x_1).to(inverse(x_1)),
# Division is right multiplication inverse
birewrite(x_1 / x_2).to(x_1 * inverse(x_2)),
# Inverse basis vector
# TODO: Figure out why this is broken x_1 * x_1 != scalar_literal(0.0)
rule(eq(x_1).to(inverse(x_2))).then(
x_2 * x_2 != scalar_literal(0.0)
),
birewrite(inverse(x_1)).to(
x_1 / (x_1 * x_1), x_1 * x_1 != scalar_literal(0.0)
),
)
if full_inverse:
for dims in range(len(signature) + 1):
deps, x_1_has_inverse, x_1_inverse = _maybe_inverse(
m=MathExpr, x=x_1, dims=dims
)

egraph.register(
rule(eq(x_2).to(inverse(x_1))).then(*deps),
rewrite(inverse(x_1)).to(x_1_inverse, *x_1_has_inverse),
)

# Multiplicative equation solving with inverses
if eq_solve:
egraph.register(
rule(eq(x_3).to(x_1 * x_2)).then(
x_1 * x_1 != scalar_literal(0.0),
x_2 * x_2 != scalar_literal(0.0),
),
# Left inverse: x_3 = x_1 * x_2 -> inv(x_1) * x_3 = x_2
rule(eq(x_3).to(x_1 * x_2), x_1 * x_1 != scalar_literal(0.0)).then(
union(x_2).with_(inverse(x_1) * x_3)
),
# Right inverse: x_3 = x_1 * x_2 -> x_3 * inv(x_2) = x_1
rule(eq(x_3).to(x_1 * x_2), x_2 * x_2 != scalar_literal(0.0)).then(
union(x_1).with_(x_3 * inverse(x_2))
),
)
for dims in range(len(signature) + 1):
deps_1, x_1_has_inverse, x_1_inverse = _maybe_inverse(
m=MathExpr, x=x_1, dims=dims
)
deps_2, x_2_has_inverse, x_2_inverse = _maybe_inverse(
m=MathExpr, x=x_2, dims=dims
)

egraph.register(
# x_3 = x_1 * x_2: Figure out if x_1 and x_2 have inverses
rule(eq(x_3).to(x_1 * x_2)).then(
*deps_1,
*deps_2,
),
# Left inverse: inv(x_1) * x_3 = x_2
rule(eq(x_3).to(x_1 * x_2), *x_1_has_inverse).then(
union(x_2).with_(x_1_inverse * x_3),
),
# Right inverse: x_3 * inv(x_2) = x_1
rule(eq(x_3).to(x_1 * x_2), *x_2_has_inverse).then(
union(x_1).with_(x_3 * x_2_inverse),
),
)

def register_pow(medium=True):
egraph.register(
Expand Down Expand Up @@ -451,6 +533,8 @@ def register_scalar(medium=True):
birewrite(scalar_variable(s_1)).to(scalar(variable(s_1))),
rewrite(scalar_variable(s_1)).to(variable(s_1)),
# -0 is 0 (apparently not true for f64)
birewrite(scalar_literal(-0.0)).to(scalar_literal(0.0)),
birewrite(-scalar_literal(0.0)).to(scalar_literal(0.0)),
union(scalar_literal(-0.0)).with_(scalar_literal(0.0)),
union(-scalar_literal(0.0)).with_(scalar_literal(0.0)),
# Scalar
Expand Down Expand Up @@ -544,7 +628,7 @@ def register_grade():
egraph.register(
rewrite(grade(basis_blade)).to(
scalar_literal(float(blade_grade)), *conds
)
),
)

# Select grade
Expand All @@ -570,7 +654,6 @@ def register_grade():

def register_basic_ga(medium=True):
basis_vectors = [e(str(i)) for i in range(len(signature))]

egraph.register(
# e_i^2 = signature[i]
*map(
Expand All @@ -587,6 +670,20 @@ def register_basic_ga(medium=True):
),
# Sandwich
rewrite(sandwich(x_1, x_2)).to(x_1 * x_2 * ~x_1),
# # (e_i + e_j)^2 = e_i^2 + e_j^2
birewrite((e(s_1) + e(s_2)) ** scalar_literal(2.0)).to(
e(s_1) ** scalar_literal(2.0) + e(s_2) ** scalar_literal(2.0),
s_1 != s_2,
),
# (a + b)^2 = a^2 + ab + ba + b^2
# birewrite((x_1 + x_2) * (x_1 + x_2)).to(
# x_1 * x_1 + x_1 * x_2 + x_2 * x_1 + x_2 * x_2
# ),
# rule((x_1 + x_2) * (x_1 + x_2)).then(x_1 * x_2, x_2 * x_1),
# birewrite((x_1 + x_2) * (x_1 + x_2)).to(
# x_1 * x_1 + x_2 * x_2,
# eq(x_1 * x_2).to(-x_2 * x_1),
# ),
)

if medium:
Expand Down Expand Up @@ -670,6 +767,34 @@ def register_reverse(medium=True):
birewrite(~(x_1 + x_2)).to(~x_1 + ~x_2),
)

def register_grade_involution():
egraph.register(
birewrite(grade_involution(x_1 * x_2)).to(
grade_involution(x_1) * grade_involution(x_2)
),
birewrite(grade_involution(x_1 + x_2)).to(
grade_involution(x_1) + grade_involution(x_2)
),
rewrite(grade_involution(scalar_literal(f_1))).to(scalar_literal(f_1)),
rewrite(grade_involution(e(s_1))).to(-e(s_1)),
rewrite(grade_involution(grade_involution(x_1))).to(x_1),
)

def register_clifford_conjugation():
egraph.register(
birewrite(clifford_conjugation(x_1 * x_2)).to(
clifford_conjugation(x_2) * clifford_conjugation(x_1)
),
birewrite(clifford_conjugation(x_1 + x_2)).to(
clifford_conjugation(x_1) + clifford_conjugation(x_2)
),
rewrite(clifford_conjugation(scalar_literal(f_1))).to(
scalar_literal(f_1)
),
rewrite(clifford_conjugation(e(s_1))).to(-e(s_1)),
rewrite(clifford_conjugation(clifford_conjugation(x_1))).to(x_1),
)

def register_equality():
egraph.register(
# Comm
Expand Down Expand Up @@ -794,6 +919,8 @@ def register_diff(medium=True):
register_wedge()
register_inner()
register_reverse()
register_grade_involution()
register_clifford_conjugation()
register_diff()

set_active_ruleset(medium_ruleset)
Expand All @@ -815,6 +942,8 @@ def register_diff(medium=True):
register_wedge()
register_inner()
register_reverse()
register_grade_involution()
register_clifford_conjugation()
register_diff()

set_active_ruleset(fast_ruleset)
Expand All @@ -836,6 +965,8 @@ def register_diff(medium=True):
register_wedge(medium=False)
register_inner(medium=False)
register_reverse(medium=False)
register_grade_involution()
register_clifford_conjugation()
register_diff(medium=False)

self.egraph = egraph
Expand Down
4 changes: 3 additions & 1 deletion examples/basic_equation_solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@
# Make LHS equal to RHS
ga.egraph.register(union(lhs).with_(rhs))

assert str(simplify(ga, x)) == str(ga.expr_cls.e("0"))
solved = simplify(ga, x)
print("X:", solved)
assert str(solved) == str(ga.expr_cls.e("0"))
2 changes: 1 addition & 1 deletion tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from itertools import combinations

ga = GeometricAlgebra([0.0, 1.0, 1.0, 1.0, -1.0])
ga = GeometricAlgebra([0.0, 1.0, 1.0, 1.0, -1.0], eq_solve=False, full_inverse=True)
E = ga.expr_cls

e_0, e_1, e_2, e_3, e_4 = ga.basis_vectors
Expand Down
Loading