Skip to content

Commit

Permalink
0.1.3, fix some contradictions because of 0.0 != -0.0, fix variable c…
Browse files Browse the repository at this point in the history
…ost for eq solving, fix some readme issues
  • Loading branch information
RobinKa committed Sep 28, 2023
1 parent 9f05b69 commit e970a0a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 40 deletions.
52 changes: 25 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ from egglog import union
from egga.geometric_algebra import GeometricAlgebra
from egga.utils import simplify

# Pass eq_solve=True to enable the equation solving rules
ga = GeometricAlgebra(
signature=[1.0, 1.0],
eq_solve=True,
)
# Pass eq_solve=True to enable the equation solving rules.
# Add a cost to variable to it gets rewritten to something else.
ga = GeometricAlgebra(signature=[1.0, 1.0], eq_solve=True, costs={"variable": 1_000})

e_0, e_1 = ga.basis_vectors
e_01 = e_0 * e_1
Expand Down Expand Up @@ -113,28 +111,28 @@ The [/examples](examples) as well as the [/tests](tests) directories contain mor

### Functions

| Code | Description |
| ------------------------ | ------------------------------------------------------------------------------------------------------------------------- |
| `inverse(x)` | Right-multiplicative inverse of x |
| `scalar(x)` | Mark x as a scalar |
| `scalar_literal(f)` | Create a scalar constant |
| `scalar_variable(s)` | Create a scalar variable |
| `e(s)` | Basis vector |
| `e2(s1, s2)` | Basis bivector |
| `e3(s1, s2, s3)` | Basis trivector |
| `variable(s)` | Create a variable |
| `cos(x)` | Cos of x |
| `sin(x)` | Sin of x |
| `cosh(x)` | Cosh of x |
| `sinh(x)` | Sinh of x |
| `exp(x)` | Exponential function of x |
| `grade(x)` | Grade of x |
| `mix_grades(x_1, x_2)` | Represents the mixture of two grades. If the grades of x_1 and x_2 are the same, this will be simplified to `grade(x_1)`. |
| `select_grade(x_1, x_2)` | Selects the grade x_2 part of x_1 |
| `abs(x)` | Absolute value of x |
| `rotor(x_1, x_2)` | Shorthand for `exp(scalar_literal(-0.5) * scalar(x_2) * x_1)` |
| `sandwich(x_1, x_2)` | Shorthand for `x_1 * x_2 * ~x_1` |
| `diff(x_1, x_2)` | Derivative of x_1 with respect to x_2 |
| Code | Description |
| ------------------------ | ---------------------------------------------------------------------------------------------------- |
| `inverse(x)` | Right-multiplicative inverse of x |
| `scalar(x)` | Mark x as a scalar |
| `scalar_literal(f)` | Create a scalar constant |
| `scalar_variable(s)` | Create a scalar variable |
| `e(s)` | Basis vector |
| `e2(s_1, s_2)` | Basis bivector |
| `e3(s_1, s_2, s_3)` | Basis trivector |
| `variable(s)` | Create a variable |
| `cos(x)` | Cos of x |
| `sin(x)` | Sin of x |
| `cosh(x)` | Cosh of x |
| `sinh(x)` | Sinh of x |
| `exp(x)` | Exponential function of x |
| `grade(x)` | Grade of x |
| `mix_grades(x_1, x_2)` | Represents the mixture of two grades. If x_1 and x_2 are the same, this will be simplified to `x_1`. |
| `select_grade(x_1, x_2)` | Selects the grade x_2 part of x_1 |
| `abs(x)` | Absolute value of x |
| `rotor(x_1, x_2)` | Shorthand for `exp(scalar_literal(-0.5) * scalar(x_2) * x_1)` |
| `sandwich(x_1, x_2)` | Shorthand for `x_1 * x_2 * ~x_1` |
| `diff(x_1, x_2)` | Derivative of x_1 with respect to x_2 |

### Unsupported but exists, might or might not work

Expand Down
28 changes: 20 additions & 8 deletions egga/geometric_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@

birewrite = egraph.birewrite

