Skip to content

Commit

Permalink
Merge pull request #7 from choderalab/add-grouped
Browse files Browse the repository at this point in the history
Add grouped models
  • Loading branch information
kaminow authored Mar 8, 2023
2 parents 2349e4e + 3b4e872 commit 67a8c5d
Show file tree
Hide file tree
Showing 3 changed files with 397 additions and 41 deletions.
107 changes: 87 additions & 20 deletions mtenn/conversion_utils/e3nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@
from e3nn import o3
from e3nn.nn.models.gate_points_2101 import Network

from ..model import ConcatStrategy, DeltaStrategy, Model
from ..model import (
BoltzmannCombination,
ConcatStrategy,
DeltaStrategy,
GroupedModel,
MeanCombination,
Model,
PIC50Readout,
)


class E3NN(Network):
def __init__(self, model_kwargs, model=None):
## If no model is passed, construct E3NN model with model_kwargs, otherwise copy
## all parameters and weights over
def __init__(self, model=None, model_kwargs=None):
## If no model is passed, construct E3NN model with model_kwargs,
## otherwise copy all parameters and weights over
if model is None:
super(E3NN, self).__init__(**model_kwargs)
self.model_parameters = model_kwargs
Expand Down Expand Up @@ -42,13 +50,16 @@ def forward(self, data):
copy["x"] = torch.clone(x)
return copy

def _get_representation(self):
def _get_representation(self, reduce_output=False):
"""
Input model, remove last layer.
Parameters
----------
model: E3NN
e3nn model
reduce_output: bool, default=False
Whether to reduce output across nodes. This should be set to True
if you want a uniform size tensor for every input size (eg when
using a ConcatStrategy).
Returns
-------
E3NN
Expand All @@ -59,7 +70,7 @@ def _get_representation(self):
model_copy = deepcopy(self)
## Remove last layer
model_copy.layers = model_copy.layers[:-1]
model_copy.reduce_output = False
model_copy.reduce_output = reduce_output

return model_copy

Expand Down Expand Up @@ -89,34 +100,69 @@ def _get_energy_func(self):

def _get_delta_strategy(self):
"""
Build a DeltaStrategy object based on the passed model.
Parameters
----------
model: E3NN
e3nn model
Build a DeltaStrategy object based on the calling model.
Returns
-------
DeltaStrategy
DeltaStrategy built from `model`
DeltaStrategy built from the model
"""

return DeltaStrategy(self._get_energy_func())

def _get_concat_strategy(self):
"""
Build a ConcatStrategy object using the key "x" to extract the tensor
representation from the data dict.
Returns
-------
ConcatStrategy
ConcatStrategy for the model
"""

return ConcatStrategy(extract_key="x")

@staticmethod
def get_model(model=None, model_kwargs=None, strategy: str = "delta"):
def get_model(
model=None,
model_kwargs=None,
grouped=False,
fix_device=False,
strategy: str = "delta",
combination=None,
pred_readout=None,
comb_readout=None,
):
"""
Exposed function to build a Model object from a E3NN object. If none
is provided, a default model is initialized.
Parameters
----------
model: E3NN, optional
E3NN model to use to build the Model object. If left as none, a
default model will be initialized and used
model_kwargs: dict, optional
Dictionary used to initialize E3NN model if model is not passed in
grouped: bool, default=False
Whether this model should accept groups of inputs or one input at a
time.
fix_device: bool, default=False
If True, make sure the input is on the same device as the model,
copying over as necessary.
strategy: str, default='delta'
Strategy to use to combine representation of the different parts.
Options are ['delta', 'concat']
combination: Combination, optional
Combination object to use to combine predictions in a group. A value
must be passed if `grouped` is `True`.
pred_readout : Readout
Readout object for the energy predictions. If `grouped` is `False`,
this option will still be used in the construction of the `Model`
object.
comb_readout : Readout
Readout object for the combination output.
Returns
-------
Model
Expand All @@ -127,14 +173,35 @@ def get_model(model=None, model_kwargs=None, strategy: str = "delta"):
if model is None:
model = E3NN(model_kwargs)

