diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0accb0e --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,58 @@ +name: CI + +on: + pull_request: + push: + branches: + - master + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 1 + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: earthly/actions-setup@v1 + with: + version: v0.7.19 + - uses: actions/checkout@v2 + - name: Put back the git branch into git (Earthly uses it for tagging) + run: | + branch="" + if [ -n "$GITHUB_HEAD_REF" ]; then + branch="$GITHUB_HEAD_REF" + else + branch="${GITHUB_REF##*/}" + fi + git checkout -b "$branch" || true + - name: Run tests + run: earthly +test + - name: Run examples + run: earthly +test-examples + publish: + runs-on: ubuntu-latest + steps: + - uses: earthly/actions-setup@v1 + with: + version: v0.7.19 + - uses: actions/checkout@v2 + - name: Put back the git branch into git (Earthly uses it for tagging) + run: | + branch="" + if [ -n "$GITHUB_HEAD_REF" ]; then + branch="$GITHUB_HEAD_REF" + else + branch="${GITHUB_REF##*/}" + fi + git checkout -b "$branch" || true + - name: Publish test + run: earthly --secret PYPI_TOKEN=${{ secrets.test_pypi_password }} --ci +publish --REPOSITORY=testpypi + - name: Publish + if: contains(github.ref, 'master') + run: earthly --secret PYPI_TOKEN=${{ secrets.pypi_password }} --ci +publish --REPOSITORY=pypi diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d1f2963 --- /dev/null +++ b/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +/render +/render.pdf +/render.html diff --git a/Earthfile b/Earthfile new file mode 100644 index 0000000..b69b02e --- /dev/null +++ b/Earthfile @@ -0,0 +1,60 @@ +VERSION 0.6 + +base-python: + FROM python:3.8 + + # Install Poetry + ENV PIP_CACHE_DIR /var/cache/buildkit/pip + RUN --mount=type=cache,target=$PIP_CACHE_DIR \ + pip install poetry==1.6.1 + RUN --mount=type=cache,target=$PIP_CACHE_DIR \ + poetry config virtualenvs.create false + + # Install graphviz which the tests use + RUN apt-get update && apt-get install -y graphviz && apt-get clean + +build: + FROM +base-python + + WORKDIR /app + + # Copy poetry files + COPY pyproject.toml poetry.lock README.md . + + # We only want to install the dependencies once, so if we copied + # our code here now, we'd reinstall the dependencies ever ytime + # the code changes. Instead, comment out the line making us depend + # on our code, install, then copy our code and install again + # with the line not commented. + RUN sed -e '/packages/ s/^#*/#/' -i pyproject.toml + + # Install dependencies + RUN poetry install + + # Copy without the commented out packages line and install again + COPY --dir egga . + COPY pyproject.toml . + RUN poetry install + +test: + FROM +build + + # Run tests + COPY --dir tests . + RUN poetry run pytest -n auto + +test-examples: + FROM +build + + # Run examples + COPY --dir examples . + FOR example IN $(ls examples/**/*.py) + RUN poetry run python "$example" + END + +publish: + RUN --mount=type=cache,target=$PIP_CACHE_DIR \ + --secret PYPI_TOKEN=+secrets/PYPI_TOKEN \ + poetry publish \ + --build --skip-existing -r $REPOSITORY \ + -u __token__ -p $PYPI_TOKEN diff --git a/README.md b/README.md new file mode 100644 index 0000000..9dcdb68 --- /dev/null +++ b/README.md @@ -0,0 +1,157 @@ +# E-Graph Geometric Algebra (EGGA) + +Symbolic [Geometric Algebra](https://en.wikipedia.org/wiki/Geometric_algebra) with [E-Graphs](https://egraphs-good.github.io/) + +Things you can do with this library + +- Simplify expressions +- Prove equalities +- Solve for variables + +Things that are supported + +- Any signature +- Arbitrary number of basis vectors +- Symplectic Geometric Algebra (aka Weyl Algebras) +- Derivatives +- Add your own expression types and rules (with egglog) + +Based on the [Python bindings](https://github.com/metadsl/egglog-python) for [egglog](https://github.com/egraphs-good/egglog) + +## Setup + +Supports Python 3.8 and higher. + +`pip install egga` + +## Usage + +The first step is to create a `GeometricAlgebra` object with a given signature. +You can then use its basis vectors as well as functions exposed by it. Use the utility methods provided to do things like simplification and +equation solving. In some cases you might need to interface with egglog directly. Below are +some examples for common use-cases. + +Simplification + +```python +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import simplify + +ga = GeometricAlgebra(signature=[1.0, 1.0]) +e_0, e_1 = ga.basis_vectors +e_01 = e_0 * e_1 + +# Build an expression to simplify +expr = e_01 * e_0 * ~e_01 + +# Prints Simplified: -e("0") +print("Simplified:", simplify(ga, expr)) +``` + +Equation solving + +```python +from egglog import union + +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import simplify + +# Pass eq_solve=True to enable the equation solving rules +ga = GeometricAlgebra( + signature=[1.0, 1.0], + eq_solve=True, +) + +e_0, e_1 = ga.basis_vectors +e_01 = e_0 * e_1 + +# Solve e_01 * x * ~e_01 = e_0 for x +x = ga.expr_cls.variable("x") +lhs = e_01 * x * ~e_01 +rhs = -e_0 + +# Make LHS equal to RHS +ga.egraph.register(union(lhs).with_(rhs)) + +assert str(simplify(ga, x)) == str(ga.expr_cls.e("0")) +``` + +Equality check + +```python +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import check_equality + +ga = GeometricAlgebra(signature=[1.0, 1.0]) +e_0, e_1 = ga.basis_vectors +e_01 = e_0 * e_1 + +# Build an lhs to check for equality to an rhs +lhs = e_01 * e_01 +rhs = ga.expr_cls.scalar_literal(-1.0) + +assert check_equality(ga, lhs, rhs) +``` + +The [/examples](examples) as well as the [/tests](tests) directories contain more examples. +The [/experimental_examples](experimental_examples) directory has more examples but are not tested as part +of continuous integration, so they can be outdated or broken. + +## List of expressions + +### Operators + +| Code | Description | +| ------------ | ------------------------------------------------------------------------------------- | +| `x_1 + x_2` | Addition of x_1 and x_2 | +| `x_1 - x_2` | Subtraction of x_1 and x_2 | +| `x_1 * x_2` | Multiplication of x_1 and x_2 (aka the Geometric Product) | +| `x_1 ^ x_2` | Wedge / exterior / outer product of x_1 and x_2 | +| `x_1 \| x_2` | Inner ("fat dot") product of x_1 and x_2 | +| `-x_1` | Negation of x_1 | +| `~x_1` | Reversion of x_1 | +| `x_1 ** x_2` | x_1 to the power of x_2 | +| `x_1 / x_2` | x_1 divided by x_2 (more generally, x_1 right-multiplied by the right inverse of x_2) | + +### Functions + +| Code | Description | +| ------------------------ | ------------------------------------------------------------------------------------------------------------------------- | +| `inverse(x)` | Right-multiplicative inverse of x | +| `scalar(x)` | Mark x as a scalar | +| `scalar_literal(f)` | Create a scalar constant | +| `scalar_variable(s)` | Create a scalar variable | +| `e(s)` | Basis vector | +| `e2(s1, s2)` | Basis bivector | +| `e3(s1, s2, s3)` | Basis trivector | +| `variable(s)` | Create a variable | +| `cos(x)` | Cos of x | +| `sin(x)` | Sin of x | +| `cosh(x)` | Cosh of x | +| `sinh(x)` | Sinh of x | +| `exp(x)` | Exponential function of x | +| `grade(x)` | Grade of x | +| `mix_grades(x_1, x_2)` | Represents the mixture of two grades. If the grades of x_1 and x_2 are the same, this will be simplified to `grade(x_1)`. | +| `select_grade(x_1, x_2)` | Selects the grade x_2 part of x_1 | +| `abs(x)` | Absolute value of x | +| `rotor(x_1, x_2)` | Shorthand for `exp(scalar_literal(-0.5) * scalar(x_1) * x_2)` | +| `sandwich(x_1, x_2)` | Shorthand for `x_1 * x_2 * ~x_1` | +| `diff(x_1, x_2)` | Derivative of x_1 with respect to x_2 | + +### Unsupported but exists, might or might not work + +| Code | Description | +| -------------------- | ------------------------------ | +| `boolean(x)` | Mark x as a boolean | +| `x_1.equal(x_2)` | Whether x_1 equals x_2 | +| `x_1.not_equal(x_2)` | Whether x_1 does not equal x_2 | + +## Caveats + +- Egraphs are bad with associativity (combined with commutativity?) so things can blow up +- Most operations aren't "fully" implemented (eg. `pow` only supports powers of two right now) + +## Contributing + +Code contributions as well as suggestions and comments about things that don't work yet are appreciated. +You can reach me by email at `tora@warlock.ai` or in the [Bivector Discord](https://discord.gg/vGY6pPk). diff --git a/egga/expression.py b/egga/expression.py new file mode 100644 index 0000000..8317fda --- /dev/null +++ b/egga/expression.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import Protocol + +from egglog import StringLike, f64Like, i64Like + + +class Expression(Protocol): + def __ne__(self, other: Expression) -> Expression: + ... + + def __eq__(self, other: Expression) -> Expression: + ... + + def equal(self, other: Expression) -> Expression: + ... + + def not_equal(self, other: Expression) -> Expression: + ... + + def __add__(self, other: Expression) -> Expression: + ... + + def __sub__(self, other: Expression) -> Expression: + ... + + def __mul__(self, other: Expression) -> Expression: + ... + + def __neg__(self) -> Expression: + ... + + def __invert__(self) -> Expression: + ... + + def __pow__(self, other: Expression) -> Expression: + ... + + def __truediv__(self, other: Expression) -> Expression: + ... + + def __xor__(self, other: Expression) -> Expression: + ... + + def __or__(self, other: Expression) -> Expression: + ... + + @staticmethod + def inverse(other: Expression) -> Expression: + ... + + @staticmethod + def boolean(b: i64Like) -> Expression: + ... + + @staticmethod + def scalar(value: Expression) -> Expression: + ... + + @staticmethod + def scalar_literal(value: f64Like) -> Expression: + ... + + @staticmethod + def scalar_variable(value: StringLike) -> Expression: + ... + + @staticmethod + def e(s: StringLike) -> Expression: + ... + + @staticmethod + def e2(s_1: StringLike, s_2: StringLike) -> Expression: + ... + + @staticmethod + def e3(s_1: StringLike, s_2: StringLike, s_3: StringLike) -> Expression: + ... + + @staticmethod + def variable(name: StringLike) -> Expression: + ... + + @staticmethod + def cos(value: Expression) -> Expression: + ... + + @staticmethod + def sin(value: Expression) -> Expression: + ... + + @staticmethod + def cosh(value: Expression) -> Expression: + ... + + @staticmethod + def sinh(value: Expression) -> Expression: + ... + + @staticmethod + def exp(value: Expression) -> Expression: + ... + + @staticmethod + def grade(value: Expression) -> Expression: + ... + + @staticmethod + def mix_grades(a: Expression, b: Expression) -> Expression: + ... + + @staticmethod + def select_grade(value: Expression, grade: Expression) -> Expression: + ... + + @staticmethod + def abs_(value: Expression) -> Expression: + ... + + @staticmethod + def rotor(basis_blade: Expression, angle: Expression) -> Expression: + ... + + @staticmethod + def sandwich(r: Expression, x: Expression) -> Expression: + ... + + @staticmethod + def diff(value: Expression, wrt: Expression) -> Expression: + ... diff --git a/egga/geometric_algebra.py b/egga/geometric_algebra.py new file mode 100644 index 0000000..d1464e5 --- /dev/null +++ b/egga/geometric_algebra.py @@ -0,0 +1,811 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from itertools import combinations +from typing import Type + +from egglog import ( + BaseExpr, + EGraph, + String, + StringLike, + egraph, + eq, + f64, + f64Like, + i64, # noqa: F401 + i64Like, + rewrite, + rule, + set_, + union, + var, + vars_, +) +from egglog import ( + egraph as egglog_egraph, +) + +from egga.expression import Expression + +birewrite = egraph.birewrite + + +@dataclass +class GeometricAlgebraRulesets: + full: egglog_egraph.Ruleset + medium: egglog_egraph.Ruleset + fast: egglog_egraph.Ruleset + + +@dataclass +class GeometricAlgebra: + egraph: EGraph + expr_cls: Type[Expression] + rulesets: GeometricAlgebraRulesets + signature: tuple[float] + + @property + def basis_vectors(self): + return [self.expr_cls.e(str(i)) for i in range(len(self.signature))] + + @property + def symplectic_dual_basis_vectors(self): + return [self.expr_cls.e(f"{i}*") for i in range(len(self.signature))] + + def __init__( + self, + signature: tuple[float], + symplectic=False, + eq_solve=False, + costs: dict[str, int] | None = None, + ): + if costs is None: + costs = {} + + egraph = EGraph() + + @egraph.class_ + class MathExpr(BaseExpr): + def equal(self, other: MathExpr) -> MathExpr: + ... + + def not_equal(self, other: MathExpr) -> MathExpr: + ... + + def __add__(self, other: MathExpr) -> MathExpr: + ... + + def __sub__(self, other: MathExpr) -> MathExpr: + ... + + def __mul__(self, other: MathExpr) -> MathExpr: + ... + + def __neg__(self) -> MathExpr: + ... + + def __invert__(self) -> MathExpr: + ... + + def __pow__(self, other: MathExpr) -> MathExpr: + ... + + def __truediv__(self, other: MathExpr) -> MathExpr: + ... + + def __xor__(self, other: MathExpr) -> MathExpr: + ... + + def __or__(self, other: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("inverse")) + def inverse(other: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("boolean")) + def boolean(b: i64Like) -> MathExpr: + ... + + @egraph.function(cost=costs.get("scalar")) + def scalar(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("scalar_literal")) + def scalar_literal(value: f64Like) -> MathExpr: + ... + + @egraph.function(cost=costs.get("scalar_variable")) + def scalar_variable(value: StringLike) -> MathExpr: + ... + + @egraph.function(cost=costs.get("e")) + def e(s: StringLike) -> MathExpr: + ... + + @egraph.function(cost=costs.get("e2")) + def e2(s_1: StringLike, s_2: StringLike) -> MathExpr: + ... + + @egraph.function(cost=costs.get("e3")) + def e3(s_1: StringLike, s_2: StringLike, s_3: StringLike) -> MathExpr: + ... + + @egraph.function(cost=costs.get("variable")) + def variable(name: StringLike) -> MathExpr: + ... + + @egraph.function(cost=costs.get("cos")) + def cos(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("sin")) + def sin(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("cosh")) + def cosh(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("sinh")) + def sinh(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("exp")) + def exp(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("grade")) + def grade(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("mix_grades")) + def mix_grades(a: MathExpr, b: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("select_grade")) + def select_grade(value: MathExpr, grade: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("abs")) + def abs_(value: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("rotor")) + def rotor(basis_blade: MathExpr, angle: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("sandwich")) + def sandwich(r: MathExpr, x: MathExpr) -> MathExpr: + ... + + @egraph.function(cost=costs.get("diff")) + def diff(value: MathExpr, wrt: MathExpr) -> MathExpr: + ... + + MathExpr.inverse = inverse + MathExpr.boolean = boolean + MathExpr.scalar = scalar + MathExpr.scalar_literal = scalar_literal + MathExpr.scalar_variable = scalar_variable + MathExpr.e = e + MathExpr.e2 = e2 + MathExpr.e3 = e3 + MathExpr.variable = variable + MathExpr.cos = cos + MathExpr.sin = sin + MathExpr.cosh = cosh + MathExpr.sinh = sinh + MathExpr.exp = exp + MathExpr.grade = grade + MathExpr.mix_grades = mix_grades + MathExpr.select_grade = select_grade + MathExpr.abs = abs_ + MathExpr.rotor = rotor + MathExpr.sandwich = sandwich + MathExpr.diff = diff + + x_1, x_2, x_3 = vars_("x_1 x_2 x_3", MathExpr) + f_1, f_2 = vars_("f_1 f_2", f64) + s_1, s_2, s_3 = vars_("s_1 s_2 s_3", String) + + orig_rewrite = rewrite + orig_birewrite = birewrite + orig_rule = rule + + def set_active_ruleset(ruleset): + global rewrite, birewrite, rule + rewrite = partial(orig_rewrite, ruleset=ruleset) + birewrite = partial(orig_birewrite, ruleset=ruleset) + rule = partial(orig_rule, ruleset=ruleset) + + def register_addition(full=True, medium=True): + egraph.register( + # Identity + rewrite(x_1 + scalar_literal(0.0)).to(x_1), + ) + if medium: + egraph.register( + # Comm + rewrite(x_1 + x_2).to(x_2 + x_1), + # Assoc + # rewrite(x_1 + (x_2 + x_3)).to((x_1 + x_2) + x_3), + ) + if full: + egraph.register( + # Assoc + # rewrite((x_1 + x_2) + x_3).to(x_1 + (x_2 + x_3)), + birewrite(x_1 + (x_2 + x_3)).to((x_1 + x_2) + x_3), + ) + + if eq_solve: + egraph.register( + rule(eq(x_3).to(x_1 + x_2)).then(union(x_1).with_(x_3 - x_2)), + ) + + def register_subtraction(medium=True): + egraph.register( + # Sub self is zero + rewrite(x_1 - x_1).to(scalar_literal(0.0)), + # (a + b) - (a + c) = b - c + rewrite((x_1 + x_2) - (x_1 + x_3)).to(x_2 - x_3), + ) + + if medium: + egraph.register( + # Add negation is subtraction + birewrite(x_1 + -x_2).to(x_1 - x_2), + ) + + def register_negation(): + egraph.register( + # Involute + rewrite(--x_1).to(x_1), + # This makes things blow up + # birewrite(-x_1).to(scalar_literal(-1.0) * x_1), + ) + + def register_multiplication(full=True, medium=True): + egraph.register( + # Identity + rewrite(scalar_literal(1.0) * x_1).to(x_1), + # Zero + rewrite(scalar_literal(0.0) * x_1).to(scalar_literal(0.0)), + # x + x = 2x + birewrite(x_1 + x_1).to(scalar_literal(2.0) * x_1), + # ax + bx = (a+b)x + birewrite(scalar(x_2) * x_1 + scalar(x_3) * x_1).to( + (scalar(x_2) + scalar(x_3)) * x_1 + ), + # ax + x = (a+1)x + birewrite(scalar(x_2) * x_1 + x_1).to( + (scalar(x_2) + scalar_literal(1.0)) * x_1 + ), + ) + + if medium: + egraph.register( + # Scalar comm + birewrite(x_1 * scalar(x_2)).to(scalar(x_2) * x_1), + # Assoc + birewrite(x_1 * (x_2 * x_3)).to((x_1 * x_2) * x_3), + # Left distr + birewrite(x_1 * (x_2 + x_3)).to(x_1 * x_2 + x_1 * x_3), + # Right distr + birewrite((x_1 + x_2) * x_3).to(x_1 * x_3 + x_2 * x_3), + # Neg + birewrite(-x_1 * x_2).to(-(x_1 * x_2)), + birewrite(-x_1 * x_2).to(x_1 * -x_2), + ) + if full: + pass + + def register_wedge(medium=True): + egraph.register( + birewrite(x_1 ^ x_2).to( + select_grade(x_1 * x_2, grade(x_1) + grade(x_2)), + ), + ) + if medium: + egraph.register( + # TODO: Can we do without these rules? + # Without them, (e1 + e12) ^ e12 = e12 fails. + # Assoc + birewrite(x_1 ^ (x_2 ^ x_3)).to((x_1 ^ x_2) ^ x_3), + # Left distr + birewrite(x_1 ^ (x_2 + x_3)).to((x_1 ^ x_2) + (x_1 ^ x_3)), + # Right distr + birewrite((x_1 + x_2) ^ x_3).to((x_1 ^ x_3) + (x_2 ^ x_3)), + ) + + def register_inner(medium=True): + egraph.register( + birewrite(x_1 | x_2).to( + select_grade(x_1 * x_2, abs_(grade(x_1) - grade(x_2))), + ), + ) + if medium: + egraph.register( + # TODO: Can we do without these rules? + # Assoc + birewrite(x_1 | (x_2 | x_3)).to((x_1 | x_2) | x_3), + # Left distr + birewrite(x_1 | (x_2 + x_3)).to((x_1 | x_2) + (x_1 | x_3)), + # Right distr + birewrite((x_1 + x_2) | x_3).to((x_1 | x_3) + (x_2 | x_3)), + ) + + def register_division(medium=True): + # TODO: Left / right inverse + + egraph.register( + # Divide identity + rewrite(x_1 / scalar_literal(1.0)).to(x_1), + # Divide self + rewrite(x_1 / x_1).to(scalar_literal(1.0), x_1 != scalar_literal(0.0)), + ) + + if medium: + egraph.register( + # 1 / x is inverse of x + birewrite(scalar_literal(1.0) / x_1).to(inverse(x_1)), + # Division is right multiplication inverse + birewrite(x_1 / x_2).to(x_1 * inverse(x_2)), + # Inverse basis vector + # TODO: Figure out why this is broken x_1 * x_1 != scalar_literal(0.0) + rule(eq(x_1).to(inverse(x_2))).then( + x_2 * x_2 != scalar_literal(0.0) + ), + birewrite(inverse(x_1)).to( + x_1 / (x_1 * x_1), x_1 * x_1 != scalar_literal(0.0) + ), + ) + + if eq_solve: + egraph.register( + rule(eq(x_3).to(x_1 * x_2)).then( + x_1 * x_1 != scalar_literal(0.0), + x_2 * x_2 != scalar_literal(0.0), + ), + # Left inverse: x_3 = x_1 * x_2 -> inv(x_1) * x_3 = x_2 + rule(eq(x_3).to(x_1 * x_2), x_1 * x_1 != scalar_literal(0.0)).then( + union(x_2).with_(inverse(x_1) * x_3) + ), + # Right inverse: x_3 = x_1 * x_2 -> x_3 * inv(x_2) = x_1 + rule(eq(x_3).to(x_1 * x_2), x_2 * x_2 != scalar_literal(0.0)).then( + union(x_1).with_(x_3 * inverse(x_2)) + ), + ) + + def register_pow(medium=True): + egraph.register( + # pow zero + rewrite(x_1 ** scalar_literal(0.0)).to(scalar_literal(1.0)), + # pow one + rewrite(x_1 ** scalar_literal(1.0)).to(x_1), + ) + if medium: + egraph.register( + # expand pow 2 to mul + birewrite(x_1 ** scalar_literal(2.0)).to(x_1 * x_1), + ) + + def register_exp(medium=True): + egraph.register( + # B^2 = -1 + birewrite(exp(scalar(x_2) * x_1)).to( + cos(scalar(x_2)) + x_1 * sin(scalar(x_2)), + eq(x_1 * x_1).to(scalar_literal(-1.0)), + ), + # B^2 = 0 + birewrite(exp(scalar(x_2) * x_1)).to( + scalar_literal(1.0) + x_1 * scalar(x_2), + eq(x_1 * x_1).to(scalar_literal(0.0)), + ), + # B^2 = +1 + birewrite(exp(scalar(x_2) * x_1)).to( + cosh(scalar(x_2)) + x_1 * sinh(scalar(x_2)), + eq(x_1 * x_1).to(scalar_literal(1.0)), + ), + # Euler's formula etc. require adding B^2 so the rule can get matched. + rule(eq(x_2).to(exp(x_1))).then(x_1 * x_1), + ) + + def register_scalar(medium=True): + if medium: + egraph.register( + birewrite(scalar_literal(f_1)).to(scalar(scalar_literal(f_1))), + birewrite(scalar_variable(s_1)).to(scalar(scalar_variable(s_1))), + birewrite(scalar_variable(s_1)).to(scalar(variable(s_1))), + rewrite(scalar_variable(s_1)).to(variable(s_1)), + # Scalar + rewrite(scalar(x_1) + scalar(x_2)).to(scalar(x_1 + x_2)), + rewrite(scalar(x_1) - scalar(x_2)).to(scalar(x_1 - x_2)), + rewrite(-scalar(x_1)).to(scalar(-x_1)), + rewrite(scalar(x_1) * scalar(x_2)).to(scalar(x_1 * x_2)), + rewrite(scalar(x_1) / scalar(x_2)).to(scalar(x_1 / x_2)), + # Scalar literal + rewrite(scalar_literal(f_1) + scalar_literal(f_2)).to( + scalar_literal(f_1 + f_2) + ), + rewrite(scalar_literal(f_1) - scalar_literal(f_2)).to( + scalar_literal(f_1 - f_2) + ), + rewrite(-scalar_literal(f_1)).to(scalar_literal(-f_1)), + rewrite(scalar_literal(f_1) * scalar_literal(f_2)).to( + scalar_literal(f_1 * f_2) + ), + rewrite(scalar_literal(f_1) / scalar_literal(f_2)).to( + scalar_literal(f_1 / f_2), f_2 != 0.0 + ), + # Scalar literal - abs + rewrite(abs_(scalar_literal(f_1))).to( + scalar_literal(f_1), f_1 >= 0.0 + ), + rewrite(abs_(scalar_literal(f_1))).to( + scalar_literal(-f_1), f_1 < 0.0 + ), + ) + + def register_grade(): + # -- Idea: + # mix_grades(x, y) -> x if x == y + # grade(s) -> 0 + # grade(e(i)) -> 1 + # grade(s * e(i)) -> 1 + # grade(x + y) -> mix_grades(grade(x), grade(y)) + # -- Example: + # grade(e1 + e2 + e3) -> + # mix_grades(grade(e1), grade(e2 + e3)) -> + # mix_grades(grade(e1), mix_grades(grade(e2), grade(e3))) + + egraph.register( + # mix_grades(x, y) -> x if x == y + rewrite(mix_grades(x_1, x_2)).to(x_1, eq(x_1).to(x_2)), + # mix_grades comm + rewrite(mix_grades(x_1, x_2)).to(mix_grades(x_2, x_1)), + # negation doesn't affect grade + rewrite(grade(-x_1)).to(grade(x_1)), + # Grade is scalar + birewrite(grade(x_1)).to(scalar(grade(x_1))), + # Grade of scalar is 0, if scalar is not zero + # rule(eq(x_2).to(grade(scalar(x_1)))).then( + # x_1 != scalar_literal(0.0) + # ), + rewrite(grade(scalar_literal(f_1))).to(scalar_literal(0.0), f_1 != 0.0), + # grade(a + b) -> mix_grades(grade(a), grade(b)), if a + b is not zero + # rule(eq(x_1).to(grade(x_2 + x_3))).then( + # x_2 + x_3 != scalar_literal(0.0) + # ), + rewrite(grade(x_1 + x_2)).to( + mix_grades(grade(x_1), grade(x_2)), + # x_1 + x_2 != scalar_literal(0.0), + ), + # grade(s * x) -> grade(x) if s != 0 + # With scalar coef, if scalar coef is not zero + # rule(eq(x_1).to(scalar(x_2) * x_3)).then( + # x_2 != scalar_literal(0.0) + # ), + rewrite(grade(scalar_literal(f_1) * x_2)).to(grade(x_2), f_1 != 0.0), + ) + + # Basis blade grades + for blade_grade in range(1, len(signature)): + basis_blade = None + basis_vector_names = [var(f"s_{i}", String) for i in range(blade_grade)] + basis_vectors = [e(name) for name in basis_vector_names] + for basis_vector in basis_vectors: + if basis_blade is None: + basis_blade = basis_vector + else: + basis_blade *= basis_vector + conds = [] + for name_1, name_2 in combinations(basis_vector_names, 2): + conds.append(name_1 != name_2) + egraph.register( + rewrite(grade(basis_blade)).to( + scalar_literal(float(blade_grade)), *conds + ) + ) + + # Select grade + egraph.register( + # select_grade(x + y, z) -> select_grade(x, z) + select_grade(y, z) + rewrite(select_grade(x_1 + x_2, x_3)).to( + select_grade(x_1, x_3) + select_grade(x_2, x_3) + ), + # select_grade(x, y) -> 0 if grade(x) != y + rule( + eq(x_2).to(select_grade(x_1, scalar_literal(f_1))), + eq(grade(x_1)).to(scalar_literal(f_2)), + f_1 != f_2, + ).then( + set_(select_grade(x_1, scalar_literal(f_1))).to(scalar_literal(0.0)) + ), + # select_grade(x, y) -> x if grade(x) == y + rule(eq(x_3).to(select_grade(x_1, scalar_literal(f_1)))).then( + grade(x_1) + ), + rewrite(select_grade(x_1, scalar_literal(f_1))).to( + x_1, eq(grade(x_1)).to(scalar_literal(f_1)) + ), + ) + + def register_basic_ga(medium=True): + basis_vectors = [e(str(i)) for i in range(len(signature))] + + egraph.register( + # e_i^2 = signature[i] + *map( + lambda e, s: rewrite(e * e).to(scalar_literal(s)), + basis_vectors, + signature, + ), + # e_i e_j = e_ij + birewrite(e(s_1) * e(s_2)).to(e2(s_1, s_2)), + birewrite(e(s_1) * e(s_2) * e(s_3)).to(e3(s_1, s_2, s_3)), + # Rotor + birewrite(exp(x_2 * scalar(x_1) * scalar_literal(-0.5))).to( + rotor(x_2, scalar(x_1)) + ), + # Sandwich + rewrite(sandwich(x_1, x_2)).to(x_1 * x_2 * ~x_1), + ) + + if medium: + egraph.register( + # e_i e_j = -e_j e_i, i != j + birewrite(-e(s_1) * e(s_2)).to(e(s_2) * e(s_1), s_1 != s_2), + # Sandwich + rewrite(x_1 * x_2 * ~x_1).to(sandwich(x_1, x_2)), + ) + + def register_symplectic_ga(medium=True): + q_vectors = [e(str(i)) for i in range(len(signature))] + p_vectors = [e(f"{i}*") for i in range(len(signature))] + + if True: + egraph.register( + # q p - p q = 1 + *map( + lambda q, p, s: birewrite(q * p - p * q).to( + scalar_literal(2 * s) + ), + q_vectors, + p_vectors, + signature, + ), + # q p = p q + 1 + *map( + lambda q, p, s: birewrite(q * p).to( + p * q + scalar_literal(2 * s) + ), + q_vectors, + p_vectors, + signature, + ), + # p q = q p - 1 + *map( + lambda q, p, s: birewrite(p * q).to( + q * p - scalar_literal(2 * s) + ), + q_vectors, + p_vectors, + signature, + ), + ) + + egraph.register( + # cutoff e_i^n = 0 + rewrite(e(s_1) * e(s_1) * e(s_1) * e(s_1)).to(scalar_literal(0.0)), + # e_i e_j = e_ij + birewrite(e(s_1) * e(s_2)).to(e2(s_1, s_2)), + birewrite(e(s_1) * e(s_2) * e(s_3)).to(e3(s_1, s_2, s_3)), + ) + + # # # q_i p_j = p_j q_i, i != j + # for i, q in enumerate(q_vectors): + # for j, p in enumerate(p_vectors): + # if i != j: + # egraph.register(birewrite(q * p).to(p * q)) + # # q_i q_j = q_j q_i, i != j + # for i, q_1 in enumerate(q_vectors): + # for j, q_2 in enumerate(q_vectors): + # if i != j: + # egraph.register(birewrite(q_1 * q_2).to(q_2 * q_1)) + # # p_i p_j = p_j p_i, i != j + # for i, p_1 in enumerate(p_vectors): + # for j, p_2 in enumerate(p_vectors): + # if i != j: + # egraph.register(birewrite(p_1 * p_2).to(p_2 * p_1)) + + def register_reverse(medium=True): + if not symplectic: + egraph.register( + rewrite(~~x_1).to(x_1), + rewrite(~scalar(x_1)).to(scalar(x_1)), + rewrite(~e(s_1)).to(e(s_1)), + ) + + if medium: + egraph.register( + birewrite(~(x_1 * x_2)).to(~x_2 * ~x_1), + birewrite(~(x_1 + x_2)).to(~x_1 + ~x_2), + ) + + def register_equality(): + egraph.register( + # Comm + birewrite(x_1.equal(x_2)).to(x_2.equal(x_1)), + birewrite(x_1.not_equal(x_2)).to(x_2.not_equal(x_1)), + # Eq / Ne + rewrite(x_1.equal(x_1)).to(boolean(True)), + rewrite(x_1.not_equal(x_1)).to(boolean(False)), + rewrite(x_1.equal(x_2)).to(boolean(False), x_1 != x_2), + rewrite(x_1.not_equal(x_2)).to(boolean(True), x_1 != x_2), + ) + + def register_trigonometry(): + egraph.register( + # scalar + birewrite(sin(x_1)).to(scalar(sin(x_1))), + birewrite(cos(x_1)).to(scalar(cos(x_1))), + birewrite(sinh(x_1)).to(scalar(sinh(x_1))), + birewrite(cosh(x_1)).to(scalar(cosh(x_1))), + # sin/cos + rewrite(cos(scalar_literal(0.0))).to(scalar_literal(1.0)), + rewrite(sin(scalar_literal(0.0))).to(scalar_literal(0.0)), + rewrite(cos(x_2) * cos(x_2) + sin(x_1) * sin(x_1)).to( + scalar_literal(1.0) + ), + rewrite(-sin(-x_1)).to(sin(x_1)), + rewrite(cos(-x_1)).to(cos(x_1)), + birewrite(cos(x_1) * cos(x_2)).to( + scalar_literal(0.5) * (cos(x_1 - x_2) + cos(x_1 + x_2)) + ), + birewrite(sin(x_1) * sin(x_2)).to( + scalar_literal(0.5) * (cos(x_1 - x_2) - cos(x_1 + x_2)) + ), + birewrite(sin(x_1) * cos(x_2)).to( + scalar_literal(0.5) * (sin(x_1 + x_2) + sin(x_1 - x_2)) + ), + birewrite(cos(x_1) * sin(x_2)).to( + scalar_literal(0.5) * (sin(x_1 + x_2) - sin(x_1 - x_2)) + ), + # sinh/cosh + rewrite(cosh(x_2) * cosh(x_2) - sinh(x_1) * sinh(x_1)).to( + scalar_literal(1.0) + ), + rewrite(-sin(-x_1)).to(sin(x_1)), + rewrite(cos(-x_1)).to(cos(x_1)), + ) + + def register_diff(medium=True): + # TODO: maybe add constant() to unify scalar_literal and e? + + if medium: + # Linearity + egraph.register( + # Addition + birewrite(diff(x_1 + x_2, x_3)).to(diff(x_1, x_3) + diff(x_2, x_3)), + # Constant multiplication + birewrite(diff(scalar_literal(f_1) * x_2, x_3)).to( + scalar_literal(f_1) * diff(x_2, x_3) + ), + birewrite(diff(e(s_1) * x_2, x_3)).to(e(s_1) * diff(x_2, x_3)), + birewrite(diff(x_2 * e(s_1), x_3)).to(diff(x_2, x_3) * e(s_1)), + ) + + # Concrete derivatives + egraph.register( + # wrt self + # TODO: Fix not constant condition + # rewrite(diff(x_1, x_1)).to( + # scalar_literal(1.0)#, x_1 != scalar_literal(f_1) + # ), + rewrite(diff(variable(s_1), variable(s_1))).to(scalar_literal(1.0)), + # wrt other + # rewrite(diff(variable(s_1), variable(s_2))).to( + # scalar_literal(0.0), s_1 != s_2 + # ), + # constant: scalar_literal + rewrite(diff(scalar_literal(f_1), x_1)).to(scalar_literal(0.0)), + # constant: e + rewrite(diff(e(s_1), x_1)).to(scalar_literal(0.0)), + # x * y + rewrite(diff(x_1 * x_2, x_3)).to( + diff(x_1, x_3) * x_2 + x_1 * diff(x_2, x_3) + ), + # sin(x) + rewrite(diff(sin(x_1), x_2)).to(cos(x_1) * diff(x_1, x_2)), + # cos(x) + rewrite(diff(cos(x_1), x_2)).to(-sin(x_1) * diff(x_1, x_2)), + # sinh(x) + rewrite(diff(sinh(x_1), x_2)).to(cosh(x_1) * diff(x_1, x_2)), + # cosh(x) + rewrite(diff(cosh(x_1), x_2)).to(sinh(x_1) * diff(x_1, x_2)), + # reverse + rewrite(diff(~x_1, x_2)).to(~diff(x_1, x_2)), + # negative + rewrite(diff(-x_1, x_2)).to(-diff(x_1, x_2)), + # square + rewrite(diff(x_1 * x_1, x_1)).to(scalar_literal(2.0) * x_1), + ) + + full_ruleset = egraph.ruleset("full") + medium_ruleset = egraph.ruleset("medium") + fast_ruleset = egraph.ruleset("fast") + + set_active_ruleset(full_ruleset) + register_addition() + register_negation() + register_subtraction() + register_multiplication() + register_division() + register_pow() + register_trigonometry() + register_exp() + register_scalar() + register_equality() + register_grade() + if symplectic: + register_symplectic_ga() + else: + register_basic_ga() + register_wedge() + register_inner() + register_reverse() + register_diff() + + set_active_ruleset(medium_ruleset) + register_addition(full=False) + register_negation() + register_subtraction() + register_multiplication(full=False) + register_division() + register_pow() + register_trigonometry() + register_exp() + register_scalar() + register_equality() + register_grade() + if symplectic: + register_symplectic_ga() + else: + register_basic_ga() + register_wedge() + register_inner() + register_reverse() + register_diff() + + set_active_ruleset(fast_ruleset) + register_addition(medium=False, full=False) + register_negation() + register_subtraction(medium=False) + register_multiplication(medium=False, full=False) + register_division(medium=False) + register_pow(medium=False) + register_trigonometry() + register_exp(medium=False) + register_scalar(medium=False) + register_equality() + register_grade() + if symplectic: + register_symplectic_ga(medium=False) + else: + register_basic_ga(medium=False) + register_wedge(medium=False) + register_inner(medium=False) + register_reverse(medium=False) + register_diff(medium=False) + + self.egraph = egraph + self.expr_cls = MathExpr + self.rulesets = GeometricAlgebraRulesets( + full=full_ruleset, + medium=medium_ruleset, + fast=fast_ruleset, + ) + self.signature = signature diff --git a/egga/utils.py b/egga/utils.py new file mode 100644 index 0000000..83de9e1 --- /dev/null +++ b/egga/utils.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import Literal, Optional, Union + +from egglog import Fact, egraph, eq, run + +from egga.expression import Expression +from egga.geometric_algebra import GeometricAlgebra + + +@dataclass +class RunRulesetOptions: + ruleset: Union[Literal["fast", "medium", "full"], egraph.Ruleset] = "full" + limit: int = 10 + until: Optional[Fact] = None + + +def run_ruleset(ga: GeometricAlgebra, options: RunRulesetOptions = RunRulesetOptions()): + if isinstance(options.ruleset, str): + ruleset = ga.rulesets.__dict__[options.ruleset] + else: + ruleset = options.ruleset + + ga.egraph.run( + run( + ruleset, + options.limit, + *([options.until] if options.until is not None else []), + ).saturate() + ) + + +@dataclass +class RunScheduledOptions: + limit: int = 10 + full_interval: int = 4 + fast_count: int = 4 + until: Optional[Fact] = None + + +def run_scheduled( + ga: GeometricAlgebra, options: RunScheduledOptions = RunScheduledOptions() +): + def do_step(ruleset: egraph.Ruleset, step_limit: int): + run_ruleset( + ga, + options=RunRulesetOptions( + ruleset=ruleset, limit=step_limit, until=options.until + ), + ) + + step = 0 + while step < options.limit: + # Fast step + do_step(ruleset=ga.rulesets.fast, step_limit=options.fast_count) + + # Medium or full step + do_step( + ruleset=ga.rulesets.full + if step % options.full_interval != 0 + else ga.rulesets.medium, + step_limit=1, + ) + step += 1 + + +def simplify( + ga: GeometricAlgebra, + expression: Expression, + options: RunScheduledOptions = RunScheduledOptions(), +) -> Expression: + """ + Returns a simplified expression. + """ + ga.egraph.push() + + ga.egraph.register(expression) + run_scheduled(ga=ga, options=options) + simplified = ga.egraph.extract(expression) + + ga.egraph.pop() + + return simplified + + +@dataclass +class CheckEqualityOptions: + limit: int = 10 + full_interval: int = 4 + fast_count: int = 4 + equal: bool = True + + +def check_equality( + ga: GeometricAlgebra, + lhs: Expression, + rhs: Expression, + options: CheckEqualityOptions = CheckEqualityOptions(), +) -> bool: + """ + Returns whether two expressions are equal. + """ + ga.egraph.push() + + ga.egraph.register(lhs) + ga.egraph.register(rhs) + predicate = eq(lhs).to(rhs) + run_scheduled( + ga=ga, + options=RunScheduledOptions( + limit=options.limit, + full_interval=options.full_interval, + fast_count=options.fast_count, + until=predicate if options.equal else None, + ), + ) + check_fn = ga.egraph.check if options.equal else ga.egraph.check_fail + check_passed = True + try: + check_fn(predicate) + except: + check_passed = False + + ga.egraph.pop() + return check_passed diff --git a/examples/basic_equality_check.py b/examples/basic_equality_check.py new file mode 100644 index 0000000..453fc3b --- /dev/null +++ b/examples/basic_equality_check.py @@ -0,0 +1,12 @@ +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import check_equality + +ga = GeometricAlgebra(signature=[1.0, 1.0]) +e_0, e_1 = ga.basis_vectors +e_01 = e_0 * e_1 + +# Build an lhs to check for equality to an rhs +lhs = e_01 * e_01 +rhs = ga.expr_cls.scalar_literal(-1.0) + +assert check_equality(ga, lhs, rhs) diff --git a/examples/basic_equation_solving.py b/examples/basic_equation_solving.py new file mode 100644 index 0000000..21ecb2b --- /dev/null +++ b/examples/basic_equation_solving.py @@ -0,0 +1,23 @@ +from egglog import union + +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import simplify + +# Pass eq_solve=True to enable the equation solving rules +ga = GeometricAlgebra( + signature=[1.0, 1.0], + eq_solve=True, +) + +e_0, e_1 = ga.basis_vectors +e_01 = e_0 * e_1 + +# Solve e_01 * x * ~e_01 = e_0 for x +x = ga.expr_cls.variable("x") +lhs = e_01 * x * ~e_01 +rhs = -e_0 + +# Make LHS equal to RHS +ga.egraph.register(union(lhs).with_(rhs)) + +assert str(simplify(ga, x)) == str(ga.expr_cls.e("0")) diff --git a/examples/basic_simplification.py b/examples/basic_simplification.py new file mode 100644 index 0000000..7ac0cbf --- /dev/null +++ b/examples/basic_simplification.py @@ -0,0 +1,12 @@ +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import simplify + +ga = GeometricAlgebra(signature=[1.0, 1.0]) +e_0, e_1 = ga.basis_vectors +e_01 = e_0 * e_1 + +# Build an expression to simplify +expr = e_01 * e_0 * ~e_01 + +# Prints Simplified: -e("0") +print("Simplified:", simplify(ga, expr)) diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..7423e66 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,299 @@ +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. + +[[package]] +name = "black" +version = "23.9.1" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"}, + {file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"}, + {file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"}, + {file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"}, + {file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"}, + {file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"}, + {file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"}, + {file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"}, + {file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"}, + {file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"}, + {file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "egglog" +version = "0.5.1" +description = "e-graphs in Python built around the the egglog rust library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "egglog-0.5.1-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:9884318657ea62f163c488ff42746dd0de5cf93cc59a28bb674033839442da67"}, + {file = "egglog-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61e3c72c9c9c3353b1f83310ac862469be802d21cac6c18417b6a5661bf74f6b"}, + {file = "egglog-0.5.1-cp310-none-win_amd64.whl", hash = "sha256:e4aee50f82f865ca52a9bdc4b8042783014961389a6e6a7833c10af2850b0890"}, + {file = "egglog-0.5.1-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:08853a6e75b834a58ce4fe09734cabb26b92b89f30117bf6df1e02df526404df"}, + {file = "egglog-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac9f7d0a9f15b2a31ef23c9ff404e69443b59e8ed9735dbf73797290a66cf6f1"}, + {file = "egglog-0.5.1-cp311-none-win_amd64.whl", hash = "sha256:9d3595296a8f2d5178459fcce02b12caf03935b32e3ee09d5dfffea521c4654d"}, + {file = "egglog-0.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31fe46d31b098e6c241c4938935ba27a467be7e8aea0eee6aa9a649e6b7fd521"}, + {file = "egglog-0.5.1-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e8c10af8bce4830d8722a3e6d3977e2717587b711185328710c84ef1c52eb821"}, + {file = "egglog-0.5.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba5137221c2d26c67cfcd05bdb38b6030b32f849359d33471005798b1ca6850c"}, + {file = "egglog-0.5.1-cp37-none-win_amd64.whl", hash = "sha256:88b42b8c0522f4c7a54a9a20fb11eabd87070288a9bdb2cbe3a5a1fa1338acee"}, + {file = "egglog-0.5.1-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0624fd6646bc53c2e0214c5eb0da3617c88cf8d7ee2bb7e0d40781a7581f602e"}, + {file = "egglog-0.5.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3665a3f39fc87d4c3e347db57b5053f83287d51318555734bfed89df9eaef61e"}, + {file = "egglog-0.5.1-cp38-none-win_amd64.whl", hash = "sha256:ca074fa8fb5c186d4d3f1ff5e5803ec3577552ed9485823b812517104c62bcf6"}, + {file = "egglog-0.5.1-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fb1579c598ea4ad484ad9c238583b66caa841a5bd7f1d9df7f159998e730ced4"}, + {file = "egglog-0.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:606b1fcdc4f092399506715792b0c2a61d441330a5952f74e8f81f7376d4867c"}, + {file = "egglog-0.5.1-cp39-none-win_amd64.whl", hash = "sha256:66f91d18f60cea88d8042d2f897e61b4094ad44c6610fc25748840a54c2f8abb"}, + {file = "egglog-0.5.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7624e603336f3b4f7e5efdbd469720e1e3f7b47b54e1d0832b51dabd664575c"}, + {file = "egglog-0.5.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8f1523a7d0c4c6eee96369a42c89b37f052e92ecb6bcc3858bd6b0b08000b65"}, + {file = "egglog-0.5.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6eeb5ac2d11a476b5db16e0732a8455a8f53cebcb2e364b12fbbfb57bb3dec5"}, + {file = "egglog-0.5.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:136acffe72fd618d6d98a7e13d69adf543ede28ba11d8383e9ef15213e4ce005"}, + {file = "egglog-0.5.1.tar.gz", hash = "sha256:e114da890ad4a0ec78d207d38d9757f157e7d605247a3738aebfd475f97a6114"}, +] + +[package.dependencies] +black = "*" +graphviz = "*" +typing-extensions = "*" + +[package.extras] +dev = ["black", "flake8", "isort", "mypy", "pre-commit"] +docs = ["matplotlib", "myst-nb", "nbconvert", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinx-gallery"] +test = ["mypy", "pytest"] + +[[package]] +name = "exceptiongroup" +version = "1.1.3" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, + {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "execnet" +version = "2.0.2" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.7" +files = [ + {file = "execnet-2.0.2-py3-none-any.whl", hash = "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41"}, + {file = "execnet-2.0.2.tar.gz", hash = "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + +[[package]] +name = "graphviz" +version = "0.20.1" +description = "Simple Python interface for Graphviz" +optional = false +python-versions = ">=3.7" +files = [ + {file = "graphviz-0.20.1-py3-none-any.whl", hash = "sha256:587c58a223b51611c0cf461132da386edd896a029524ca61a1462b880bf97977"}, + {file = "graphviz-0.20.1.zip", hash = "sha256:8c58f14adaa3b947daf26c19bc1e98c4e0702cdc31cf99153e6f06904d492bf8"}, +] + +[package.extras] +dev = ["flake8", "pep8-naming", "tox (>=3)", "twine", "wheel"] +docs = ["sphinx (>=5)", "sphinx-autodoc-typehints", "sphinx-rtd-theme"] +test = ["coverage", "mock (>=4)", "pytest (>=7)", "pytest-cov", "pytest-mock (>=3)"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +name = "packaging" +version = "23.1" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, + {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, +] + +[[package]] +name = "pathspec" +version = "0.11.2" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, + {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, +] + +[[package]] +name = "platformdirs" +version = "3.10.0" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = ">=3.7" +files = [ + {file = "platformdirs-3.10.0-py3-none-any.whl", hash = "sha256:d7c24979f292f916dc9cbf8648319032f551ea8c49a4c9bf2fb556a02070ec1d"}, + {file = "platformdirs-3.10.0.tar.gz", hash = "sha256:b45696dab2d7cc691a3226759c0d3b00c47c8b6e293d96f6436f733303f77f6d"}, +] + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] + +[[package]] +name = "pluggy" +version = "1.3.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, + {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pytest" +version = "7.4.2" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.2-py3-none-any.whl", hash = "sha256:1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002"}, + {file = "pytest-7.4.2.tar.gz", hash = "sha256:a766259cfab564a2ad52cb1aae1b881a75c3eb7e34ca3779697c23ed47c47069"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-xdist" +version = "3.3.1" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-xdist-3.3.1.tar.gz", hash = "sha256:d5ee0520eb1b7bcca50a60a518ab7a7707992812c578198f8b44fdfac78e8c93"}, + {file = "pytest_xdist-3.3.1-py3-none-any.whl", hash = "sha256:ff9daa7793569e6a68544850fd3927cd257cc03a7ef76c95e86915355e82b5f2"}, +] + +[package.dependencies] +execnet = ">=1.1" +pytest = ">=6.2.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "typing-extensions" +version = "4.8.0" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, + {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, +] + +[metadata] +lock-version = "2.0" +python-versions = "^3.8" +content-hash = "1677f7f88c065da95eb9ef9316be3023175bef6cd466f730773de00cddca00fe" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9754e36 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,56 @@ +[tool.poetry] +name = "egga" +version = "0.1.0" +description = "Symbolic Geometric Algebra with E-Graphs " +authors = ["Robin Kahlow "] +readme = "README.md" +packages = [{ include = "egga" }] +homepage = "https://github.com/RobinKa/egga" +repository = "https://github.com/RobinKa/egga" +license = "MIT" +keywords = [ + "geometric-algebra", + "clifford-algebra", + "simplification", + "equations", + "egraph", + "egglog", + "multi-vector", + "para-vector", + "mathematics", +] +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + + "Intended Audience :: Education", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Physics", + "Topic :: Scientific/Engineering :: Mathematics", + + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +[tool.poetry.dependencies] +python = "^3.8" +egglog = "^0.5.1" + +[tool.poetry.group.dev.dependencies] +pytest = "^7.4.0" +pytest-xdist = "^3.3.1" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..0408ac1 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +/renders \ No newline at end of file diff --git a/tests/test_rules.py b/tests/test_rules.py new file mode 100644 index 0000000..5929bbc --- /dev/null +++ b/tests/test_rules.py @@ -0,0 +1,349 @@ +import pytest +from egglog import eq + +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import ( + CheckEqualityOptions, + RunRulesetOptions, + check_equality, + run_ruleset, +) + +ga = GeometricAlgebra([0.0, 1.0, 1.0, 1.0, -1.0]) +E = ga.expr_cls + +e_0, e_1, e_2, e_3, e_4 = ga.basis_vectors + +e_12 = e_1 * e_2 +e_23 = e_2 * e_3 +e_31 = e_3 * e_1 +e_123 = e_1 * e_2 * e_3 + + +@pytest.mark.parametrize( + "equation", + [ + (-e_1 * e_2, e_2 * e_1), + (e_1 * e_1, E.scalar_literal(1.0)), + (-e_1 * e_2 * e_1, e_2), + (~(e_1 * e_2), -e_1 * e_2), + (e_12 * e_1 * ~e_12, -e_1), + (e_23 * e_1 * ~e_23, e_1), + (~(e_12 + e_23), -e_12 - e_23), + ((e_12 + e_23) * e_1, -e_2 + e_123), + (~(e_12 + e_23 + e_31), -e_12 - e_23 - e_31), + ((e_12 + e_23 + e_31) * e_1, -e_2 + e_123 + e_3), + (e_1 * ~(e_12 + e_23 + e_31), -e_2 - e_123 + e_3), + (e_23 ** E.scalar_literal(1.0), e_23), + (e_23 * e_23, E.scalar_literal(-1.0)), + (e_1.equal(E.scalar_literal(1.0) * e_1), E.boolean(True)), + (e_1.not_equal(E.scalar_literal(1.0) * e_1), E.boolean(False)), + (e_1.equal(-E.scalar_literal(1.0) * e_1), E.boolean(False)), + (e_1.not_equal(-E.scalar_literal(1.0) * e_1), E.boolean(True)), + (E.inverse(e_1), e_1), + (E.inverse(e_12), -e_12), + (E.scalar_literal(1.0) / e_12, -e_12), + (E.scalar_literal(5.0) / e_12, -E.scalar_literal(5.0) * e_12), + (e_123 / e_12, e_3), + (e_12 / e_3, e_123), + (e_1 ^ e_2, e_12), + (e_1 ^ e_1, E.scalar_literal(0.0)), + (e_1 | e_2, E.scalar_literal(0.0)), + (e_1 | e_1, E.scalar_literal(1.0)), + (e_0 * e_0, E.scalar_literal(0.0)), + (e_4 * e_4, E.scalar_literal(-1.0)), + (e_1 * e_4 * e_1 * e_4, E.scalar_literal(1.0)), + ( + E.exp(E.scalar_variable("x") * e_23), + E.cos(E.scalar_variable("x")) + e_23 * E.sin(E.scalar_variable("x")), + ), + ( + E.exp(e_23), + E.cos(E.scalar_literal(1.0)) + e_23 * E.sin(E.scalar_literal(1.0)), + ), + (E.exp(e_0 * e_1), E.scalar_literal(1.0) + e_0 * e_1), + (E.exp(e_0 * e_4), E.scalar_literal(1.0) + e_0 * e_4), + ( + E.exp(e_1 * e_4), + E.cosh(E.scalar_literal(1.0)) + e_1 * e_4 * E.sinh(E.scalar_literal(1.0)), + ), + (E.scalar_literal(22.0) * e_1, e_1 * E.scalar_literal(22.0)), + ( + E.scalar(E.sin(E.scalar_literal(22.0))) * e_1, + e_1 * E.scalar(E.sin(E.scalar_literal(22.0))), + ), + ( + E.sin(E.scalar_literal(1.0)) * e_1, + e_1 * E.sin(E.scalar_literal(1.0)), + ), + ((e_1 + e_12) ^ e_2, e_12), + (e_123 ^ e_2, E.scalar_literal(0.0)), + (e_1 ^ e_2, e_123, False), + (E.grade(e_1), E.scalar_literal(1.0)), + (E.grade(e_1), E.scalar_literal(0.0), False), + ( + E.select_grade(e_1 * e_2, E.scalar_literal(2.0)), + e_1 * e_2, + ), + (E.grade(-e_1), E.scalar_literal(1.0)), + (E.grade(E.scalar_literal(-1.0) * e_1), E.scalar_literal(1.0)), + (E.grade(e_1 ^ e_2), E.scalar_literal(2.0)), + (E.grade(e_1 ^ e_2), E.scalar_literal(1.0), False), + (E.grade(e_1 ^ e_2), E.scalar_literal(0.0), False), + (e_23 ^ e_1, e_123), + (E.scalar_literal(1.0), E.scalar_literal(2.0), False), + (e_1 ^ e_2, E.scalar_literal(0.0), False), + ( + E.select_grade(e_12, E.scalar_literal(2.0)), + E.scalar_literal(0.0), + False, + ), + ( + E.select_grade(E.scalar_literal(5.0) * e_1 * e_2, E.scalar_literal(2.0)), + E.scalar_literal(0.0), + False, + ), + ( + E.grade(E.scalar_literal(5.0) * e_1 * e_2), + E.scalar_literal(0.0), + False, + ), + (E.grade(E.scalar_literal(5.0) * e_1), E.scalar_literal(1.0)), + (E.scalar_literal(5.0) * e_1, E.scalar(E.scalar_literal(5.0)) * e_1), + (E.grade(E.scalar_literal(5.0) * e_1 * e_2), E.scalar_literal(2.0)), + (E.grade(e_1 * e_2), E.scalar_literal(2.0)), + (E.grade(e_1 * e_2), E.scalar_literal(0.0), False), + (E.grade(e_1), E.scalar_literal(2.0), False), + ( + E.select_grade(e_1, E.scalar_literal(1.0)), + E.scalar_literal(1.0), + False, + ), + (E.select_grade(e_1, E.scalar_literal(1.0)), e_1), + (e_1, E.scalar_literal(1.0), False), + (e_123, E.scalar_literal(0.0), False), + ( + (e_1 * E.scalar_literal(5.0)) ^ (e_2 * e_3), + e_123 * E.scalar_literal(5.0), + ), + (e_1 ^ (e_2 * e_3), e_123), + (e_1 ^ (e_2 ^ e_3), e_123), + (e_1 | e_23, E.scalar_literal(0.0)), + (e_1 * (e_2 * e_3), (e_2 * e_3) * e_1), + (e_1 * e_2 * e_3 - e_2 * e_3 * e_1, E.scalar_literal(0.0)), + (e_1 * e_2 * e_3, e_2 * e_3 * e_1), + ((E.scalar_literal(5.0) * e_12) ^ e_1, E.scalar_literal(0.0)), + ( + e_2 * (E.scalar_literal(-5.0) ^ e_1), + (e_1 * e_2) * E.scalar_literal(5.0), + ), + ( + (e_2 * e_1) * (E.scalar_literal(-5.0) ^ e_1), + E.scalar_literal(-5.0) * e_2, + ), + ((E.scalar_literal(5.0) * e_1) ^ e_1, E.scalar_literal(0.0)), + (E.abs(E.grade(e_1) - E.grade(e_2)), E.scalar_literal(0.0)), + ( + E.select_grade(e_1 * e_2, E.scalar_literal(0.0)), + E.scalar_literal(0.0), + ), + ((e_1 ^ e_2) + (e_1 | e_2), e_12), + (e_12, e_1 ^ e_2), + (e_12, (e_1 ^ e_2) + (e_1 | e_2)), + (E.grade(e_1 + e_2), E.scalar_literal(1.0)), + (E.mix_grades(E.grade(e_1), E.grade(e_2)), E.scalar_literal(1.0)), + (E.grade(e_1 + e_2 + e_3), E.scalar_literal(1.0)), + ( + E.grade(e_1 + e_23), + E.mix_grades(E.scalar_literal(1.0), E.scalar_literal(2.0)), + ), + ( + E.grade(e_1 + e_23), + E.mix_grades(E.scalar_literal(2.0), E.scalar_literal(1.0)), + ), + (E.grade(E.scalar_literal(3.0)), E.scalar_literal(0.0)), + (E.select_grade(e_1, E.scalar_literal(2.0)), E.scalar_literal(0.0)), + (E.select_grade(e_1 + e_23, E.scalar_literal(1.0)), e_1), + (E.select_grade(e_1 + e_23, E.scalar_literal(2.0)), e_23), + ( + E.select_grade(E.scalar_literal(5.0) + e_1 + e_23, E.scalar_literal(0.0)), + E.scalar_literal(5.0), + ), + (e_1 ^ e_2 ^ e_3, e_1 * e_2 * e_3), + ( + e_1 ^ e_2 ^ (E.scalar_literal(5.0) * e_3), + E.scalar_literal(5.0) * e_1 * e_2 * e_3, + ), + (E.abs(-E.scalar_literal(2.0)), E.scalar_literal(2.0)), + (E.abs(E.grade(e_1) - E.grade(e_23)), E.scalar_literal(1.0)), + # (E.inverse(e_0), e_0, False), + (e_1 ^ e_2, E.select_grade(e_1 * e_2, E.grade(e_1) + E.grade(e_2))), + (e_1 ^ e_2, E.select_grade(e_1 * e_2, E.scalar_literal(2.0))), + ( + E.select_grade(e_1 * e_2, E.scalar_literal(0.0)), + e_1 * e_2, + False, + ), + ( + E.diff(E.scalar_literal(5.0), E.variable("x")), + E.scalar_literal(0.0), + ), + ( + E.diff(E.variable("x"), E.scalar_literal(5.0)), + E.scalar_literal(0.0), + False, + ), + ( + E.diff(E.scalar_literal(5.0), E.scalar_literal(5.0)), + E.scalar_literal(0.0), + ), + (E.diff(E.variable("x"), E.variable("x")), E.scalar_literal(1.0)), + # (E.diff(E.variable("x"), E.variable("y")), E.scalar_literal(0.0)), + ( + E.diff(E.variable("x") + E.variable("y"), E.variable("z")), + E.diff(E.variable("x"), E.variable("z")) + + E.diff(E.variable("y"), E.variable("z")), + ), + ( + E.diff(E.variable("x") * E.variable("y"), E.variable("z")), + E.diff(E.variable("x"), E.variable("z")) * E.variable("y") + + E.variable("x") * E.diff(E.variable("y"), E.variable("z")), + ), + # ( + # E.diff( + # E.variable("x") * e_1 + E.variable("y") * e_2, E.variable("x") + # ), + # e_1, + # ), + # ( + # E.diff( + # e_1 * E.variable("x") + e_2 * E.variable("y"), E.variable("x") + # ), + # e_1, + # ), + (E.diff(e_12 * E.variable("x"), E.variable("x")), e_12), + ( + E.diff( + e_1 * E.variable("x") * e_1, + E.variable("x"), + ), + E.scalar_literal(1.0), + ), + ( + E.diff(E.sin(E.variable("x")), E.variable("x")), + E.cos(E.variable("x")), + ), + ( + E.diff(E.cos(E.variable("x")), E.variable("x")), + -E.sin(E.variable("x")), + ), + ( + E.diff(E.sinh(E.variable("x")), E.variable("x")), + E.cosh(E.variable("x")), + ), + ( + E.diff(E.cosh(E.variable("x")), E.variable("x")), + E.sinh(E.variable("x")), + ), + # ( + # E.diff( + # E.rotor(e_12, E.scalar_variable("phi")) + # * e_1 + # * ~E.rotor(e_12, E.scalar_variable("phi")), + # E.scalar_variable("phi"), + # ), + # # R' ~R + # E.diff( + # E.rotor(e_12, E.scalar_variable("phi")), + # E.scalar_variable("phi"), + # ) + # * ~E.rotor(e_12, E.scalar_variable("phi")), + # ), + ( + E.rotor(e_12, E.scalar_variable("phi")) + * ~E.rotor(e_12, E.scalar_variable("phi")), + E.scalar_literal(1.0), + ), + ( + E.sin(E.variable("x")) * E.sin(E.variable("x")), + E.scalar_literal(0.5) + * (E.scalar_literal(1.0) - E.cos(E.scalar_literal(2.0) * E.variable("x"))), + ), + ( + E.rotor(e_12, E.scalar_variable("phi")) + * e_1 + * ~E.rotor(e_12, E.scalar_variable("phi")), + E.cos(E.scalar_variable("phi")) * e_1 + + E.sin(E.scalar_variable("phi")) * e_2, + ), + ( + E.rotor(e_12, E.scalar_variable("phi")) + * (E.scalar_variable("x") * e_1 + E.scalar_variable("y") * e_2) + * ~E.rotor(e_12, E.scalar_variable("phi")), + e_1 + * ( + E.scalar_variable("x") * E.cos(E.scalar_variable("phi")) + - E.scalar_variable("y") * E.sin(E.scalar_variable("phi")) + ) + + e_2 + * ( + E.scalar_variable("x") * E.sin(E.scalar_variable("phi")) + + E.scalar_variable("y") * E.cos(E.scalar_variable("phi")) + ), + ), + (e_1 + e_1, E.scalar_literal(2.0) * e_1), + (E.scalar_literal(1.5) * e_1 + e_1, E.scalar_literal(2.5) * e_1), + (e_1 + E.scalar_literal(1.5) * e_1, E.scalar_literal(2.5) * e_1), + ( + E.scalar_literal(1.5) * e_1 + E.scalar_literal(5.5) * e_1, + E.scalar_literal(7.0) * e_1, + ), + ( + E.scalar_literal(1.5) * e_1 - E.scalar_literal(0.1) * e_1, + E.scalar_literal(1.4) * e_1, + ), + ( + E.scalar_literal(1.5) * e_1 + E.scalar_literal(-0.1) * e_1, + E.scalar_literal(1.4) * e_1, + ), + ], +) +def test_rules(equation): + ga.egraph.push() + + caught = None + + try: + should_be_equal = True + eq_len = len(equation) + if eq_len == 2: + lhs, rhs = equation + elif eq_len == 3: + lhs, rhs, should_be_equal = equation + else: + raise ValueError("Invalid number of values in equation tuple") + + stop_cond = None + if should_be_equal: + stop_cond = eq(lhs).to(rhs) + + ga.egraph.register(lhs, rhs) + run_ruleset(ga, options=RunRulesetOptions(limit=100, until=stop_cond)) + + # Check if LHS is (not) equal to RHS, don't run any additional steps + check_passes = check_equality( + ga, lhs, rhs, options=CheckEqualityOptions(limit=0, equal=should_be_equal) + ) + + assert check_passes + + # Check against some basic contradictions + ga.egraph.check_fail(eq(E.scalar_literal(0.0)).to(E.scalar_literal(1.0))) + ga.egraph.check_fail(eq(E.scalar_literal(1.0)).to(E.scalar_literal(2.0))) + except Exception as e: + caught = e + + ga.egraph.pop() + + if caught is not None: + raise caught diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..7d4f194 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,33 @@ +import pytest + +from egga.geometric_algebra import GeometricAlgebra +from egga.utils import CheckEqualityOptions, check_equality, simplify + + +@pytest.fixture +def ga() -> GeometricAlgebra: + return GeometricAlgebra([1.0, 1.0]) + + +@pytest.fixture +def ga_eq_solve() -> GeometricAlgebra: + return GeometricAlgebra([1.0, 1.0], eq_solve=True) + + +def test_simplify(ga: GeometricAlgebra): + simplified = simplify(ga, ga.basis_vectors[0] * ga.basis_vectors[0]) + expected = ga.expr_cls.scalar_literal(1.0) + assert str(simplified) == str(expected) + + +def test_check_equality(ga: GeometricAlgebra): + lhs = ga.basis_vectors[0] * ga.basis_vectors[0] + rhs = ga.expr_cls.scalar_literal(1.0) + not_rhs = ga.expr_cls.scalar_literal(2.0) + + assert check_equality(ga, lhs, rhs) + assert not check_equality(ga, lhs, not_rhs) + + # Check negated too + assert not check_equality(ga, lhs, rhs, options=CheckEqualityOptions(equal=False)) + assert check_equality(ga, lhs, not_rhs, options=CheckEqualityOptions(equal=False))