Skip to content

Commit

Permalink
Merge pull request #26 from choderalab/split-comb-calcs
Browse files Browse the repository at this point in the history
Rework Combination class
  • Loading branch information
kaminow authored Nov 1, 2023
2 parents 54c94b0 + 58c43f2 commit 6f6d8e8
Show file tree
Hide file tree
Showing 13 changed files with 873 additions and 270 deletions.
74 changes: 74 additions & 0 deletions README_COMBINATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
As of v0.4.0 the `Combination` class has been reworked to be able to run on normal sized
GPUs. Due to the size of the all-atom protein-ligand complex representation, storing all
of the autograd computation graphs for every pose used all the GPU memory. By splitting
the gradient math up into a function of the gradient from each pose, we can reduce the
need to store more than one comp graph at a time. This document contains the derivation
of the split up math.

# `MSE Loss`
```math
L = (\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}})^2
```
```math
\frac{\partial L}{\partial \theta} = 2(\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}}) \frac{\partial \Delta G_{\mathrm{pred}} \left ( \theta \right )}{\partial \theta}
```

# `MeanCombination`
Just take the mean of all preds, so the gradient is straightforward:
```math
\Delta G(\theta) = \frac{1}{N} \sum_{n=1}^{N} \Delta G_n (\theta)
```
```math
\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{N} \sum_{n=1}^{N} \frac{\partial \Delta G_n (\theta)}{\partial \theta}
```

# `MaxCombination`
Combine according to a smooth max approximation using LSE:
```math
\Delta G(\theta) = \frac{-1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))
```
```math
Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))
```
```math
\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))} \sum_{n=1}^N \left[ \frac{\partial \Delta G_n (\theta)}{\partial \theta} \mathrm{exp} (-t \Delta G_n (\theta)) \right]
```
```math
\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\mathrm{exp}(Q)} \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
```
```math
\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) - Q \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
```
# `BoltzmannCombination`
Combine according to Boltzmann weighting:
```math
\Delta G(\theta) = \sum_{n=1}^{N} w_n \Delta G_n (\theta)
```

```math
w_n = \mathrm{exp} \left[ -\Delta G_n (\theta) - \mathrm{ln} \sum_{i=1}^N \mathrm{exp} (-\Delta G_i (\theta)) \right]
```

```math
Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta))
```

```math
\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \frac{\partial w_n}{\partial \theta} \Delta G_n (\theta) + w_n \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
```

```math
\frac{\partial w_n}{\partial \theta} = \mathrm{exp} \left[ -\Delta G_n (\theta) - Q \right] \left[ \frac{-\partial \Delta G_n (\theta)}{\partial \theta} - \frac{\partial Q}{\partial \theta} \right]
```

```math
\frac{\partial Q}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta))} \sum_{i=1}^{N} \left[ \mathrm{exp} (-\Delta G_i (\theta)) \frac{-\partial \Delta G_i (\theta)}{\partial \theta} \right]
```

```math
\frac{\partial Q}{\partial \theta} = \frac{-1}{\mathrm{exp} (Q)} \sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta)) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
```

```math
\frac{\partial Q}{\partial \theta} = -\sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta) - Q) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
```
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- dgllife
- dgl
- rdkit
- ase
# testing dependencies
- pytest
- pytest-cov
Expand Down
3 changes: 2 additions & 1 deletion environment-gpu.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: mtenn-gpu
channels:
- conda-forge
- dglteam
dependencies:
- pytorch
- pytorch-gpu
Expand All @@ -14,3 +13,5 @@ dependencies:
- e3nn
- dgllife
- dgl
- rdkit
- ase
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: mtenn
channels:
- conda-forge
- dglteam
dependencies:
- pytorch
- pytorch_geometric
Expand All @@ -13,3 +12,5 @@ dependencies:
- e3nn
- dgllife
- dgl
- rdkit
- ase
Loading

0 comments on commit 6f6d8e8

Please sign in to comment.