## First get representation module
representation = model._get_representation()

## Construct strategy module based on model and
## representation (if necessary)
strategy = strategy.lower()
if strategy == "delta":
strategy = model._get_delta_strategy()
reduce_output = False
elif strategy == "concat":
strategy = ConcatStrategy()
strategy = model._get_concat_strategy()
reduce_output = True
else:
raise ValueError(f"Unknown strategy: {strategy}")

return Model(representation, strategy)
## Get representation module
representation = model._get_representation(reduce_output=reduce_output)

## Check on `combination`
if grouped and (combination is None):
raise ValueError(
f"Must pass a value for `combination` if `grouped` is `True`."
)

if grouped:
return GroupedModel(
representation,
strategy,
combination,
pred_readout,
comb_readout,
fix_device,
)
else:
return Model(representation, strategy, pred_readout, fix_device)
64 changes: 59 additions & 5 deletions mtenn/conversion_utils/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
import torch
from torch_geometric.nn.models import SchNet as PygSchNet

from ..model import ConcatStrategy, DeltaStrategy, Model
from ..model import (
BoltzmannCombination,
ConcatStrategy,
DeltaStrategy,
GroupedModel,
MeanCombination,
Model,
PIC50Readout,
)


class SchNet(PygSchNet):
Expand All @@ -15,14 +23,18 @@ def __init__(self, model=None):
if model is None:
super(SchNet, self).__init__()
else:
atomref = model.atomref.weight.detach().clone()
try:
atomref = model.atomref.weight.detach().clone()
except AttributeError:
atomref = None
model_params = (
model.hidden_channels,
model.num_filters,
model.num_interactions,
model.num_gaussians,
model.cutoff,
model.max_num_neighbors,
model.interaction_graph,
model.interaction_graph.max_num_neighbors,
model.readout,
model.dipole,
model.mean,
Expand Down Expand Up @@ -92,7 +104,15 @@ def _get_delta_strategy(self):
return DeltaStrategy(self._get_energy_func())

@staticmethod
def get_model(model=None, strategy: str = "delta"):
def get_model(
model=None,
grouped=False,
fix_device=False,
strategy: str = "delta",
combination=None,
pred_readout=None,
comb_readout=None,
):
"""
Exposed function to build a Model object from a SchNet object. If none
is provided, a default model is initialized.
Expand All @@ -102,9 +122,24 @@ def get_model(model=None, strategy: str = "delta"):
model: SchNet, optional
SchNet model to use to build the Model object. If left as none, a
default model will be initialized and used
grouped: bool, default=False
Whether this model should accept groups of inputs or one input at a
time.
fix_device: bool, default=False
If True, make sure the input is on the same device as the model,
copying over as necessary.
strategy: str, default='delta'
Strategy to use to combine representation of the different parts.
Options are ['delta', 'concat']
combination: Combination, optional
Combination object to use to combine predictions in a group. A value
must be passed if `grouped` is `True`.
pred_readout : Readout
Readout object for the energy predictions. If `grouped` is `False`,
this option will still be used in the construction of the `Model`
object.
comb_readout : Readout
Readout object for the combination output.
Returns
-------
Expand All @@ -119,9 +154,28 @@ def get_model(model=None, strategy: str = "delta"):

## Construct strategy module based on model and
## representation (if necessary)
strategy = strategy.lower()
if strategy == "delta":
strategy = model._get_delta_strategy()
elif strategy == "concat":
strategy = ConcatStrategy()
else:
raise ValueError(f"Unknown strategy: {strategy}")

## Check on `combination`
if grouped and (combination is None):
raise ValueError(
f"Must pass a value for `combination` if `grouped` is `True`."
)

return Model(representation, strategy)
if grouped:
return GroupedModel(
representation,
strategy,
combination,
pred_readout,
comb_readout,
fix_device,
)
else:
return Model(representation, strategy, pred_readout, fix_device)
Loading

0 comments on commit 67a8c5d

Please sign in to comment.