From c67202da276724b18983d5f1451bb2a8f3eab87b Mon Sep 17 00:00:00 2001 From: Robin Kahlow Date: Wed, 4 Oct 2023 22:28:18 +0100 Subject: [PATCH 1/4] add grade involution, clifford conjugation, general inverses --- egga/expression.py | 8 ++ egga/geometric_algebra.py | 174 +++++++++++++++++++++++++---- examples/basic_equation_solving.py | 4 +- tests/test_rules.py | 2 +- 4 files changed, 163 insertions(+), 25 deletions(-) diff --git a/egga/expression.py b/egga/expression.py index 8317fda..8b8b277 100644 --- a/egga/expression.py +++ b/egga/expression.py @@ -45,6 +45,14 @@ def __xor__(self, other: Expression) -> Expression: def __or__(self, other: Expression) -> Expression: ... + @staticmethod + def grade_involution(other: Expression) -> Expression: + ... + + @staticmethod + def clifford_conjugation(other: Expression) -> Expression: + ... + @staticmethod def inverse(other: Expression) -> Expression: ... diff --git a/egga/geometric_algebra.py b/egga/geometric_algebra.py index 64e5470..3c136e0 100644 --- a/egga/geometric_algebra.py +++ b/egga/geometric_algebra.py @@ -70,6 +70,7 @@ def __init__( symplectic=False, eq_solve=False, costs: Optional[Dict[str, int]] = None, + full_inverse=False, ): if costs is None: costs = {} @@ -122,6 +123,14 @@ def __xor__(self, other: MathExpr) -> MathExpr: def __or__(self, other: MathExpr) -> MathExpr: ... + @egraph.function(cost=costs.get("grade_involution")) + def grade_involution(other: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("clifford_conjugation")) + def clifford_conjugation(other: MathExpr) -> MathExpr: + ... + @egraph.function(cost=costs.get("inverse")) def inverse(other: MathExpr) -> MathExpr: ... @@ -206,6 +215,8 @@ def sandwich(r: MathExpr, x: MathExpr) -> MathExpr: def diff(value: MathExpr, wrt: MathExpr) -> MathExpr: ... + MathExpr.grade_involution = grade_involution + MathExpr.clifford_conjugation = clifford_conjugation MathExpr.inverse = inverse MathExpr.boolean = boolean MathExpr.scalar = scalar @@ -359,31 +370,113 @@ def register_inner(medium=True): ) def register_division(medium=True): - # TODO: Left / right inverse - egraph.register( - # Divide identity - rewrite(x_1 / scalar_literal(1.0)).to(x_1), - # Divide self - rewrite(x_1 / x_1).to(scalar_literal(1.0), x_1 != scalar_literal(0.0)), + # / is syntactic sugar for multiplication by inverse + birewrite(x_1 / x_2).to(x_1 * inverse(x_2)), + # Inverse of non-zero scalar + rewrite(inverse(scalar_literal(f_1))).to( + scalar_literal(f64(1.0) / f_1), _not_close(f_1, 0.0) + ), + # Inverse of basis vector + rewrite(inverse(e(s_1))).to(e(s_1) * inverse(e(s_1) * e(s_1))), ) - if medium: - egraph.register( - # 1 / x is inverse of x - birewrite(scalar_literal(1.0) / x_1).to(inverse(x_1)), - # Division is right multiplication inverse - birewrite(x_1 / x_2).to(x_1 * inverse(x_2)), - # Inverse basis vector - # TODO: Figure out why this is broken x_1 * x_1 != scalar_literal(0.0) - rule(eq(x_1).to(inverse(x_2))).then( - x_2 * x_2 != scalar_literal(0.0) - ), - birewrite(inverse(x_1)).to( - x_1 / (x_1 * x_1), x_1 * x_1 != scalar_literal(0.0) - ), - ) + if full_inverse: + dims = len(signature) + + if dims > 5: + # # Shirokov inverse https://arxiv.org/abs/2005.04015 Theorem 4 + n = 2 ** ((dims + 1) // 2) + u = x_1 + for k in range(1, n): + c = scalar_literal(n / k) * select_grade(u, scalar_literal(0.0)) + u_minus_c = u - c + u = x_1 * u_minus_c + + # As soon as u is a scalar, we can calculate an inverse + # TODO: might be missing some scalar factors on early + # termination + egraph.register( + rule(eq(x_2).to(inverse(x_1))).then(u), + rewrite(inverse(x_1)).to( + u_minus_c * scalar_literal(f64(1.0) / f_1), + eq(u).to(scalar_literal(f_1)), + _not_close(f_1, 0.0), + ), + ) + + """ + - Example + dims = 2 + n = 2 + u = e_1 + e_2 + + -- Iter k=1 + c = 0 + u_minus_c = e_1 + e_2 + u = (e_1 + e_2) * (e_1 + e_2) = 2 + + -- Result + inv = (e_1 + e_2) / 2 + """ + else: + # Closed form inverses (https://dx.doi.org/10.1016/j.amc.2017.05.027) + # More optimized forms from clifford (https://github.com/pygae/clifford). + # TODO: Might be able to pick lower dimension ones depending + # on which basis vectors are present? + + x_1_conj = clifford_conjugation(x_1) + x_1_x_1_conj = x_1 * x_1_conj + x_1_conj_rev_x_1_x_1_conj = x_1_conj * ~x_1_x_1_conj + x_1_x_1_conj_rev_x_1_x_1_conj = x_1 * x_1_conj_rev_x_1_x_1_conj + + if dims == 0: + numerator = scalar_literal(1.0) + elif dims == 1: + numerator = grade_involution(x_1) + elif dims == 2: + numerator = clifford_conjugation(x_1) + elif dims == 3: + numerator = x_1_conj * ~x_1_x_1_conj + elif dims == 4: + numerator = x_1_conj * ( + x_1_x_1_conj + - scalar_literal(2.0) + * ( + select_grade(x_1_x_1_conj, scalar_literal(3.0)) + + select_grade(x_1_x_1_conj, scalar_literal(4.0)) + ) + ) + elif len(signature) == 5: + numerator = x_1_conj_rev_x_1_x_1_conj * ( + x_1_x_1_conj_rev_x_1_x_1_conj + - scalar_literal(2.0) + * ( + select_grade( + x_1_x_1_conj_rev_x_1_x_1_conj, scalar_literal(1.0) + ) + + select_grade( + x_1_x_1_conj_rev_x_1_x_1_conj, scalar_literal(4.0) + ) + ) + ) + else: + raise NotImplementedError("Unreachable") + denominator = select_grade(x_1 * numerator, scalar_literal(0.0)) + + egraph.register( + rule(eq(x_2).to(inverse(x_1))).then(denominator), + rewrite(inverse(x_1)).to( + numerator * (scalar_literal(f64(1.0) / f_1)), + eq(denominator).to( + scalar_literal(f_1), + ), + _not_close(f_1, 0.0), + ), + ) + + # Multiplicative equation solving with inverses if eq_solve: egraph.register( rule(eq(x_3).to(x_1 * x_2)).then( @@ -451,6 +544,8 @@ def register_scalar(medium=True): birewrite(scalar_variable(s_1)).to(scalar(variable(s_1))), rewrite(scalar_variable(s_1)).to(variable(s_1)), # -0 is 0 (apparently not true for f64) + birewrite(scalar_literal(-0.0)).to(scalar_literal(0.0)), + birewrite(-scalar_literal(0.0)).to(scalar_literal(0.0)), union(scalar_literal(-0.0)).with_(scalar_literal(0.0)), union(-scalar_literal(0.0)).with_(scalar_literal(0.0)), # Scalar @@ -544,7 +639,7 @@ def register_grade(): egraph.register( rewrite(grade(basis_blade)).to( scalar_literal(float(blade_grade)), *conds - ) + ), ) # Select grade @@ -570,7 +665,6 @@ def register_grade(): def register_basic_ga(medium=True): basis_vectors = [e(str(i)) for i in range(len(signature))] - egraph.register( # e_i^2 = signature[i] *map( @@ -670,6 +764,34 @@ def register_reverse(medium=True): birewrite(~(x_1 + x_2)).to(~x_1 + ~x_2), ) + def register_grade_involution(): + egraph.register( + birewrite(grade_involution(x_1 * x_2)).to( + grade_involution(x_1) * grade_involution(x_2) + ), + birewrite(grade_involution(x_1 + x_2)).to( + grade_involution(x_1) + grade_involution(x_2) + ), + rewrite(grade_involution(scalar_literal(f_1))).to(scalar_literal(f_1)), + rewrite(grade_involution(e(s_1))).to(-e(s_1)), + rewrite(grade_involution(grade_involution(x_1))).to(x_1), + ) + + def register_clifford_conjugation(): + egraph.register( + birewrite(clifford_conjugation(x_1 * x_2)).to( + clifford_conjugation(x_2) * clifford_conjugation(x_1) + ), + birewrite(clifford_conjugation(x_1 + x_2)).to( + clifford_conjugation(x_1) + clifford_conjugation(x_2) + ), + rewrite(clifford_conjugation(scalar_literal(f_1))).to( + scalar_literal(f_1) + ), + rewrite(clifford_conjugation(e(s_1))).to(-e(s_1)), + rewrite(clifford_conjugation(clifford_conjugation(x_1))).to(x_1), + ) + def register_equality(): egraph.register( # Comm @@ -794,6 +916,8 @@ def register_diff(medium=True): register_wedge() register_inner() register_reverse() + register_grade_involution() + register_clifford_conjugation() register_diff() set_active_ruleset(medium_ruleset) @@ -815,6 +939,8 @@ def register_diff(medium=True): register_wedge() register_inner() register_reverse() + register_grade_involution() + register_clifford_conjugation() register_diff() set_active_ruleset(fast_ruleset) @@ -836,6 +962,8 @@ def register_diff(medium=True): register_wedge(medium=False) register_inner(medium=False) register_reverse(medium=False) + register_grade_involution() + register_clifford_conjugation() register_diff(medium=False) self.egraph = egraph diff --git a/examples/basic_equation_solving.py b/examples/basic_equation_solving.py index 5028317..d19e18f 100644 --- a/examples/basic_equation_solving.py +++ b/examples/basic_equation_solving.py @@ -18,4 +18,6 @@ # Make LHS equal to RHS ga.egraph.register(union(lhs).with_(rhs)) -assert str(simplify(ga, x)) == str(ga.expr_cls.e("0")) +solved = simplify(ga, x) +print("X:", solved) +assert str(solved) == str(ga.expr_cls.e("0")) diff --git a/tests/test_rules.py b/tests/test_rules.py index 2bfd6d2..39915a7 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -13,7 +13,7 @@ ) from itertools import combinations -ga = GeometricAlgebra([0.0, 1.0, 1.0, 1.0, -1.0]) +ga = GeometricAlgebra([0.0, 1.0, 1.0, 1.0, -1.0], eq_solve=False, full_inverse=True) E = ga.expr_cls e_0, e_1, e_2, e_3, e_4 = ga.basis_vectors From 1341bab1c4240fdce3a56740dd7c8c5e35daca4e Mon Sep 17 00:00:00 2001 From: Robin Kahlow Date: Wed, 4 Oct 2023 23:57:32 +0100 Subject: [PATCH 2/4] add all inverses --- egga/geometric_algebra.py | 111 +++++++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 44 deletions(-) diff --git a/egga/geometric_algebra.py b/egga/geometric_algebra.py index 3c136e0..431a24b 100644 --- a/egga/geometric_algebra.py +++ b/egga/geometric_algebra.py @@ -384,47 +384,50 @@ def register_division(medium=True): if full_inverse: dims = len(signature) - if dims > 5: - # # Shirokov inverse https://arxiv.org/abs/2005.04015 Theorem 4 - n = 2 ** ((dims + 1) // 2) - u = x_1 - for k in range(1, n): - c = scalar_literal(n / k) * select_grade(u, scalar_literal(0.0)) - u_minus_c = u - c - u = x_1 * u_minus_c - - # As soon as u is a scalar, we can calculate an inverse - # TODO: might be missing some scalar factors on early - # termination - egraph.register( - rule(eq(x_2).to(inverse(x_1))).then(u), - rewrite(inverse(x_1)).to( - u_minus_c * scalar_literal(f64(1.0) / f_1), - eq(u).to(scalar_literal(f_1)), - _not_close(f_1, 0.0), - ), - ) - - """ - - Example - dims = 2 - n = 2 - u = e_1 + e_2 - - -- Iter k=1 - c = 0 - u_minus_c = e_1 + e_2 - u = (e_1 + e_2) * (e_1 + e_2) = 2 - - -- Result - inv = (e_1 + e_2) / 2 - """ - else: - # Closed form inverses (https://dx.doi.org/10.1016/j.amc.2017.05.027) - # More optimized forms from clifford (https://github.com/pygae/clifford). - # TODO: Might be able to pick lower dimension ones depending - # on which basis vectors are present? + # TODO: 5 dim case of Hitzer inverse gives contradictions? + # if dims > 5: + #if dims > 4: + + # # Shirokov inverse https://arxiv.org/abs/2005.04015 Theorem 4 + n = 2 ** ((dims + 1) // 2) + u = x_1 + for k in range(1, n): + c = scalar_literal(n / k) * select_grade(u, scalar_literal(0.0)) + u_minus_c = u - c + u = x_1 * u_minus_c + + # As soon as u is a scalar, we can calculate an inverse + # TODO: might be missing some scalar factors on early + # termination + egraph.register( + rule(eq(x_2).to(inverse(x_1))).then(u), + rewrite(inverse(x_1)).to( + u_minus_c * scalar_literal(f64(1.0) / f_1), + eq(u).to(scalar_literal(f_1)), + _not_close(f_1, 0.0), + ), + ) + """ + - Example + dims = 2 + n = 2 + u = e_1 + e_2 + + -- Iter k=1 + c = 0 + u_minus_c = e_1 + e_2 + u = (e_1 + e_2) * (e_1 + e_2) = 2 + + -- Result + inv = u_minus_c / u = (e_1 + e_2) / 2 + """ + #else: + # Closed form inverses (https://dx.doi.org/10.1016/j.amc.2017.05.027) + # More optimized forms from clifford (https://github.com/pygae/clifford). + # TODO: Might be able to pick lower dimension ones depending + # on which basis vectors are present? + for dims in range(len(signature) + 1): x_1_conj = clifford_conjugation(x_1) x_1_x_1_conj = x_1 * x_1_conj x_1_conj_rev_x_1_x_1_conj = x_1_conj * ~x_1_x_1_conj @@ -435,13 +438,15 @@ def register_division(medium=True): elif dims == 1: numerator = grade_involution(x_1) elif dims == 2: - numerator = clifford_conjugation(x_1) + numerator = x_1_conj elif dims == 3: numerator = x_1_conj * ~x_1_x_1_conj elif dims == 4: numerator = x_1_conj * ( x_1_x_1_conj - scalar_literal(2.0) + # Invert sign of grade 3 and 4 parts + # TODO: is there a more efficient way? * ( select_grade(x_1_x_1_conj, scalar_literal(3.0)) + select_grade(x_1_x_1_conj, scalar_literal(4.0)) @@ -451,19 +456,23 @@ def register_division(medium=True): numerator = x_1_conj_rev_x_1_x_1_conj * ( x_1_x_1_conj_rev_x_1_x_1_conj - scalar_literal(2.0) + # Invert sign of grade 1 and 4 parts + # TODO: is there a more efficient way? * ( select_grade( - x_1_x_1_conj_rev_x_1_x_1_conj, scalar_literal(1.0) + x_1_x_1_conj_rev_x_1_x_1_conj, + scalar_literal(1.0), ) + select_grade( - x_1_x_1_conj_rev_x_1_x_1_conj, scalar_literal(4.0) + x_1_x_1_conj_rev_x_1_x_1_conj, + scalar_literal(4.0), ) ) ) else: raise NotImplementedError("Unreachable") - denominator = select_grade(x_1 * numerator, scalar_literal(0.0)) + denominator = x_1 * numerator egraph.register( rule(eq(x_2).to(inverse(x_1))).then(denominator), @@ -681,6 +690,20 @@ def register_basic_ga(medium=True): ), # Sandwich rewrite(sandwich(x_1, x_2)).to(x_1 * x_2 * ~x_1), + # # (e_i + e_j)^2 = e_i^2 + e_j^2 + birewrite((e(s_1) + e(s_2)) ** scalar_literal(2.0)).to( + e(s_1) ** scalar_literal(2.0) + e(s_2) ** scalar_literal(2.0), + s_1 != s_2, + ), + # (a + b)^2 = a^2 + ab + ba + b^2 + # birewrite((x_1 + x_2) * (x_1 + x_2)).to( + # x_1 * x_1 + x_1 * x_2 + x_2 * x_1 + x_2 * x_2 + # ), + # rule((x_1 + x_2) * (x_1 + x_2)).then(x_1 * x_2, x_2 * x_1), + # birewrite((x_1 + x_2) * (x_1 + x_2)).to( + # x_1 * x_1 + x_2 * x_2, + # eq(x_1 * x_2).to(-x_2 * x_1), + # ), ) if medium: From 8de960c8741c670873e0461847a8af661b77d3da Mon Sep 17 00:00:00 2001 From: Robin Kahlow Date: Wed, 11 Oct 2023 20:26:19 +0100 Subject: [PATCH 3/4] only use inverse for equation solving if inverse exists --- egga/geometric_algebra.py | 214 +++++++++++++++++--------------------- 1 file changed, 96 insertions(+), 118 deletions(-) diff --git a/egga/geometric_algebra.py b/egga/geometric_algebra.py index 431a24b..6a64165 100644 --- a/egga/geometric_algebra.py +++ b/egga/geometric_algebra.py @@ -3,11 +3,12 @@ from dataclasses import dataclass from functools import partial from itertools import combinations -from typing import Dict, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type from egglog import ( EGraph, Expr, + Fact, String, StringLike, egraph, @@ -18,7 +19,6 @@ i64Like, rewrite, rule, - set_, union, var, vars_, @@ -49,6 +49,71 @@ class GeometricAlgebraRulesets: fast: egglog_egraph.Ruleset +def _maybe_inverse( + m: type[Expression], x: Expression, dims: int +) -> Tuple[List[Expression], List[Fact], Expression]: + clifford_conjugation = m.clifford_conjugation + grade_involution = m.grade_involution + scalar_literal = m.scalar_literal + select_grade = m.select_grade + + x_conj = clifford_conjugation(x) + x_x_conj = x * x_conj + x_conj_rev_x_x_conj = x_conj * ~x_x_conj + x_x_conj_rev_x_x_conj = x * x_conj_rev_x_x_conj + + if dims == 0: + numerator = scalar_literal(1.0) + elif dims == 1: + numerator = grade_involution(x) + elif dims == 2: + numerator = x_conj + elif dims == 3: + numerator = x_conj * ~x_x_conj + elif dims == 4: + numerator = x_conj * ( + x_x_conj + - scalar_literal(2.0) + # Invert sign of grade 3 and 4 parts + # TODO: is there a more efficient way? + * ( + select_grade(x_x_conj, scalar_literal(3.0)) + + select_grade(x_x_conj, scalar_literal(4.0)) + ) + ) + elif dims == 5: + numerator = x_conj_rev_x_x_conj * ( + x_x_conj_rev_x_x_conj + - scalar_literal(2.0) + # Invert sign of grade 1 and 4 parts + # TODO: is there a more efficient way? + * ( + select_grade( + x_x_conj_rev_x_x_conj, + scalar_literal(1.0), + ) + + select_grade( + x_x_conj_rev_x_x_conj, + scalar_literal(4.0), + ) + ) + ) + else: + raise NotImplementedError("Unreachable") + + denominator = x * numerator + dependencies = [denominator] + + scalar_divisor = var("scalar_divisor", f64) + has_inverse = [ + eq(denominator).to(scalar_literal(scalar_divisor)), + _not_close(scalar_divisor, 0.0), + ] + inverse_x = numerator * (scalar_literal(f64(1.0) / scalar_divisor)) + + return dependencies, has_inverse, inverse_x + + @dataclass class GeometricAlgebra: egraph: EGraph @@ -262,13 +327,10 @@ def register_addition(full=True, medium=True): egraph.register( # Comm rewrite(x_1 + x_2).to(x_2 + x_1), - # Assoc - # rewrite(x_1 + (x_2 + x_3)).to((x_1 + x_2) + x_3), ) if full: egraph.register( # Assoc - # rewrite((x_1 + x_2) + x_3).to(x_1 + (x_2 + x_3)), birewrite(x_1 + (x_2 + x_3)).to((x_1 + x_2) + x_3), ) @@ -372,7 +434,7 @@ def register_inner(medium=True): def register_division(medium=True): egraph.register( # / is syntactic sugar for multiplication by inverse - birewrite(x_1 / x_2).to(x_1 * inverse(x_2)), + rewrite(x_1 / x_2).to(x_1 * inverse(x_2)), # Inverse of non-zero scalar rewrite(inverse(scalar_literal(f_1))).to( scalar_literal(f64(1.0) / f_1), _not_close(f_1, 0.0) @@ -382,125 +444,41 @@ def register_division(medium=True): ) if full_inverse: - dims = len(signature) - - # TODO: 5 dim case of Hitzer inverse gives contradictions? - # if dims > 5: - #if dims > 4: - - # # Shirokov inverse https://arxiv.org/abs/2005.04015 Theorem 4 - n = 2 ** ((dims + 1) // 2) - u = x_1 - for k in range(1, n): - c = scalar_literal(n / k) * select_grade(u, scalar_literal(0.0)) - u_minus_c = u - c - u = x_1 * u_minus_c - - # As soon as u is a scalar, we can calculate an inverse - # TODO: might be missing some scalar factors on early - # termination - egraph.register( - rule(eq(x_2).to(inverse(x_1))).then(u), - rewrite(inverse(x_1)).to( - u_minus_c * scalar_literal(f64(1.0) / f_1), - eq(u).to(scalar_literal(f_1)), - _not_close(f_1, 0.0), - ), - ) - - """ - - Example - dims = 2 - n = 2 - u = e_1 + e_2 - - -- Iter k=1 - c = 0 - u_minus_c = e_1 + e_2 - u = (e_1 + e_2) * (e_1 + e_2) = 2 - - -- Result - inv = u_minus_c / u = (e_1 + e_2) / 2 - """ - #else: - # Closed form inverses (https://dx.doi.org/10.1016/j.amc.2017.05.027) - # More optimized forms from clifford (https://github.com/pygae/clifford). - # TODO: Might be able to pick lower dimension ones depending - # on which basis vectors are present? for dims in range(len(signature) + 1): - x_1_conj = clifford_conjugation(x_1) - x_1_x_1_conj = x_1 * x_1_conj - x_1_conj_rev_x_1_x_1_conj = x_1_conj * ~x_1_x_1_conj - x_1_x_1_conj_rev_x_1_x_1_conj = x_1 * x_1_conj_rev_x_1_x_1_conj - - if dims == 0: - numerator = scalar_literal(1.0) - elif dims == 1: - numerator = grade_involution(x_1) - elif dims == 2: - numerator = x_1_conj - elif dims == 3: - numerator = x_1_conj * ~x_1_x_1_conj - elif dims == 4: - numerator = x_1_conj * ( - x_1_x_1_conj - - scalar_literal(2.0) - # Invert sign of grade 3 and 4 parts - # TODO: is there a more efficient way? - * ( - select_grade(x_1_x_1_conj, scalar_literal(3.0)) - + select_grade(x_1_x_1_conj, scalar_literal(4.0)) - ) - ) - elif len(signature) == 5: - numerator = x_1_conj_rev_x_1_x_1_conj * ( - x_1_x_1_conj_rev_x_1_x_1_conj - - scalar_literal(2.0) - # Invert sign of grade 1 and 4 parts - # TODO: is there a more efficient way? - * ( - select_grade( - x_1_x_1_conj_rev_x_1_x_1_conj, - scalar_literal(1.0), - ) - + select_grade( - x_1_x_1_conj_rev_x_1_x_1_conj, - scalar_literal(4.0), - ) - ) - ) - else: - raise NotImplementedError("Unreachable") - - denominator = x_1 * numerator + deps, x_1_has_inverse, x_1_inverse = _maybe_inverse( + m=MathExpr, x=x_1, dims=dims + ) egraph.register( - rule(eq(x_2).to(inverse(x_1))).then(denominator), - rewrite(inverse(x_1)).to( - numerator * (scalar_literal(f64(1.0) / f_1)), - eq(denominator).to( - scalar_literal(f_1), - ), - _not_close(f_1, 0.0), - ), + rule(eq(x_2).to(inverse(x_1))).then(*deps), + rewrite(inverse(x_1)).to(x_1_inverse, *x_1_has_inverse), ) # Multiplicative equation solving with inverses if eq_solve: - egraph.register( - rule(eq(x_3).to(x_1 * x_2)).then( - x_1 * x_1 != scalar_literal(0.0), - x_2 * x_2 != scalar_literal(0.0), - ), - # Left inverse: x_3 = x_1 * x_2 -> inv(x_1) * x_3 = x_2 - rule(eq(x_3).to(x_1 * x_2), x_1 * x_1 != scalar_literal(0.0)).then( - union(x_2).with_(inverse(x_1) * x_3) - ), - # Right inverse: x_3 = x_1 * x_2 -> x_3 * inv(x_2) = x_1 - rule(eq(x_3).to(x_1 * x_2), x_2 * x_2 != scalar_literal(0.0)).then( - union(x_1).with_(x_3 * inverse(x_2)) - ), - ) + for dims in range(len(signature) + 1): + deps_1, x_1_has_inverse, x_1_inverse = _maybe_inverse( + m=MathExpr, x=x_1, dims=dims + ) + deps_2, x_2_has_inverse, x_2_inverse = _maybe_inverse( + m=MathExpr, x=x_2, dims=dims + ) + + egraph.register( + # x_3 = x_1 * x_2: Figure out if x_1 and x_2 have inverses + rule(eq(x_3).to(x_1 * x_2)).then( + *deps_1, + *deps_2, + ), + # Left inverse: inv(x_1) * x_3 = x_2 + rule(eq(x_3).to(x_1 * x_2), *x_1_has_inverse).then( + union(x_2).with_(x_1_inverse * x_3), + ), + # Right inverse: x_3 * inv(x_2) = x_1 + rule(eq(x_3).to(x_1 * x_2), *x_2_has_inverse).then( + union(x_1).with_(x_3 * x_2_inverse), + ), + ) def register_pow(medium=True): egraph.register( From 8f8d1fc17af7b9498a648f2f6edbda7221529a06 Mon Sep 17 00:00:00 2001 From: Robin Kahlow Date: Wed, 11 Oct 2023 22:51:54 +0100 Subject: [PATCH 4/4] add inverse of product rule --- egga/geometric_algebra.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egga/geometric_algebra.py b/egga/geometric_algebra.py index 6a64165..942b4c1 100644 --- a/egga/geometric_algebra.py +++ b/egga/geometric_algebra.py @@ -441,6 +441,8 @@ def register_division(medium=True): ), # Inverse of basis vector rewrite(inverse(e(s_1))).to(e(s_1) * inverse(e(s_1) * e(s_1))), + # Inverse of product is product of inverses in reverse order + rewrite(inverse(x_1 * x_2)).to(inverse(x_2) * inverse(x_1)), ) if full_inverse: