This repository contains the code accompanying my master's thesis, an interdisciplinary work at the intersection of data science, deep learning and theoretical molecular chemistry. You can find the PDF of the thesis in the repository or at the official TU Vienna thesis repository.
Master thesis: Jraph Attention neTworks (JAT), a deep learning architecture to predict the potential energy and forces of organic molecules and ionic liquids.
- to predict the potential energy and atomic forces of molecules
- operates within the message passing neural networks (MPNN) framework [2]
- extends the NeuralIL [1] architecture and implementation
- adapts Graph Attention Networks (GAT) to replace fingerprint features
The JAT architecture takes as input the 3N Cartesian coordinates and N types (atomic species) for a molecule with N atoms. With these, the model
-
generates a sparse molecular graph
-
iteratively refines the type embeddings h^0 with multiple message passing layers
-
using the dynamic linear attention function (GATv2) [4] as the message function
(masked multi-headed self-attention),
-
a weighted sum of features and skip connection as the update function,
-
a pyramidal regression head as the readout function
-
and finally takes a sum over all atomic contributions to obtain the potential energy.
The architecture uses message passing neural networks (MPNN) [2] with an attentional update function (linear dynamic attention, GATv2) [4] by adapting Graph Attention Networks (GAT) [3] to the domain of computational chemistry. The name JAT (JraphAttentionNetworks) derives from adapting Graph Attention Networks in JAX and builds upon the Jraph library.
The JAT code and architecture was developed during the master's thesis at TU Vienna with the department of Theoretical Materials Chemistry under supervision of Dr. Jesús Carrete Montaña. The NeuralIL [1] implementation serves as a baseline implementation and reuses code to train the model.
In this thesis, I've
- built a deep learning architecture by adapting a state-of-the-art DL approach (Graph Attention Networks) to the domain of computational chemistry
- adapted the NeuralIL codebase (which uses conventional and computationally expensive fingerprint features and a shallow neural network)
- performed an extensive literature review surveying the state-of-the-art in multiple fields (computational chemistry, graph-based/geometric deep learning, attention & Transformer-based architectures) to extract the most promising approaches
- optimized, debugged and trained the architecture on a small dataset of ionic liquids
- scaled the architecture to the very large ANI-1 dataset
- while optimizing for efficiency of the architecture and achieving a 4x speedup over the supervisors' baseline with comparable accuracy.
Also available as a poster here!
- Domain: Computational chemistry
- Goal: Predict the energy of molecules
- Task: Learn the potential energy surface (PES)
- atomic positions → energy mapping
- high-dimensional task (3N - 6 degrees of freedom for N atoms)
- Problem: "Exponential barrier" of electronic contribution
- Solution: Approximate the Schrödinger Equation efficient Deep Neural Networks
- Obtain 3Natoms forces as gradient of energy w.r.t. positions
- Use atomic forces to integrate Newton's equations of motion
- requires very small timestep Δt ~ fs, (10^-15 s)
- 1M+ steps to watch something interesting happen...
- thus 1M+ model evaluations, making efficiency crucial
- model & its gradient for value (energy) and grad (forces)
- Adapt NeuralIL implementation [1]
- replace spherical Bessel descriptors with a Message Passing Neural Network [2]
- using Graph Attention Networks [3]
- on a sparse molecular graph
- replace computationally expensive descriptors
- implemented in JAX leveraging a GPU
- just-in-time (JIT) compilation
- sparse graph for sparse self-attention O( N^2 ) → O( N * max(Nneighbours) ) for fully connected graph
- mask for edges & atoms for JIT static shape
- invariant to rotation, translation and permutation
- additive ansatz: sum of atomic contributions
- trains on energy or forces
- locality for MD parallelization
- smooth nonlinearity (Swish-1) and loss function (log-cosh)
src > jat_model.py > JatCore
src > jat_model.py > JatModel
Visualization of the entire JAT architecture. Using the positions and species (types) as input, the graph generator generates a sparse edge list, with which T message passing steps (here 4) using an attentional update function are performed. The features
The readout head transform the features
JatModel is a Wrapper around the JatCore model to calculate the potential energy (model.calc_potential_energy
) or atomic forces (model.calc_forces
).
src > jat_model.py > GraphGenerator
Visualization of the graph generator component of the JAT architecture. Using the Cartesian coordinates (positions) of all atoms, the pairwise distance matrix is calculated using the Euclidean
src > jat_model.py > JatLayer
Visualization of a single JAT layer, which performs a single round of message passing to update the node feature vectors. The features
src > jat_model.py > JatLayer.attention()
Visualization of the attention mechanism of the JAT architecture. For every
Clone the repository, create an environment using conda and install dependencies using pip.
git clone https://github.com/stefanhoedl/JAT_potential
cd JAT_potential
conda create -e JAT
conda activate JAT
pip install .
After successful installation, run scripts with python3 ean/train_JAT_EAN.py
.
If you have a GPU, install a matching CUDA-supported jaxlib version:
Download jaxlib 0.3.10+cuda11.cudnn805 here to match installed jax, cuda & cudnn versions
ean > train_JAT_EAN.py - minimal training script, val only
ean > full_JAT_training.py - full script with val+test, logging, loading
ean > load_run_JAT_EAN.py - script to load trained model & predict
ean > configurations.json - EAN dataset
ean > models >
JAT_EAN15_ep3K.pickle - Model weights after 3K training epochs
Script to train JAT on EAN dataset
- Ionic liquid: anion-cation pairs (salt) in liquid phase at room temperature
- 15 EAN pairs → 225 atoms
- Copyrights for EAN dataset from NeuralIL authors
- sampled from OPLS-AA MD trajectory
- Reference energy & forces from DFT
- 741 configurations
- training on atomic forces
Download ANI-1 dataset here and unzip to data directory
ani1 > train_JAT_ANI1.py - full ANI-1 training script
ani1 > ANI-1_release - ANI-1 data directory
> ani_gdb_s01.h5 - Data files, 1-heavy atom subset
Script to train JAT on ANI-1 dataset
clone https://github.com/AIRI-Institute/nablaDFT into ./nablaDFT
download hamiltonian_databases:
"dataset_train_2k": "https://sc.link/2ZAA"
"dataset_test_conformations_2k": "https://sc.link/0ZyN"
Load data (positions, types, energy, forces) from db, pad (mask) up to static maximum and train the JAT model on the dataset.
nablaDFT > train_nabla.py - training script for JAT
nablaDFT > train_nabla - script to only load & prep dataset
nablaDFT > data > dataset_train_2k
> dataset_test_2k_conformers
[1] Hadrián Montes-Campos, Jesús Carrete, Sebastian Bichelmaier, Luis M Varela, and Georg KH Madsen. A differentiable neural-network force field for ionic liquids. Journal of chemical information and modeling, 62(1):88–101, 2021
[2] Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural message passing for quantum chemistry. In International conference on machine learning, pages 1263–1272. PMLR, 2017.
[3] Petar Velickovic, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. Graph attention networks. stat, 1050:20, 2017.
[4] Brody, Shaked, Uri Alon, and Eran Yahav. "How attentive are graph attention networks?." arXiv preprint arXiv:2105.14491 (2021).
If you find the thesis or code useful, please cite the following:
@mastersthesis{Hoedl2022,
title={Sparse graph attention networks as efficient ionic liquid potentials},
author={H{\"o}dl, Stefan},
year={2022},
school={Technische Universit{\"a}t Wien},
doi={https://doi.org/10.34726/hss.2022.98004},
url={https://github.com/stefanhoedl/JAT_potential},
}
NeuralIL code and Paper