diff --git a/README_COMBINATION.md b/README_COMBINATION.md new file mode 100644 index 0000000..3c9a5c8 --- /dev/null +++ b/README_COMBINATION.md @@ -0,0 +1,74 @@ +As of v0.4.0 the `Combination` class has been reworked to be able to run on normal sized +GPUs. Due to the size of the all-atom protein-ligand complex representation, storing all +of the autograd computation graphs for every pose used all the GPU memory. By splitting +the gradient math up into a function of the gradient from each pose, we can reduce the +need to store more than one comp graph at a time. This document contains the derivation +of the split up math. + +# `MSE Loss` +```math +L = (\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}})^2 +``` +```math +\frac{\partial L}{\partial \theta} = 2(\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}}) \frac{\partial \Delta G_{\mathrm{pred}} \left ( \theta \right )}{\partial \theta} +``` + +# `MeanCombination` +Just take the mean of all preds, so the gradient is straightforward: +```math +\Delta G(\theta) = \frac{1}{N} \sum_{n=1}^{N} \Delta G_n (\theta) +``` +```math +\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{N} \sum_{n=1}^{N} \frac{\partial \Delta G_n (\theta)}{\partial \theta} +``` + +# `MaxCombination` +Combine according to a smooth max approximation using LSE: +```math +\Delta G(\theta) = \frac{-1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta)) +``` +```math +Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta)) +``` +```math +\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))} \sum_{n=1}^N \left[ \frac{\partial \Delta G_n (\theta)}{\partial \theta} \mathrm{exp} (-t \Delta G_n (\theta)) \right] +``` +```math +\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\mathrm{exp}(Q)} \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] +``` +```math +\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) - Q \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] +``` +# `BoltzmannCombination` +Combine according to Boltzmann weighting: +```math +\Delta G(\theta) = \sum_{n=1}^{N} w_n \Delta G_n (\theta) +``` + +```math +w_n = \mathrm{exp} \left[ -\Delta G_n (\theta) - \mathrm{ln} \sum_{i=1}^N \mathrm{exp} (-\Delta G_i (\theta)) \right] +``` + +```math +Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta)) +``` + +```math +\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \frac{\partial w_n}{\partial \theta} \Delta G_n (\theta) + w_n \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] +``` + +```math +\frac{\partial w_n}{\partial \theta} = \mathrm{exp} \left[ -\Delta G_n (\theta) - Q \right] \left[ \frac{-\partial \Delta G_n (\theta)}{\partial \theta} - \frac{\partial Q}{\partial \theta} \right] +``` + +```math +\frac{\partial Q}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta))} \sum_{i=1}^{N} \left[ \mathrm{exp} (-\Delta G_i (\theta)) \frac{-\partial \Delta G_i (\theta)}{\partial \theta} \right] +``` + +```math +\frac{\partial Q}{\partial \theta} = \frac{-1}{\mathrm{exp} (Q)} \sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta)) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] +``` + +```math +\frac{\partial Q}{\partial \theta} = -\sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta) - Q) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] +``` diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index a953835..bde3c8a 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -13,6 +13,7 @@ dependencies: - dgllife - dgl - rdkit + - ase # testing dependencies - pytest - pytest-cov diff --git a/environment-gpu.yml b/environment-gpu.yml index 67b4f51..de93784 100644 --- a/environment-gpu.yml +++ b/environment-gpu.yml @@ -1,7 +1,6 @@ name: mtenn-gpu channels: - conda-forge - - dglteam dependencies: - pytorch - pytorch-gpu @@ -14,3 +13,5 @@ dependencies: - e3nn - dgllife - dgl + - rdkit + - ase diff --git a/environment.yml b/environment.yml index 2a415ed..7a679c7 100644 --- a/environment.yml +++ b/environment.yml @@ -1,7 +1,6 @@ name: mtenn channels: - conda-forge - - dglteam dependencies: - pytorch - pytorch_geometric @@ -13,3 +12,5 @@ dependencies: - e3nn - dgllife - dgl + - rdkit + - ase \ No newline at end of file diff --git a/mtenn/combination.py b/mtenn/combination.py new file mode 100644 index 0000000..2cb4552 --- /dev/null +++ b/mtenn/combination.py @@ -0,0 +1,415 @@ +import abc +import torch + + +class Combination(torch.nn.Module, abc.ABC): + @abc.abstractmethod + def forward(self, pred_list, grad_dict, param_names, *model_params): + """ + This function signature should be the same for any Combination subclass + implementation. + + Parameters + ---------- + pred_list: List[torch.Tensor] + List of delta G predictions to be combined using LSE + grad_dict: dict[str, List[torch.Tensor]] + Dict mapping from parameter name to list of gradients + param_names: List[str] + List of parameter names + model_params: torch.Tensor + Actual parameters that we'll return the gradients for. Each param + should be passed individually for the backward pass to work right. + """ + raise NotImplementedError("Must implement the `forward` method.") + + @staticmethod + def split_grad_dict(grad_dict): + """ + Helper method used by all Combination classes to split up the passed grad_dict + for saving by context manager. + + Parameters + ---------- + grad_dict : Dict[str, List[torch.Tensor]] + Dict mapping from parameter name to list of gradients + + Returns + ------- + List[str] + Key in grad_dict corresponding 1:1 with the gradients + List[torch.Tensor] + Gradients from grad_dict corresponding 1:1 with the keys + """ + # Deconstruct grad_dict to be saved for backwards + grad_dict_keys = [ + k for k, grad_list in grad_dict.items() for _ in range(len(grad_list)) + ] + grad_dict_tensors = [ + grad for grad_list in grad_dict.values() for grad in grad_list + ] + + return grad_dict_keys, grad_dict_tensors + + @staticmethod + def join_grad_dict(grad_dict_keys, grad_dict_tensors): + """ + Helper method used by all Combination classes to reconstruct the grad_dict + from keys and grad tensors. + + Parameters + ---------- + grad_dict_keys : List[str] + Key in grad_dict corresponding 1:1 with the gradients + grad_dict_tensors : List[torch.Tensor] + Gradients from grad_dict corresponding 1:1 with the keys + + Returns + ------- + Dict[str, List[torch.Tensor]] + Dict mapping from parameter name to list of gradients + """ + # Reconstruct grad_dict + grad_dict = {} + for k, grad in zip(grad_dict_keys, grad_dict_tensors): + try: + grad_dict[k].append(grad) + except KeyError: + grad_dict[k] = [grad] + + return grad_dict + + +class MeanCombination(Combination): + """ + Combine a list of predictions by taking the mean. + """ + + def __init__(self): + super(MeanCombination, self).__init__() + + def forward(self, pred_list, grad_dict, param_names, *model_params): + return _MeanCombinationFunc.apply( + pred_list, grad_dict, param_names, *model_params + ) + + +class _MeanCombinationFunc(torch.autograd.Function): + """ + Custom autograd function that will handle the gradient math for us. + """ + + @staticmethod + def forward(pred_list, grad_dict, param_names, *model_params): + # Return mean of all preds + all_preds = torch.stack(pred_list).flatten() + final_pred = all_preds.mean(axis=None).detach() + + return final_pred + + @staticmethod + def setup_context(ctx, inputs, output): + pred_list, grad_dict, param_names, *model_params = inputs + + grad_dict_keys, grad_dict_tensors = Combination.split_grad_dict(grad_dict) + + # Save non-Tensors for backward + ctx.grad_dict_keys = grad_dict_keys + ctx.param_names = param_names + + # Save Tensors for backward + # Saving: + # * Predictions (1 tensor) + # * Grad tensors (N params * M poses tensors) + # * Model param tensors (N params tensors) + ctx.save_for_backward( + torch.stack(pred_list).flatten(), + *grad_dict_tensors, + *model_params, + ) + + @staticmethod + def backward(ctx, grad_output): + # Unpack saved tensors + preds, *other_tensors = ctx.saved_tensors + + # Split up other_tensors + grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] + + grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) + + # Calculate final gradients for each parameter + final_grads = {} + for n, grad_list in grad_dict.items(): + final_grads[n] = torch.stack(grad_list, axis=-1).mean(axis=-1) + + # Adjust gradients by grad_output + for grad in final_grads.values(): + grad *= grad_output + + # Pull out return vals + return_vals = [None] * 3 + [final_grads[n] for n in ctx.param_names] + return tuple(return_vals) + + +class MaxCombination(Combination): + """ + Approximate max/min of the predictions using the LogSumExp function for smoothness. + """ + + def __init__(self, neg=True, scale=1000.0): + """ + Parameters + ---------- + neg : bool, default=True + Negate the predictions before calculating the LSE, effectively finding + the min. Preds are negated again before being returned + scale : float, default=1000.0 + Fixed positive value to scale predictions by before taking the LSE. This + tightens the bounds of the LSE approximation + """ + super(MaxCombination, self).__init__() + + self.neg = neg + self.scale = scale + + def __repr__(self): + return f"MaxCombination(neg={self.neg}, scale={self.scale})" + + def __str__(self): + return repr(self) + + def forward(self, pred_list, grad_dict, param_names, *model_params): + return _MaxCombinationFunc.apply( + self.neg, self.scale, pred_list, grad_dict, param_names, *model_params + ) + + +class _MaxCombinationFunc(torch.autograd.Function): + """ + Custom autograd function that will handle the gradient math for us. + """ + + @staticmethod + def forward(neg, scale, pred_list, grad_dict, param_names, *model_params): + """ + neg: bool + Negate the predictions before calculating the LSE, effectively finding + the min. Preds are negated again before being returned + scale: float + Fixed positive value to scale predictions by before taking the LSE. This + tightens the bounds of the LSE approximation + pred_list: List[torch.Tensor] + List of delta G predictions to be combined using LSE + grad_dict: dict[str, List[torch.Tensor]] + Dict mapping from parameter name to list of gradients + param_names: List[str] + List of parameter names + model_params: torch.Tensor + Actual parameters that we'll return the gradients for. Each param + should be passed individually for the backward pass to work right. + """ + neg = (-1) ** neg + # Calculate once for reuse later + all_preds = torch.stack(pred_list).flatten() + adj_preds = neg * scale * all_preds.detach() + Q = torch.logsumexp(adj_preds, dim=0) + # Calculate the actual prediction + final_pred = (neg * Q / scale).detach() + + return final_pred + + @staticmethod + def setup_context(ctx, inputs, output): + neg, scale, pred_list, grad_dict, param_names, *model_params = inputs + + grad_dict_keys, grad_dict_tensors = Combination.split_grad_dict(grad_dict) + + # Save non-Tensors for backward + ctx.neg = neg + ctx.scale = scale + ctx.grad_dict_keys = grad_dict_keys + ctx.param_names = param_names + + # Save Tensors for backward + # Saving: + # * Predictions (1 tensor) + # * Grad tensors (N params * M poses tensors) + # * Model param tensors (N params tensors) + ctx.save_for_backward( + torch.stack(pred_list).flatten(), + *grad_dict_tensors, + *model_params, + ) + + @staticmethod + def backward(ctx, grad_output): + # Unpack saved tensors + preds, *other_tensors = ctx.saved_tensors + + # Split up other_tensors + grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] + + grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) + + # Begin calculations + neg = (-1) ** ctx.neg + + # Calculate once for reuse later + adj_preds = neg * ctx.scale * preds.detach() + Q = torch.logsumexp(adj_preds, dim=0) + + # Calculate final gradients for each parameter + final_grads = {} + for n, grad_list in grad_dict.items(): + final_grads[n] = ( + torch.stack( + [ + grad * (pred - Q).exp() + for grad, pred in zip(grad_list, adj_preds) + ], + axis=-1, + ) + .detach() + .sum(axis=-1) + ) + + # Adjust gradients by grad_output + for grad in final_grads.values(): + grad *= grad_output + + # Pull out return vals + return_vals = [None] * 5 + [final_grads[n] for n in ctx.param_names] + return tuple(return_vals) + + +class BoltzmannCombination(Combination): + """ + Combine a list of deltaG predictions according to their Boltzmann weight. + Treat energy in implicit kT units. + """ + + def __init__(self): + super(BoltzmannCombination, self).__init__() + + def forward(self, pred_list, grad_dict, param_names, *model_params): + return _BoltzmannCombinationFunc.apply( + pred_list, grad_dict, param_names, *model_params + ) + + +class _BoltzmannCombinationFunc(torch.autograd.Function): + """ + Custom autograd function that will handle the gradient math for us. + """ + + @staticmethod + def forward(pred_list, grad_dict, param_names, *model_params): + """ + pred_list: List[torch.Tensor] + List of delta G predictions to be combined using LSE + grad_dict: dict[str, List[torch.Tensor]] + Dict mapping from parameter name to list of gradients + param_names: List[str] + List of parameter names + model_params: torch.Tensor + Actual parameters that we'll return the gradients for. Each param + should be passed individually for the backward pass to work right. + """ + # Save for later so we don't have to keep redoing this + adj_preds = -torch.stack(pred_list).flatten().detach() + + # First calculate the normalization factor + Q = torch.logsumexp(adj_preds, dim=0) + + # Calculate w + w = (adj_preds - Q).exp() + + # Calculate final pred + final_pred = torch.dot(w, -adj_preds) + + return final_pred + + @staticmethod + def setup_context(ctx, inputs, output): + pred_list, grad_dict, param_names, *model_params = inputs + + grad_dict_keys, grad_dict_tensors = Combination.split_grad_dict(grad_dict) + + # Save non-Tensors for backward + ctx.grad_dict_keys = grad_dict_keys + ctx.param_names = param_names + + # Save Tensors for backward + # Saving: + # * Predictions (1 tensor) + # * Grad tensors (N params * M poses tensors) + # * Model param tensors (N params tensors) + ctx.save_for_backward( + torch.stack(pred_list).flatten(), + *grad_dict_tensors, + *model_params, + ) + + @staticmethod + def backward(ctx, grad_output): + # Unpack saved tensors + preds, *other_tensors = ctx.saved_tensors + + # Split up other_tensors + grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] + + grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) + + # Begin calculations + # Save for later so we don't have to keep redoing this + adj_preds = -preds.detach() + + # First calculate the normalization factor + Q = torch.logsumexp(adj_preds, dim=0) + + # Calculate w + w = (adj_preds - Q).exp() + + # Calculate dQ/d_theta + dQ = { + n: -torch.stack( + [(pred - Q).exp() * grad for pred, grad in zip(adj_preds, grad_list)], + axis=-1, + ).sum(axis=-1) + for n, grad_list in grad_dict.items() + } + + # Calculate dw/d_theta + dw = { + n: [ + (pred - Q).exp() * (-grad - dQ[n]) + for pred, grad in zip(adj_preds, grad_list) + ] + for n, grad_list in grad_dict.items() + } + + # Calculate final grads + final_grads = {} + for n, grad_list in grad_dict.items(): + final_grads[n] = ( + torch.stack( + [ + w_grad * -pred + w_val * grad + for w_grad, pred, w_val, grad in zip( + dw[n], adj_preds, w, grad_list + ) + ], + axis=-1, + ) + .detach() + .sum(axis=-1) + ) + + # Adjust gradients by grad_output + for grad in final_grads.values(): + grad *= grad_output + + # Pull out return vals + return_vals = [None] * 3 + [final_grads[n] for n in ctx.param_names] + return tuple(return_vals) diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 2fe30f1..188473d 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -3,18 +3,10 @@ """ from copy import deepcopy import torch -from e3nn import o3 from e3nn.nn.models.gate_points_2101 import Network -from ..model import ( - BoltzmannCombination, - ConcatStrategy, - DeltaStrategy, - GroupedModel, - MeanCombination, - Model, - PIC50Readout, -) +from mtenn.model import GroupedModel, Model +from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy class E3NN(Network): @@ -110,6 +102,18 @@ def _get_delta_strategy(self): return DeltaStrategy(self._get_energy_func()) + def _get_complex_only_strategy(self): + """ + Build a ComplexOnlyStrategy object based on the passed model. + + Returns + ------- + ComplexOnlyStrategy + ComplexOnlyStrategy built from `self` + """ + + return ComplexOnlyStrategy(self._get_energy_func()) + def _get_concat_strategy(self): """ Build a ConcatStrategy object using the key "x" to extract the tensor @@ -153,7 +157,7 @@ def get_model( copying over as necessary. strategy: str, default='delta' Strategy to use to combine representation of the different parts. - Options are ['delta', 'concat'] + Options are ['delta', 'concat', 'complex'] combination: Combination, optional Combination object to use to combine predictions in a group. A value must be passed if `grouped` is `True`. @@ -182,6 +186,9 @@ def get_model( elif strategy == "concat": strategy = model._get_concat_strategy() reduce_output = True + elif strategy == "complex": + strategy = model._get_complex_only_strategy() + reduce_output = False else: raise ValueError(f"Unknown strategy: {strategy}") @@ -191,7 +198,7 @@ def get_model( ## Check on `combination` if grouped and (combination is None): raise ValueError( - f"Must pass a value for `combination` if `grouped` is `True`." + "Must pass a value for `combination` if `grouped` is `True`." ) if grouped: diff --git a/mtenn/conversion_utils/gat.py b/mtenn/conversion_utils/gat.py index 3bdb8f4..544febc 100644 --- a/mtenn/conversion_utils/gat.py +++ b/mtenn/conversion_utils/gat.py @@ -6,15 +6,7 @@ from dgllife.model import GAT as GAT_dgl from dgllife.model import WeightedSumAndMax -from ..model import ( - BoltzmannCombination, - ConcatStrategy, - DeltaStrategy, - GroupedModel, - MeanCombination, - LigandOnlyModel, - PIC50Readout, -) +from mtenn.model import LigandOnlyModel class GAT(torch.nn.Module): diff --git a/mtenn/conversion_utils/schnet.py b/mtenn/conversion_utils/schnet.py index 45c7c62..c696915 100644 --- a/mtenn/conversion_utils/schnet.py +++ b/mtenn/conversion_utils/schnet.py @@ -5,15 +5,8 @@ import torch from torch_geometric.nn.models import SchNet as PygSchNet -from ..model import ( - BoltzmannCombination, - ConcatStrategy, - DeltaStrategy, - GroupedModel, - MeanCombination, - Model, - PIC50Readout, -) +from mtenn.model import GroupedModel, Model +from mtenn.strategy import ComplexOnlyStrategy, ConcatStrategy, DeltaStrategy class SchNet(PygSchNet): @@ -103,6 +96,18 @@ def _get_delta_strategy(self): return DeltaStrategy(self._get_energy_func()) + def _get_complex_only_strategy(self): + """ + Build a ComplexOnlyStrategy object based on the passed model. + + Returns + ------- + ComplexOnlyStrategy + ComplexOnlyStrategy built from `self` + """ + + return ComplexOnlyStrategy(self._get_energy_func()) + @staticmethod def get_model( model=None, @@ -130,7 +135,7 @@ def get_model( copying over as necessary. strategy: str, default='delta' Strategy to use to combine representation of the different parts. - Options are ['delta', 'concat'] + Options are ['delta', 'concat', 'complex'] combination: Combination, optional Combination object to use to combine predictions in a group. A value must be passed if `grouped` is `True`. @@ -159,13 +164,15 @@ def get_model( strategy = model._get_delta_strategy() elif strategy == "concat": strategy = ConcatStrategy() + elif strategy == "complex": + strategy = model._get_complex_only_strategy() else: raise ValueError(f"Unknown strategy: {strategy}") ## Check on `combination` if grouped and (combination is None): raise ValueError( - f"Must pass a value for `combination` if `grouped` is `True`." + "Must pass a value for `combination` if `grouped` is `True`." ) if grouped: diff --git a/mtenn/model.py b/mtenn/model.py index ad8cef5..4817091 100644 --- a/mtenn/model.py +++ b/mtenn/model.py @@ -1,8 +1,10 @@ -from copy import deepcopy -from itertools import permutations import os import torch -from typing import Optional + +from mtenn.combination import Combination +from mtenn.representation import Representation +from mtenn.strategy import Strategy +from mtenn.readout import Readout class Model(torch.nn.Module): @@ -54,9 +56,9 @@ def forward(self, comp, *parts): energy_val = self.strategy(complex_rep, *parts_rep) if self.readout: - return self.readout(energy_val) + return self.readout(energy_val), [energy_val] else: - return energy_val + return energy_val, [energy_val] def _fix_device(self, data): ## We'll call this on everything for uniformity, but if we fix_deivec is @@ -153,8 +155,8 @@ def __init__( super(GroupedModel, self).__init__( representation, strategy, pred_readout, fix_device ) - self.combination = combination - self.readout = comb_readout + self.combination: Combination = combination + self.comb_readout: Readout = comb_readout def forward(self, input_list): """ @@ -171,10 +173,10 @@ def forward(self, input_list): torch.Tensor Combination of all predictions """ - ## Get predictions for all inputs in the list, and combine them in a - ## tensor (while keeping track of gradients) - all_reps = [] - orig_dev = None + # Get predictions for all inputs in the list, and combine them in a + # tensor (while keeping track of gradients) + pred_list = [] + grad_dict = {} for i, inp in enumerate(input_list): if "MTENN_VERBOSE" in os.environ: print(f"pose {i}", flush=True) @@ -191,15 +193,29 @@ def forward(self, input_list): f"{torch.cuda.memory_allocated():,}", flush=True, ) - all_reps.append(super(GroupedModel, self).forward(inp)) - all_reps = torch.stack(all_reps).flatten() - - ## Combine each prediction according to `self.combination` - comb_pred = self.combination(all_reps) - if self.readout: - return self.readout(comb_pred) + # First get prediction + pred, _ = super().forward(inp) + pred_list.append(pred.detach()) + + # Get gradient per sample + self.zero_grad() + pred.backward() + for n, p in self.named_parameters(): + try: + grad_dict[n].append(p.grad.detach()) + except KeyError: + grad_dict[n] = [p.grad.detach()] + # Zero grads again just to make sure nothing gets accumulated + self.zero_grad() + + # Separate out param names and params + param_names, model_params = zip(*self.named_parameters()) + comb_pred = self.combination(pred_list, grad_dict, param_names, *model_params) + + if self.comb_readout: + return self.comb_readout(comb_pred), pred_list else: - return comb_pred + return comb_pred, pred_list class LigandOnlyModel(Model): @@ -234,220 +250,3 @@ def forward(self, rep): return self.readout(pred) else: return pred - - -class Representation(torch.nn.Module): - pass - - -class Strategy(torch.nn.Module): - pass - - -class Combination(torch.nn.Module): - pass - - -class Readout(torch.nn.Module): - pass - - -class DeltaStrategy(Strategy): - """ - Simple strategy for subtracting the sum of the individual component energies - from the complex energy. - """ - - def __init__(self, energy_func, pic50=True): - super(DeltaStrategy, self).__init__() - self.energy_func: torch.nn.Module = energy_func - self.pic50 = pic50 - - def forward(self, comp, *parts): - ## Calculat delta G - return self.energy_func(comp) - sum([self.energy_func(p) for p in parts]) - - -class ConcatStrategy(Strategy): - """ - Strategy for combining the complex representation and parts representations - in some learned manner, using sum-pooling to ensure permutation-invariance - of the parts. - """ - - def __init__(self, extract_key=None): - """ - Parameters - ---------- - extract_key : str, optional - Key to use to extract representation from a dict - """ - super(ConcatStrategy, self).__init__() - self.reduce_nn: torch.nn.Module = None - self.extract_key = extract_key - - def forward(self, comp, *parts): - ## Extract representation from dict - if self.extract_key: - comp = comp[self.extract_key] - parts = [p[self.extract_key] for p in parts] - - ## Flatten tensors - comp = comp.flatten() - parts = [p.flatten() for p in parts] - - parts_size = sum([len(p) for p in parts]) - if self.reduce_nn is None: - ## These should already by representations, so initialize a Linear - ## module with appropriate input size - input_size = len(comp) + parts_size - self.reduce_nn = torch.nn.Linear(input_size, 1) - - ## Move self.reduce_nn to appropriate torch device - self.reduce_nn = self.reduce_nn.to(comp.device) - - ## Enumerate all possible permutations of parts and add together - parts_cat = torch.zeros((parts_size), device=comp.device) - for idxs in permutations(range(len(parts)), len(parts)): - parts_cat += torch.cat([parts[i] for i in idxs]) - - ## Concat comp w/ permut-invariant parts representation - full_embedded = torch.cat([comp, parts_cat]) - - return self.reduce_nn(full_embedded) - - -class MeanCombination(Combination): - """ - Combine a list of predictions by taking the mean. - """ - - def __init__(self): - super(MeanCombination, self).__init__() - - def forward(self, predictions: torch.Tensor): - return torch.mean(predictions) - - -class MaxCombination(Combination): - """ - Approximate max/min of the predictions using the LogSumExp function for smoothness. - """ - - def __init__(self, neg=True, scale=1000.0): - """ - Parameters - ---------- - neg : bool, default=True - Negate the predictions before calculating the LSE, effectively finding - the min. Preds are negated again before being returned - scale : float, default=1000.0 - Fixed positive value to scale predictions by before taking the LSE. This - tightens the bounds of the LSE approximation - """ - super(MaxCombination, self).__init__() - - self.neg = -1 * neg - self.scale = scale - - def forward(self, predictions: torch.Tensor): - return ( - self.neg - * torch.logsumexp(self.neg * self.scale * predictions, dim=0) - / self.scale - ) - - -class BoltzmannCombination(Combination): - """ - Combine a list of deltaG predictions according to their Boltzmann weight. Use LSE - approximation of min energy to improve numerical stability. Treat energy in implicit - kT units. - """ - - def __init__(self): - super(BoltzmannCombination, self).__init__() - - def forward(self, predictions: torch.Tensor): - # First calculate LSE (no scale here bc math) - lse = torch.logsumexp(-predictions, dim=0) - # Calculate Boltzmann weights for each prediction - w = torch.exp(-predictions - lse) - - return torch.dot(w, predictions) - - -class PIC50Readout(Readout): - """ - Readout implementation to convert delta G values to pIC50 values. This new - implementation assumes implicit energy units, WHICH WILL INVALIDATE MODELS TRAINED - PRIOR TO v0.3.0. - Assuming implicit energy units: - deltaG = ln(Ki) - Ki = exp(deltaG) - Using the Cheng-Prusoff equation: - Ki = IC50 / (1 + [S]/Km) - exp(deltaG) = IC50 / (1 + [S]/Km) - IC50 = exp(deltaG) * (1 + [S]/Km) - pIC50 = -log10(exp(deltaG) * (1 + [S]/Km)) - pIC50 = -log10(exp(deltaG)) - log10(1 + [S]/Km) - pIC50 = -ln(exp(deltaG))/ln(10) - log10(1 + [S]/Km) - pIC50 = -deltaG/ln(10) - log10(1 + [S]/Km) - Estimating Ki as the IC50 value: - Ki = IC50 - IC50 = exp(deltaG) - pIC50 = -log10(exp(deltaG)) - pIC50 = -ln(exp(deltaG))/ln(10) - pIC50 = -deltaG/ln(10) - """ - - def __init__(self, substrate: Optional[float] = None, Km: Optional[float] = None): - """ - Initialize conversion with specified substrate concentration and Km. If either - is left blank, the IC50 approximation will be used. - - Parameters - ---------- - substrate : float, optional - Substrate concentration for use in the Cheng-Prusoff equation. Assumed to be - in the same units as Km - Km : float, optional - Km value for use in the Cheng-Prusoff equation. Assumed to be in the same - units as substrate - """ - super(PIC50Readout, self).__init__() - - self.substrate = substrate - self.Km = Km - - if substrate and Km: - self.cp_val = 1 + substrate / Km - else: - self.cp_val = None - - def __repr__(self): - return f"PIC50Readout(substrate={self.substrate}, Km={self.Km})" - - def __str__(self): - return repr(self) - - def forward(self, delta_g): - """ - Method to convert a predicted delta G value into a pIC50 value. - - Parameters - ---------- - delta_g : torch.Tensor - Input delta G value. - - Returns - ------- - float - Calculated pIC50 value. - """ - pic50 = -delta_g / torch.log(torch.tensor(10, dtype=delta_g.dtype)) - # Using Cheng-Prusoff - if self.cp_val: - pic50 -= torch.log10(torch.tensor(self.cp_val, dtype=delta_g.dtype)) - - return pic50 diff --git a/mtenn/readout.py b/mtenn/readout.py new file mode 100644 index 0000000..41dbf2c --- /dev/null +++ b/mtenn/readout.py @@ -0,0 +1,83 @@ +import abc +import torch +from typing import Optional + + +class Readout(torch.nn.Module, abc.ABC): + pass + + +class PIC50Readout(Readout): + """ + Readout implementation to convert delta G values to pIC50 values. This new + implementation assumes implicit energy units, WHICH WILL INVALIDATE MODELS TRAINED + PRIOR TO v0.3.0. + Assuming implicit energy units: + deltaG = ln(Ki) + Ki = exp(deltaG) + Using the Cheng-Prusoff equation: + Ki = IC50 / (1 + [S]/Km) + exp(deltaG) = IC50 / (1 + [S]/Km) + IC50 = exp(deltaG) * (1 + [S]/Km) + pIC50 = -log10(exp(deltaG) * (1 + [S]/Km)) + pIC50 = -log10(exp(deltaG)) - log10(1 + [S]/Km) + pIC50 = -ln(exp(deltaG))/ln(10) - log10(1 + [S]/Km) + pIC50 = -deltaG/ln(10) - log10(1 + [S]/Km) + Estimating Ki as the IC50 value: + Ki = IC50 + IC50 = exp(deltaG) + pIC50 = -log10(exp(deltaG)) + pIC50 = -ln(exp(deltaG))/ln(10) + pIC50 = -deltaG/ln(10) + """ + + def __init__(self, substrate: Optional[float] = None, Km: Optional[float] = None): + """ + Initialize conversion with specified substrate concentration and Km. If either + is left blank, the IC50 approximation will be used. + + Parameters + ---------- + substrate : float, optional + Substrate concentration for use in the Cheng-Prusoff equation. Assumed to be + in the same units as Km + Km : float, optional + Km value for use in the Cheng-Prusoff equation. Assumed to be in the same + units as substrate + """ + super(PIC50Readout, self).__init__() + + self.substrate = substrate + self.Km = Km + + if substrate and Km: + self.cp_val = 1 + substrate / Km + else: + self.cp_val = None + + def __repr__(self): + return f"PIC50Readout(substrate={self.substrate}, Km={self.Km})" + + def __str__(self): + return repr(self) + + def forward(self, delta_g): + """ + Method to convert a predicted delta G value into a pIC50 value. + + Parameters + ---------- + delta_g : torch.Tensor + Input delta G value. + + Returns + ------- + float + Calculated pIC50 value. + """ + pic50 = -delta_g / torch.log(torch.tensor(10, dtype=delta_g.dtype)) + # Using Cheng-Prusoff + if self.cp_val: + pic50 -= torch.log10(torch.tensor(self.cp_val, dtype=delta_g.dtype)) + + return pic50 diff --git a/mtenn/representation.py b/mtenn/representation.py new file mode 100644 index 0000000..46b84c0 --- /dev/null +++ b/mtenn/representation.py @@ -0,0 +1,6 @@ +import abc +import torch + + +class Representation(torch.nn.Module, abc.ABC): + pass diff --git a/mtenn/strategy.py b/mtenn/strategy.py new file mode 100644 index 0000000..42e4d09 --- /dev/null +++ b/mtenn/strategy.py @@ -0,0 +1,95 @@ +import abc +from itertools import permutations +import torch + + +class Strategy(torch.nn.Module, abc.ABC): + pass + + +class DeltaStrategy(Strategy): + """ + Simple strategy for subtracting the sum of the individual component energies + from the complex energy. + """ + + def __init__(self, energy_func, pic50=True): + super(DeltaStrategy, self).__init__() + self.energy_func: torch.nn.Module = energy_func + self.pic50 = pic50 + + def forward(self, comp, *parts): + ## Calculat delta G + complex_pred = self.energy_func(comp) + parts_preds = [self.energy_func(p) for p in parts] + parts_preds = [ + p if len(p.flatten()) > 0 else torch.zeros_like(complex_pred) + for p in parts_preds + ] + dG_pred = complex_pred - sum(parts_preds) + return dG_pred + + +class ConcatStrategy(Strategy): + """ + Strategy for combining the complex representation and parts representations + in some learned manner, using sum-pooling to ensure permutation-invariance + of the parts. + """ + + def __init__(self, extract_key=None): + """ + Parameters + ---------- + extract_key : str, optional + Key to use to extract representation from a dict + """ + super(ConcatStrategy, self).__init__() + self.reduce_nn: torch.nn.Module = None + self.extract_key = extract_key + + def forward(self, comp, *parts): + ## Extract representation from dict + if self.extract_key: + comp = comp[self.extract_key] + parts = [p[self.extract_key] for p in parts] + + ## Flatten tensors + comp = comp.flatten() + parts = [p.flatten() for p in parts] + + parts_size = sum([len(p) for p in parts]) + if self.reduce_nn is None: + ## These should already by representations, so initialize a Linear + ## module with appropriate input size + input_size = len(comp) + parts_size + self.reduce_nn = torch.nn.Linear(input_size, 1) + + ## Move self.reduce_nn to appropriate torch device + self.reduce_nn = self.reduce_nn.to(comp.device) + + ## Enumerate all possible permutations of parts and add together + parts_cat = torch.zeros((parts_size), device=comp.device) + for idxs in permutations(range(len(parts)), len(parts)): + parts_cat += torch.cat([parts[i] for i in idxs]) + + ## Concat comp w/ permut-invariant parts representation + full_embedded = torch.cat([comp, parts_cat]) + + return self.reduce_nn(full_embedded) + + +class ComplexOnlyStrategy(Strategy): + """ + Strategy to only return prediction for the complex. This is useful if you want to + make a prediction on just the ligand or just the protein, and essentially just + reduces to a standard version of whatever your underlying model is. + """ + + def __init__(self, energy_func): + super().__init__() + self.energy_func: torch.nn.Module = energy_func + + def forward(self, comp, *parts): + complex_pred = self.energy_func(comp) + return complex_pred diff --git a/mtenn/tests/test_combination.py b/mtenn/tests/test_combination.py new file mode 100644 index 0000000..13355c3 --- /dev/null +++ b/mtenn/tests/test_combination.py @@ -0,0 +1,122 @@ +from copy import deepcopy +import numpy as np +import pytest +import torch +from torch_geometric.nn import SchNet as PygSchNet + +from mtenn.combination import MeanCombination, MaxCombination, BoltzmannCombination +from mtenn.conversion_utils import SchNet + + +@pytest.fixture() +def models_and_inputs(): + model_test = SchNet( + PygSchNet(hidden_channels=2, num_filters=2, num_interactions=2, num_gaussians=2) + ) + model_ref = deepcopy(model_test) + model_ref = SchNet.get_model(model_ref, strategy="complex") + + elem_list = torch.randint(11, size=(10,)) + inp_list = [ + { + "z": elem_list, + "pos": torch.rand((10, 3)) * 10, + "lig": torch.ones(10, dtype=bool), + } + for _ in range(5) + ] + target = torch.rand(1) + loss_func = torch.nn.MSELoss() + + return model_test, model_ref, inp_list, target, loss_func + + +def test_mean_combination(models_and_inputs): + model_test, model_ref, inp_list, target, loss_func = models_and_inputs + + # Ref calc + pred_list = [model_ref(X)[0] for X in inp_list] + pred_ref = torch.stack(pred_list).mean(axis=0) + loss = loss_func(pred_ref, target) + loss.backward() + + # Finish setting up GroupedModel + model_test = SchNet.get_model( + model_test, grouped=True, strategy="complex", combination=MeanCombination() + ) + + # Test GroupedModel + pred_test, _ = model_test(inp_list) + loss = loss_func(pred_test, target) + loss.backward() + + # Compare + ref_param_dict = dict(model_ref.named_parameters()) + assert all( + [ + np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) + for n, p in model_test.named_parameters() + ] + ) + + +def test_max_combination(models_and_inputs): + model_test, model_ref, inp_list, target, loss_func = models_and_inputs + + # Ref calc + pred_list = [model_ref(X)[0] for X in inp_list] + pred = torch.logsumexp(torch.stack(pred_list), axis=0) + loss = loss_func(pred, target) + loss.backward() + + # Finish setting up GroupedModel + model_test = SchNet.get_model( + model_test, + grouped=True, + strategy="complex", + combination=MaxCombination(neg=False, scale=1.0), + ) + + # Test GroupedModel + pred, _ = model_test(inp_list) + loss = loss_func(pred, target) + loss.backward() + + # Compare + ref_param_dict = dict(model_ref.named_parameters()) + assert all( + [ + np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) + for n, p in model_test.named_parameters() + ] + ) + + +def test_boltzmann_combination(models_and_inputs): + model_test, model_ref, inp_list, target, loss_func = models_and_inputs + + # Ref calc + pred_list = torch.stack([model_ref(X)[0] for X in inp_list]) + w = torch.exp(-pred_list - torch.logsumexp(-pred_list, axis=0)) + pred_ref = torch.dot(w.flatten(), pred_list.flatten()) + loss = loss_func(pred_ref, target) + loss.backward() + + # Finish setting up GroupedModel + model_test = SchNet.get_model( + model_test, grouped=True, strategy="complex", combination=BoltzmannCombination() + ) + + # Test GroupedModel + pred_test, _ = model_test(inp_list) + loss = loss_func(pred_test, target) + loss.backward() + + # Compare + ref_param_dict = dict(model_ref.named_parameters()) + assert all( + [ + np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) + for n, p in model_test.named_parameters() + ] + )