Skip to content

Annotated implementations of equivariant graph neural networks in Jax: EGNN, SEGNN, NequIP.

License

Notifications You must be signed in to change notification settings

smsharma/eqnn-jax

Repository files navigation

$E(3)$ Equivariant Graph Neural Networks in Jax

License: CC BY 4.0 Run Tests

Implementation of $E(3)$ equivariant graph neural networks in Jax.

Models

The following equivariant models are implemented:

Additionally, the following non-equivariant models are implemented:

Requirements and tests

To install requirements:

pip install -r requirements.txt

To run tests (testing equivariance and periodic boundary conditions):

cd tests
pytest .

Basic usage and examples

See notebooks/examples.ipynb for example usage of GNN, SEGNN, NequIP, and EGNN.

Cosmological benchmark

The cosmological benchmarking dataset, available in TFRecord format, can be downloaded from Zenodo under the DOI 10.5281/zenodo.11479419. To download the dataset into benchmarks/galaxies/quijote_records, run:

bash benchmarks/galaxies/download_tfrecords.sh

To run the graph-level task:

python benchmarks/galaxies/train_cosmology.py

To run the node-level task:

python benchmarks/galaxies/train_velocities.py

Attribution

See CITATION.cff for citation information. The implementation of SEGNN was partially inspired by segnn-jax.