Skip to content

Commit

Permalink
Fix optional postprocess layer (#156)
Browse files Browse the repository at this point in the history
* add failing test

* fix optional postprocess
  • Loading branch information
lilyminium authored Nov 5, 2024
1 parent 087212d commit b8531c8
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
1 change: 1 addition & 0 deletions openff/nagl/nn/_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def from_config(
layer_dropout = [
layer.dropout for layer in readout_config.layers
]
postprocess_layer = None
if readout_config.postprocess is not None:
postprocess_layer = _PostprocessLayerMeta._get_object(readout_config.postprocess)
hidden_feature_sizes.append(postprocess_layer.n_features)
Expand Down
114 changes: 114 additions & 0 deletions openff/nagl/tests/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
import numpy as np
import pytorch_lightning as pl

from openff.nagl.config.data import DatasetConfig, DataConfig
from openff.nagl.config.optimizer import OptimizerConfig
from openff.nagl.config.training import TrainingConfig
from openff.nagl.config.model import (
ForwardLayer,
ReadoutModule,
ModelConfig,
ConvolutionLayer,
ConvolutionModule
)
from openff.nagl.features import atoms
from openff.nagl.training.metrics import RMSEMetric
from openff.nagl.training.loss import ReadoutTarget
from openff.nagl.training.training import DGLMoleculeDataModule, DataHash, TrainingGNNModel
from openff.nagl.nn._models import GNNModel
from openff.nagl.nn._dataset import (
Expand Down Expand Up @@ -377,3 +390,104 @@ def test_train_model_no_error(example_training_config, tmpdir):
accelerator="gpu", devices=1, max_epochs=2,
)
trainer.fit(model, datamodule=data_module)



@pytest.fixture()
def forward_layer():
single_readout_layer = ForwardLayer(
hidden_feature_size=128, # 128 features per hidden convolution layer
activation_function="ReLU", # max(0, x) activation function for layer
dropout=0.0, # no dropout
)
return single_readout_layer

@pytest.fixture()
def convolution_layer():
single_convolution_layer = ConvolutionLayer(
hidden_feature_size=128, # 128 features per hidden convolution layer
aggregator_type="mean", # aggregate atom representations with mean
activation_function="ReLU", # max(0, x) activation function for layer
dropout=0.0, # no dropout
)
return single_convolution_layer

@pytest.fixture()
def convolution_module(convolution_layer):
convolution_module = ConvolutionModule(
architecture="SAGEConv", # GraphSAGE GCN
layers=[convolution_layer] * 3, # 3 hidden convolution layers
)
return convolution_module


def test_no_postprocess_layer(
convolution_module,
forward_layer,
tmpdir
):

atom_features = [atoms.AtomicElement(categories=["C", "H"])]

readout_module = ReadoutModule(
pooling="atoms",
layers=[forward_layer] * 4, # 4 internal readout layers
postprocess=None
)

model_config = ModelConfig(
version="0.1",
atom_features=atom_features,
bond_features=[],
convolution=convolution_module,
readouts={
"predicted-am1bcc-charges": readout_module
}
)

with tmpdir.as_cwd():
# copy over the data
shutil.copytree(
EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.resolve(),
EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.stem
)

dataset_name = "example-data-labelled-unfeaturized-short"

charge_rmse_target = ReadoutTarget(
metric=RMSEMetric(), # use RMSE to calculate loss
target_label="am1bcc_charges", # column to use from data as reference target
prediction_label="predicted-am1bcc-charges", # readout value to compare to target
denominator=1.0, # denominator to normalise loss -- important for multi-target objectives
weight=1.0, # how much to weight the loss -- important for multi-target objectives
)

training_dataset_config = DatasetConfig(
sources=[dataset_name],
targets=[charge_rmse_target],
batch_size=1000,
)

test_dataset_config = validation_dataset_config = DatasetConfig(
sources=[dataset_name],
targets=[charge_rmse_target],
batch_size=1000,
)

data_config = DataConfig(
training=training_dataset_config,
validation=validation_dataset_config,
test=test_dataset_config
)

optimizer_config = OptimizerConfig(
optimizer="Adam",
learning_rate=0.001,
)

training_config = TrainingConfig(
model=model_config,
data=data_config,
optimizer=optimizer_config
)
training_model = TrainingGNNModel(training_config)

0 comments on commit b8531c8

Please sign in to comment.