This repository is the official implementation of the OM-Diff paper.
This repository builds on the lightning-hydra-template
.
Explanations on e.g. how the config files are resolved can be found in the original template.
After creating an environment and installing the requirements, the repo can either be installed as a package or be used
without being installed by adding the path to root of the project to the PYTHONPATH
environment variable.
A .env
file has to be placed at the root of the project. It is used to contain environment variables that can be
accessed in the config files.
An example is provided in .env.example
.
At this point, it is only used for defining:
PROJECT_ROOT
: pointing to the root of the project;SCRATCH_PATH
: pointing to where the logs/checkpoints etc. should be dumped;SCRATCH_COMPUTE_PATH
: pointing to where the dataset should be copied during training (e.g. a fast-access disk on a compute node).
There is a unique entry point for training the different models: src/train.py
.
A training run can be launched using:
python src/train.py experiment=<experiment-name>
where <experiment-name>
is the name of the experiment to be run.
The description of the experiment to be run should be placed in config/experiment/<experiment-name>.yaml
.
A couple of experiment config files are provided and allow to:
- train a diffusion model;
- train a time-conditioned regressor;
- train a property predictor.
All experiments are by default run on the cross-coupling dataset investigated in the paper.
python src/train.py experiment=train_diffusion_suzuki
python src/train.py experiment=train_time_regressor_suzuki
python src/train.py experiment=train_regressor_suzuki
While a lot of the provided code/configs is overly specific to the experiments reported in the paper, training using your own data should be straightforward.
Provided that you have saved the data in an ase
database, you can easily interface to it using the
generic ASEDataset
and ASEDataModule
.
This can be done using src/sample_from_ckpt.py
.
This can be done using src/cond_sample_from_ckpt.py
.
This can be done using src/predict_on_samples.py
.