Skip to content

Commit

Permalink
Initial version 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
hoelzerC committed Aug 21, 2024
0 parents commit d45fd2c
Show file tree
Hide file tree
Showing 39 changed files with 2,910 additions and 0 deletions.
408 changes: 408 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# ConfRank

[![Python](https://img.shields.io/badge/python-3.11.5-blue.svg)](https://www.python.org)
[![code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

This is the official repository of the `ConfRank` project developed by the [Grimme](https://www.chemie.uni-bonn.de/grimme) and [Fraunhofer SCAI-VMD](https://www.scai.fraunhofer.de/en/business-research-areas/virtual-material-design.html) groups in Bonn.


<div align="center">
<img src="./assets/logo.png" alt="ConfRank" width="600">
</div>


# Software Setup

You can create the environment using the following command:

```bash
bash setup_environment.sh
```

To activate the virtual environment simply run:

```bash
conda activate confrank
```
The current setup is tested with python version 3.11.5 and CUDA 11.8.


# Data

The data is available under: [https://zenodo.org/records/13354132](https://zenodo.org/records/13354132)


# Citations

When using or referencing to the `ConfRank` project please cite:
- **tbd**


# License

[![CC BY NC 4.0][cc-by-nc-image]][cc-by-nc]

This work is licensed under a
[Creative Commons Attribution-NonCommercial 4.0 International License][cc-by-nc].


[cc-by-nc]: http://creativecommons.org/licenses/by-nc/4.0/
[cc-by-nc-image]: https://i.creativecommons.org/l/by-nc/4.0/88x31.png
Binary file added assets/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
316 changes: 316 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
"""
Example script for training models in either pointwise or pairwise fashion
"""

import sys
import os

sys.path.append("../")

import torch
import mlflow
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import MLFlowLogger
from torch_geometric.loader import DataLoader
from argparse import ArgumentParser

from src.training.lightning import LightningWrapper
from src.models.DimeNetPP import DimeNetPP
from src.models.SchNet import SchNet
from src.models.MACE import MACE
from src.data import ConfRankDataset, PairDataset
from src.transform import (
Scale,
Rename,
RadiusGraph,
PipelineTransform,
)
from src.util.deployment import save_model

# parse command line inputs
parser = ArgumentParser()
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="Absolute path to directory with ConfRank dataset.",
)
parser.add_argument(
"--model",
type=str,
required=True,
choices=["dimenet", "mace", "schnet", "gemnet-T"],
)
parser.add_argument("--cutoff", type=float, required=True)
parser.add_argument(
"--pairwise",
type=lambda x: x.lower() == "true",
required=False,
default=False,
choices=[True, False],
)

args = parser.parse_args()

mlflow.set_experiment(experiment_name=f"Train models")
mlflow.pytorch.autolog()

# compensate that pairwise training results in more model updates per epoch by making some values dependent on args.pairwise
training_hyperparams = {
"max_epochs": 100 if args.pairwise else 1000,
"lr": 1e-3,
"weight_decay": 1e-8,
"batch_size": 20 if args.pairwise else 40,
"stopping_patience": 5 if args.pairwise else 15,
"decay_patience": 3 if args.pairwise else 10,
"decay_factor": 0.5,
"energy_key": "total_energy_ref",
"forces_key": None,
"forces_tradeoff": 0.0,
"lowest_k": 1, # always sample the lowest_k pairs with the lowest energies
"additional_k": 19, # in addition, sample additional_k from the remaining datapoints,
"precision": 64,
"trainset_path": [f"{args.data_dir}/confrank_train{i}.h5" for i in range(1, 9)],
"testset_path": [f"{args.data_dir}/confrank_test.h5"],
"seed": 42,
"pairwise": args.pairwise,
}

exclude_keys = [
"add._restraining",
"tlist",
"xb",
"imet",
"vtors",
"hbl_e",
"vbond",
"bpair",
"vangl",
"nb",
"hbl",
"blist",
"alist",
"total_charge",
"xb_e",
"hbb",
"uid",
"hbb_e",
"total_energy_gfn2",
"dispersion_energy",
"bonded_atm_energy",
"repulsion_energy",
"hb_energy",
"electrostat_energy",
"bond_energy",
"angle_energy",
"external_energy",
"torsion_energy",
"xb_energy",
]

energy_loss_fn = lambda x, y: torch.nn.functional.l1_loss(x, y)

if training_hyperparams["precision"] == 32:
dtype = torch.float32
elif training_hyperparams["precision"] == 64:
dtype = torch.float64
else:
raise Exception("Precision must be either 32 or 64")

with mlflow.start_run() as run:
pl.seed_everything(seed=training_hyperparams["seed"], workers=True)

r2scan_atom_refs = {
1: -312.0427605689065,
6: -23687.220998505094,
7: -34221.8360905642,
8: -47026.572451837295,
9: -62579.24268115989,
14: -181528.62693507367,
15: -214078.44768832004,
16: -249752.85985328682,
17: -288725.9515678963,
35: -1615266.7419546635,
53: -186814.76788476118,
} # in kcal/mol

# select model and corresponding hyperparameters:
mlflow.log_param("model", args.model)
if args.model == "dimenet":
model_hyperparams = {
"hidden_channels": 48,
"num_blocks": 3,
"int_emb_size": 32,
"basis_emb_size": 5,
"out_emb_channels": 32,
"num_spherical": 5,
"num_radial": 6,
"cutoff": args.cutoff,
}
gnn = DimeNetPP(**model_hyperparams).to(dtype)
elif args.model == "mace":
model_hyperparams = dict(
r_max=args.cutoff,
num_bessel=8,
num_polynomial_cutoff=6,
max_ell=2,
num_interactions=3,
hidden_irreps="32x0e + 32x1o",
MLP_irreps="32x0e",
atomic_energies={key: 0.0 for key, val in r2scan_atom_refs.items()},
correlation=3,
)
gnn = MACE(**model_hyperparams).to(dtype)
elif args.model == "schnet":
model_hyperparams = dict(
cutoff=args.cutoff,
hidden_channels=128,
num_filters=64,
num_interactions=3,
num_gaussians=50,
)
gnn = SchNet(**model_hyperparams)
elif args.model == "gemnet-T":
raise NotImplementedError(
"Currently not supported due to License incompatibility."
)
else:
raise Exception

transform = PipelineTransform(
[
Scale(scaling={"grad_ref": -1.0}),
Rename(key_mapping={"grad_ref": "forces"}),
RadiusGraph(cutoff=args.cutoff),
]
)
if training_hyperparams["pairwise"]:
trainset, valset, _ = PairDataset(
path_to_hdf5=training_hyperparams["trainset_path"],
sample_pairs_randomly=True,
transform=transform,
lowest_k=training_hyperparams["lowest_k"],
additional_k=training_hyperparams["additional_k"],
dtype=dtype,
).split_by_ensemble(0.92, 0.08, 0.0)

else:
dsets = [
ConfRankDataset(path_to_hdf5=path, transform=transform, dtype=dtype)
for path in training_hyperparams["trainset_path"]
]
_trainset = torch.utils.data.ConcatDataset(dsets)
trainset, valset = torch.utils.data.random_split(_trainset, [0.92, 0.08])

gnn.set_constant_energies(
energy_dict={key: val for key, val in r2scan_atom_refs.items()}, freeze=False
)

lightning_module = LightningWrapper(
model=gnn,
energy_key=training_hyperparams["energy_key"],
forces_key=training_hyperparams["forces_key"],
forces_tradeoff=training_hyperparams["forces_tradeoff"],
atomic_numbers_key="z",
decay_factor=training_hyperparams["decay_factor"],
decay_patience=training_hyperparams["decay_patience"],
energy_loss_fn=energy_loss_fn,
weight_decay=training_hyperparams["weight_decay"],
xy_lim=None,
pairwise=training_hyperparams["pairwise"],
)

testset = PairDataset(
path_to_hdf5=training_hyperparams["testset_path"],
sample_pairs_randomly=True,
transform=transform,
lowest_k=training_hyperparams["lowest_k"],
additional_k=training_hyperparams["additional_k"],
dtype=dtype,
)

train_loader = DataLoader(
trainset,
batch_size=training_hyperparams["batch_size"],
shuffle=True,
drop_last=True,
exclude_keys=exclude_keys,
)

val_loader = DataLoader(
valset,
batch_size=training_hyperparams["batch_size"],
exclude_keys=exclude_keys,
drop_last=False,
)

test_loader = DataLoader(
testset,
batch_size=training_hyperparams["batch_size"],
exclude_keys=exclude_keys,
drop_last=False,
)

monitor_metric = f"ptl/val_loss_{'pairwise' if args.pairwise else 'single'}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor=monitor_metric, save_top_k=3
)

early_stop_callback = EarlyStopping(
monitor=monitor_metric,
min_delta=0.0,
patience=training_hyperparams["stopping_patience"],
verbose=True,
mode="min",
)

lr_monitor = LearningRateMonitor(logging_interval="epoch")
callbacks = [checkpoint_callback, early_stop_callback, lr_monitor]

mlf_logger = MLFlowLogger(run_id=run.info.run_id)

for key, val in training_hyperparams.items():
mlflow.log_param(key, val)

for key, val in model_hyperparams.items():
mlflow.log_param(key, val)

mlflow.log_param("len_trainset", len(trainset))
mlflow.log_param("len_valset", len(valset))
mlflow.log_param("len_testset", len(testset))
mlflow.log_param("num_params", sum(p.numel() for p in gnn.parameters()))

trainer = pl.Trainer(
max_epochs=training_hyperparams["max_epochs"],
enable_progress_bar=True,
callbacks=callbacks,
logger=mlf_logger,
log_every_n_steps=200,
accelerator="gpu" if torch.cuda.is_available() else None,
devices=1 if torch.cuda.is_available() else None,
precision=training_hyperparams["precision"],
inference_mode=True if training_hyperparams["forces_key"] is None else False,
# allow inference mode but only if no force computation is done. For force computation, inference mode must be False,
)

trainer.fit(
lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader
)

# save best model checkpoint
best_model_path = checkpoint_callback.best_model_path
ckpt = torch.load(best_model_path)
lightning_module.load_state_dict(ckpt["state_dict"])
best_model = lightning_module.model
run_id = mlflow.active_run().info.run_id
experiment_id = mlflow.active_run().info.experiment_id
default_root_dir = f"mlruns/{experiment_id}/{run_id}"
model_path = os.path.join(default_root_dir, f"best_model.{args.model}")
save_model(best_model, model_path)
mlflow.log_artifact(model_path)

# always run tests in pairwise mode
lightning_module.pairwise = True
trainer.test(lightning_module, ckpt_path="best", dataloaders=test_loader)
18 changes: 18 additions & 0 deletions setup_environment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

# either use conda or micromamba
mcba(){
if command -v micromamba &>/dev/null; then
micromamba "$@"
else
conda "$@"
fi
}
mcba --version

mcba clean -a -y
mcba create -n confrank python=3.11.5
mcba activate confrank
mcba install pytorch=2.1.0 torchvision torchaudio pytorch-cuda=11.8 lightning=2.1.1 torchmetrics=1.2.0 -c pytorch -c nvidia -c conda-forge
pip install torch-cluster==1.6.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.1.0+cu118.html --no-cache
pip install torch_geometric==2.5.0 h5py==3.10.0 seaborn==0.13.0 rdkit==2023.09.5 mace-torch==0.3.4 mlflow==2.9.1 black[d] numba pytest --no-cache
1 change: 1 addition & 0 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .datasets import ConfRankDataset, PairDataset
Loading

0 comments on commit d45fd2c

Please sign in to comment.