diff --git a/devtools/conda-envs/test_cuda_env.yaml b/devtools/conda-envs/test_cuda_env.yaml new file mode 100644 index 00000000..3b854646 --- /dev/null +++ b/devtools/conda-envs/test_cuda_env.yaml @@ -0,0 +1,50 @@ +# environment for testing cuda +name: openff-nagl-test-cuda +channels: + - openeye + - dglteam/label/cu117 + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + # Base depends + - python + - pip + + # UI + - click + - click-option-group + - tqdm + - rich + + # chemistry + - openff-recharge + - openff-toolkit-base >=0.11.1 + - openff-units + - pydantic <3 + - rdkit + - openeye-toolkits + + # database + - pyarrow + + # gcn + - cudatoolkit + - dgl ==1.1.2 + - pytorch >=2.0 + - pytorch-lightning + - pytorch-cuda ==11.7 + + # parallelism + - dask-jobqueue + + # Testing + - pytest + - pytest-cov + - pytest-xdist + - codecov + + # Pip-only installs + - pip: + - rich diff --git a/openff/nagl/nn/_dataset.py b/openff/nagl/nn/_dataset.py index deeae0fa..b14d9836 100644 --- a/openff/nagl/nn/_dataset.py +++ b/openff/nagl/nn/_dataset.py @@ -623,9 +623,7 @@ def __init__( super().__init__( dataset=dataset, batch_size=batch_size, - num_workers=1, # otherwise shared memory issues collate_fn=self._collate, - # pin_memory=True, **kwargs, ) diff --git a/openff/nagl/tests/training/test_training.py b/openff/nagl/tests/training/test_training.py index c92ece21..603561ee 100644 --- a/openff/nagl/tests/training/test_training.py +++ b/openff/nagl/tests/training/test_training.py @@ -5,6 +5,7 @@ import torch import numpy as np +import pytorch_lightning as pl from openff.nagl.training.training import DGLMoleculeDataModule, DataHash, TrainingGNNModel from openff.nagl.nn._models import GNNModel @@ -334,3 +335,45 @@ def test_weighted_mixed_training_step(self, mock_training_model, dgl_methane): loss = loss["loss"] assert torch.isclose(loss, torch.tensor([expected_loss], dtype=torch.float32)) assert torch.isclose(loss, torch.tensor([123.534743])) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_train_model_no_error(example_training_config, tmpdir): + + data_module = DGLMoleculeDataModule(example_training_config) + + with tmpdir.as_cwd(): + shutil.copy( + EXAMPLE_FEATURIZED_LAZY_DATA.resolve(), + "." + ) + shutil.copy( + EXAMPLE_FEATURIZED_LAZY_DATA_SHORT.resolve(), + "." + ) + shutil.copytree( + EXAMPLE_UNFEATURIZED_PARQUET_DATASET.resolve(), + EXAMPLE_UNFEATURIZED_PARQUET_DATASET.stem + ) + shutil.copytree( + EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.resolve(), + EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.stem + ) + for stage in ["train", "val", "test"]: + config = data_module._dataset_configs[stage] + config = config.copy( + update={ + "use_cached_data": True, + "cache_directory": ".", + } + ) + data_module._dataset_configs[stage] = config + + data_module.prepare_data() + assert isinstance(data_module.train_dataloader(), DGLMoleculeDataLoader) + + model = TrainingGNNModel(example_training_config) + trainer = pl.Trainer( + accelerator="gpu", devices=1, max_epochs=2, + ) + trainer.fit(model, datamodule=data_module)