def _close(a: f64Like, b: f64Like, eps: float = 1e-3):
diff = a - b
return diff * diff < eps * eps

def _not_close(a: f64Like, b: f64Like, eps: float = 1e-3):
diff = a - b
return diff * diff >= eps * eps

@dataclass
class GeometricAlgebraRulesets:
Expand Down Expand Up @@ -420,6 +427,9 @@ def register_scalar(medium=True):
birewrite(scalar_variable(s_1)).to(scalar(scalar_variable(s_1))),
birewrite(scalar_variable(s_1)).to(scalar(variable(s_1))),
rewrite(scalar_variable(s_1)).to(variable(s_1)),
# -0 is 0 (apparently not true for f64)
union(scalar_literal(-0.0)).with_(scalar_literal(0.0)),
union(-scalar_literal(0.0)).with_(scalar_literal(0.0)),
# Scalar
rewrite(scalar(x_1) + scalar(x_2)).to(scalar(x_1 + x_2)),
rewrite(scalar(x_1) - scalar(x_2)).to(scalar(x_1 - x_2)),
Expand All @@ -438,7 +448,7 @@ def register_scalar(medium=True):
scalar_literal(f_1 * f_2)
),
rewrite(scalar_literal(f_1) / scalar_literal(f_2)).to(
scalar_literal(f_1 / f_2), f_2 != 0.0
scalar_literal(f_1 / f_2), _not_close(f_2, 0.0)
),
# Scalar literal - abs
rewrite(abs_(scalar_literal(f_1))).to(
Expand Down Expand Up @@ -474,7 +484,9 @@ def register_grade():
# rule(eq(x_2).to(grade(scalar(x_1)))).then(
# x_1 != scalar_literal(0.0)
# ),
rewrite(grade(scalar_literal(f_1))).to(scalar_literal(0.0), f_1 != 0.0),
rewrite(grade(scalar_literal(f_1))).to(
scalar_literal(0.0), _not_close(f_1, 0.0)
),
# grade(a + b) -> mix_grades(grade(a), grade(b)), if a + b is not zero
# rule(eq(x_1).to(grade(x_2 + x_3))).then(
# x_2 + x_3 != scalar_literal(0.0)
Expand All @@ -488,7 +500,9 @@ def register_grade():
# rule(eq(x_1).to(scalar(x_2) * x_3)).then(
# x_2 != scalar_literal(0.0)
# ),
rewrite(grade(scalar_literal(f_1) * x_2)).to(grade(x_2), f_1 != 0.0),
rewrite(grade(scalar_literal(f_1) * x_2)).to(
grade(x_2), _not_close(f_1, 0.0)
),
)

# Basis blade grades
Expand Down Expand Up @@ -517,12 +531,10 @@ def register_grade():
select_grade(x_1, x_3) + select_grade(x_2, x_3)
),
# select_grade(x, y) -> 0 if grade(x) != y
rule(
eq(x_2).to(select_grade(x_1, scalar_literal(f_1))),
rewrite(select_grade(x_1, scalar_literal(f_1))).to(
scalar_literal(0.0),
eq(grade(x_1)).to(scalar_literal(f_2)),
f_1 != f_2,
).then(
set_(select_grade(x_1, scalar_literal(f_1))).to(scalar_literal(0.0))
_not_close(f_1, f_2),
),
# select_grade(x, y) -> x if grade(x) == y
rule(eq(x_3).to(select_grade(x_1, scalar_literal(f_1)))).then(
Expand Down
8 changes: 3 additions & 5 deletions examples/basic_equation_solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from egga.geometric_algebra import GeometricAlgebra
from egga.utils import simplify

# Pass eq_solve=True to enable the equation solving rules
ga = GeometricAlgebra(
signature=[1.0, 1.0],
eq_solve=True,
)
# Pass eq_solve=True to enable the equation solving rules.
# Add a cost to variable to it gets rewritten to something else.
ga = GeometricAlgebra(signature=[1.0, 1.0], eq_solve=True, costs={"variable": 1_000})

e_0, e_1 = ga.basis_vectors
e_01 = e_0 * e_1
Expand Down

0 comments on commit e970a0a

Please sign in to comment.