diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 920fd08..2fe30f1 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -6,13 +6,21 @@ from e3nn import o3 from e3nn.nn.models.gate_points_2101 import Network -from ..model import ConcatStrategy, DeltaStrategy, Model +from ..model import ( + BoltzmannCombination, + ConcatStrategy, + DeltaStrategy, + GroupedModel, + MeanCombination, + Model, + PIC50Readout, +) class E3NN(Network): - def __init__(self, model_kwargs, model=None): - ## If no model is passed, construct E3NN model with model_kwargs, otherwise copy - ## all parameters and weights over + def __init__(self, model=None, model_kwargs=None): + ## If no model is passed, construct E3NN model with model_kwargs, + ## otherwise copy all parameters and weights over if model is None: super(E3NN, self).__init__(**model_kwargs) self.model_parameters = model_kwargs @@ -42,13 +50,16 @@ def forward(self, data): copy["x"] = torch.clone(x) return copy - def _get_representation(self): + def _get_representation(self, reduce_output=False): """ Input model, remove last layer. Parameters ---------- - model: E3NN - e3nn model + reduce_output: bool, default=False + Whether to reduce output across nodes. This should be set to True + if you want a uniform size tensor for every input size (eg when + using a ConcatStrategy). + Returns ------- E3NN @@ -59,7 +70,7 @@ def _get_representation(self): model_copy = deepcopy(self) ## Remove last layer model_copy.layers = model_copy.layers[:-1] - model_copy.reduce_output = False + model_copy.reduce_output = reduce_output return model_copy @@ -89,24 +100,44 @@ def _get_energy_func(self): def _get_delta_strategy(self): """ - Build a DeltaStrategy object based on the passed model. - Parameters - ---------- - model: E3NN - e3nn model + Build a DeltaStrategy object based on the calling model. + Returns ------- DeltaStrategy - DeltaStrategy built from `model` + DeltaStrategy built from the model """ return DeltaStrategy(self._get_energy_func()) + def _get_concat_strategy(self): + """ + Build a ConcatStrategy object using the key "x" to extract the tensor + representation from the data dict. + + Returns + ------- + ConcatStrategy + ConcatStrategy for the model + """ + + return ConcatStrategy(extract_key="x") + @staticmethod - def get_model(model=None, model_kwargs=None, strategy: str = "delta"): + def get_model( + model=None, + model_kwargs=None, + grouped=False, + fix_device=False, + strategy: str = "delta", + combination=None, + pred_readout=None, + comb_readout=None, + ): """ Exposed function to build a Model object from a E3NN object. If none is provided, a default model is initialized. + Parameters ---------- model: E3NN, optional @@ -114,9 +145,24 @@ def get_model(model=None, model_kwargs=None, strategy: str = "delta"): default model will be initialized and used model_kwargs: dict, optional Dictionary used to initialize E3NN model if model is not passed in + grouped: bool, default=False + Whether this model should accept groups of inputs or one input at a + time. + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary. strategy: str, default='delta' Strategy to use to combine representation of the different parts. Options are ['delta', 'concat'] + combination: Combination, optional + Combination object to use to combine predictions in a group. A value + must be passed if `grouped` is `True`. + pred_readout : Readout + Readout object for the energy predictions. If `grouped` is `False`, + this option will still be used in the construction of the `Model` + object. + comb_readout : Readout + Readout object for the combination output. Returns ------- Model @@ -127,14 +173,35 @@ def get_model(model=None, model_kwargs=None, strategy: str = "delta"): if model is None: model = E3NN(model_kwargs) - ## First get representation module - representation = model._get_representation() - ## Construct strategy module based on model and ## representation (if necessary) + strategy = strategy.lower() if strategy == "delta": strategy = model._get_delta_strategy() + reduce_output = False elif strategy == "concat": - strategy = ConcatStrategy() + strategy = model._get_concat_strategy() + reduce_output = True + else: + raise ValueError(f"Unknown strategy: {strategy}") - return Model(representation, strategy) + ## Get representation module + representation = model._get_representation(reduce_output=reduce_output) + + ## Check on `combination` + if grouped and (combination is None): + raise ValueError( + f"Must pass a value for `combination` if `grouped` is `True`." + ) + + if grouped: + return GroupedModel( + representation, + strategy, + combination, + pred_readout, + comb_readout, + fix_device, + ) + else: + return Model(representation, strategy, pred_readout, fix_device) diff --git a/mtenn/conversion_utils/schnet.py b/mtenn/conversion_utils/schnet.py index fbcb5bc..45c7c62 100644 --- a/mtenn/conversion_utils/schnet.py +++ b/mtenn/conversion_utils/schnet.py @@ -5,7 +5,15 @@ import torch from torch_geometric.nn.models import SchNet as PygSchNet -from ..model import ConcatStrategy, DeltaStrategy, Model +from ..model import ( + BoltzmannCombination, + ConcatStrategy, + DeltaStrategy, + GroupedModel, + MeanCombination, + Model, + PIC50Readout, +) class SchNet(PygSchNet): @@ -15,14 +23,18 @@ def __init__(self, model=None): if model is None: super(SchNet, self).__init__() else: - atomref = model.atomref.weight.detach().clone() + try: + atomref = model.atomref.weight.detach().clone() + except AttributeError: + atomref = None model_params = ( model.hidden_channels, model.num_filters, model.num_interactions, model.num_gaussians, model.cutoff, - model.max_num_neighbors, + model.interaction_graph, + model.interaction_graph.max_num_neighbors, model.readout, model.dipole, model.mean, @@ -92,7 +104,15 @@ def _get_delta_strategy(self): return DeltaStrategy(self._get_energy_func()) @staticmethod - def get_model(model=None, strategy: str = "delta"): + def get_model( + model=None, + grouped=False, + fix_device=False, + strategy: str = "delta", + combination=None, + pred_readout=None, + comb_readout=None, + ): """ Exposed function to build a Model object from a SchNet object. If none is provided, a default model is initialized. @@ -102,9 +122,24 @@ def get_model(model=None, strategy: str = "delta"): model: SchNet, optional SchNet model to use to build the Model object. If left as none, a default model will be initialized and used + grouped: bool, default=False + Whether this model should accept groups of inputs or one input at a + time. + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary. strategy: str, default='delta' Strategy to use to combine representation of the different parts. Options are ['delta', 'concat'] + combination: Combination, optional + Combination object to use to combine predictions in a group. A value + must be passed if `grouped` is `True`. + pred_readout : Readout + Readout object for the energy predictions. If `grouped` is `False`, + this option will still be used in the construction of the `Model` + object. + comb_readout : Readout + Readout object for the combination output. Returns ------- @@ -119,9 +154,28 @@ def get_model(model=None, strategy: str = "delta"): ## Construct strategy module based on model and ## representation (if necessary) + strategy = strategy.lower() if strategy == "delta": strategy = model._get_delta_strategy() elif strategy == "concat": strategy = ConcatStrategy() + else: + raise ValueError(f"Unknown strategy: {strategy}") + + ## Check on `combination` + if grouped and (combination is None): + raise ValueError( + f"Must pass a value for `combination` if `grouped` is `True`." + ) - return Model(representation, strategy) + if grouped: + return GroupedModel( + representation, + strategy, + combination, + pred_readout, + comb_readout, + fix_device, + ) + else: + return Model(representation, strategy, pred_readout, fix_device) diff --git a/mtenn/model.py b/mtenn/model.py index 912ef99..471399a 100644 --- a/mtenn/model.py +++ b/mtenn/model.py @@ -1,5 +1,6 @@ from copy import deepcopy from itertools import permutations +import os import torch @@ -11,10 +12,22 @@ class Model(torch.nn.Module): representations, and convert to a final scalar value. """ - def __init__(self, representation, strategy): + def __init__( + self, representation, strategy, readout=None, fix_device=False + ): + """ + Parameters + ---------- + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary. + """ super(Model, self).__init__() self.representation: Representation = representation self.strategy: Strategy = strategy + self.readout: Readout = readout + + self.fix_device = fix_device def get_representation(self, *args, **kwargs): """ @@ -33,13 +46,36 @@ def get_representation(self, *args, **kwargs): def forward(self, comp, *parts): ## This implementation of the forward function assumes the ## get_representation function takes a single data object - complex_rep = self.get_representation(comp) + tmp_comp = self._fix_device(comp) + complex_rep = self.get_representation(tmp_comp) if len(parts) == 0: - parts = Model._split_parts(comp) - parts_rep = [self.get_representation(p) for p in parts] + parts = Model._split_parts(tmp_comp) + parts_rep = [ + self.get_representation(self._fix_device(p)) for p in parts + ] + + energy_val = self.strategy(complex_rep, *parts_rep) + if self.readout: + return self.readout(energy_val) + else: + return energy_val - return self.strategy(complex_rep, *parts_rep) + def _fix_device(self, data): + ## We'll call this on everything for uniformity, but if we fix_deivec is + ## False we can just return + if not self.fix_device: + return data + + device = next(self.parameters()).device + tmp_data = {} + for k, v in data.items(): + try: + tmp_data[k] = v.to(device) + except AttributeError: + tmp_data[k] = v + + return tmp_data @staticmethod def _split_parts(comp): @@ -72,16 +108,103 @@ def _split_parts(comp): prot_rep[k] = v lig_rep[k] = v else: - # prot_idx = torch.arange(len(idx))[~idx].to(v.device) - # lig_idx = torch.arange(len(idx))[idx].to(v.device) - # prot_rep[k] = torch.index_select(v, 0, prot_idx) - # lig_rep[k] = torch.index_select(v, 0, lig_idx) prot_rep[k] = v[~idx] lig_rep[k] = v[idx] return prot_rep, lig_rep +class GroupedModel(Model): + """ + Subclass of the above `Model` for use with grouped data, eg multiple docked + poses of the same molecule with the same protein. In addition to the + `Representation` and `Strategy` modules in the `Model` class, `GroupedModel` + also has a `Comination` module, that dictates how the `Model` predictions + for each item in the group of data are combined. + """ + + def __init__( + self, + representation, + strategy, + combination, + pred_readout=None, + comb_readout=None, + fix_device=False, + ): + """ + The `representation`, `strategy`, and `pred_readout` options will be used + to initialize the underlying `Model` object, while the `combination` and + `comb_readout` modules will be applied to the output of the `Model` preds. + + Parameters + ---------- + representation : Representation + Representation object to get the representation of the input data. + strategy : Strategy + Strategy object to get convert the representations into energy preds. + combination : Combination + Combination object for combining the energy predictions. + pred_readout : Readout, optional + Readout object for the energy predictions. + comb_readout : Readout, optional + Readout object for the combination output. + fix_device: bool, default=False + If True, make sure the input is on the same device as the model, + copying over as necessary. + """ + super(GroupedModel, self).__init__( + representation, strategy, pred_readout, fix_device + ) + self.combination = combination + self.readout = comb_readout + + def forward(self, input_list): + """ + Forward method for `GroupedModel` class. Will call the `forward` method + of `Model` for each entry in `input_list`. + + Parameters + ---------- + input_list : List[Tuple[Dict]] + List of tuples of (complex representation, part representations) + + Returns + ------- + torch.Tensor + Combination of all predictions + """ + ## Get predictions for all inputs in the list, and combine them in a + ## tensor (while keeping track of gradients) + all_reps = [] + orig_dev = None + for i, inp in enumerate(input_list): + if "MTENN_VERBOSE" in os.environ: + print(f"pose {i}", flush=True) + print( + "size", + ", ".join( + [ + f"{k}: {v.shape} ({v.dtype})" + for k, v in inp.items() + if type(v) is torch.Tensor + ] + ), + sum([len(p.flatten()) for p in self.parameters()]), + f"{torch.cuda.memory_allocated():,}", + flush=True, + ) + all_reps.append(super(GroupedModel, self).forward(inp)) + all_reps = torch.stack(all_reps).flatten() + + ## Combine each prediction according to `self.combination` + comb_pred = self.combination(all_reps) + if self.readout: + return self.readout(comb_pred) + else: + return comb_pred + + class Representation(torch.nn.Module): pass @@ -90,17 +213,27 @@ class Strategy(torch.nn.Module): pass +class Combination(torch.nn.Module): + pass + + +class Readout(torch.nn.Module): + pass + + class DeltaStrategy(Strategy): """ Simple strategy for subtracting the sum of the individual component energies from the complex energy. """ - def __init__(self, energy_func): + def __init__(self, energy_func, pic50=True): super(DeltaStrategy, self).__init__() self.energy_func: torch.nn.Module = energy_func + self.pic50 = pic50 def forward(self, comp, *parts): + ## Calculat delta G return self.energy_func(comp) - sum( [self.energy_func(p) for p in parts] ) @@ -113,27 +246,129 @@ class ConcatStrategy(Strategy): of the parts. """ - def __init__(self): + def __init__(self, extract_key=None): + """ + Parameters + ---------- + extract_key : str, optional + Key to use to extract representation from a dict + """ super(ConcatStrategy, self).__init__() self.reduce_nn: torch.nn.Module = None + self.extract_key = extract_key def forward(self, comp, *parts): - parts_size = sum([p.shape[1] for p in parts]) + ## Extract representation from dict + if self.extract_key: + comp = comp[self.extract_key] + parts = [p[self.extract_key] for p in parts] + + ## Flatten tensors + comp = comp.flatten() + parts = [p.flatten() for p in parts] + + parts_size = sum([len(p) for p in parts]) if self.reduce_nn is None: ## These should already by representations, so initialize a Linear ## module with appropriate input size - input_size = comp.shape[1] + parts_size + input_size = len(comp) + parts_size self.reduce_nn = torch.nn.Linear(input_size, 1) ## Move self.reduce_nn to appropriate torch device self.reduce_nn = self.reduce_nn.to(comp.device) ## Enumerate all possible permutations of parts and add together - parts_cat = torch.zeros((1, parts_size), device=comp.device) + parts_cat = torch.zeros((parts_size), device=comp.device) for idxs in permutations(range(len(parts)), len(parts)): - parts_cat += torch.cat([parts[i] for i in idxs], dim=1) + parts_cat += torch.cat([parts[i] for i in idxs]) ## Concat comp w/ permut-invariant parts representation - full_embedded = torch.cat([comp, parts_cat], dim=1) + full_embedded = torch.cat([comp, parts_cat]) return self.reduce_nn(full_embedded) + + +class MeanCombination(Combination): + """ + Combine a list of predictions by taking the mean. + """ + + def __init__(self): + super(MeanCombination, self).__init__() + + def forward(self, predictions: torch.Tensor): + return torch.mean(predictions) + + +class BoltzmannCombination(Combination): + """ + Combine a list of deltaG predictions according to their Boltzmann weight. + """ + + def __init__(self): + super(BoltzmannCombination, self).__init__() + + from simtk.unit import ( + BOLTZMANN_CONSTANT_kB as kB, + elementary_charge, + coulomb, + ) + + ## Convert kB to eV (calibrate to SchNet predictions) + electron_volt = elementary_charge.conversion_factor_to(coulomb) + + self.kT = (kB / electron_volt * 298.0)._value + + def forward(self, predictions: torch.Tensor): + return -self.kT * torch.logsumexp(-predictions, dim=0) + + +class PIC50Readout(Readout): + """ + Readout implementation to convert delta G values to pIC50 values. + """ + + def __init__(self, T=298.0): + """ + Initialize conversion with specified T (assume 298 K). + + Parameters + ---------- + T : float, default=298 + Temperature for conversion. + """ + super(PIC50Readout, self).__init__() + + from simtk.unit import ( + BOLTZMANN_CONSTANT_kB as kB, + elementary_charge, + coulomb, + ) + + ## Convert kB to eV (calibrate to SchNet predictions) + electron_volt = elementary_charge.conversion_factor_to(coulomb) + + self.kT = (kB / electron_volt * T)._value + + def forward(self, delta_g): + """ + Method to convert a predicted delta G value into a pIC50 value. + + Parameters + ---------- + delta_g : torch.Tensor + Input delta G value. + + Returns + ------- + float + Calculated pIC50 value. + """ + ## IC50 value = exp(dG/kT) => pic50 = -log10(exp(dg/kT)) + ## Rearrange a bit more to avoid disappearing floats: + ## pic50 = -dg/kT / ln(10) + return ( + -delta_g + / self.kT + / torch.log(torch.tensor(10, dtype=delta_g.dtype)) + )