Skip to content

Commit

Permalink
Fix CUDA launch error (#83)
Browse files Browse the repository at this point in the history
* not sure i can reproduce this on github

* cuda might help

* update trainer

* fix yaml?

* add skip condition

* remove num_workers kwarg

* update conda environment
  • Loading branch information
lilyminium authored Feb 14, 2024
1 parent 5f771ce commit 4c6953b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 2 deletions.
50 changes: 50 additions & 0 deletions devtools/conda-envs/test_cuda_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# environment for testing cuda
name: openff-nagl-test-cuda
channels:
- openeye
- dglteam/label/cu117
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
# Base depends
- python
- pip

# UI
- click
- click-option-group
- tqdm
- rich

# chemistry
- openff-recharge
- openff-toolkit-base >=0.11.1
- openff-units
- pydantic <3
- rdkit
- openeye-toolkits

# database
- pyarrow

# gcn
- cudatoolkit
- dgl ==1.1.2
- pytorch >=2.0
- pytorch-lightning
- pytorch-cuda ==11.7

# parallelism
- dask-jobqueue

# Testing
- pytest
- pytest-cov
- pytest-xdist
- codecov

# Pip-only installs
- pip:
- rich
2 changes: 0 additions & 2 deletions openff/nagl/nn/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,7 @@ def __init__(
super().__init__(
dataset=dataset,
batch_size=batch_size,
num_workers=1, # otherwise shared memory issues
collate_fn=self._collate,
# pin_memory=True,
**kwargs,
)

Expand Down
43 changes: 43 additions & 0 deletions openff/nagl/tests/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import numpy as np
import pytorch_lightning as pl

from openff.nagl.training.training import DGLMoleculeDataModule, DataHash, TrainingGNNModel
from openff.nagl.nn._models import GNNModel
Expand Down Expand Up @@ -334,3 +335,45 @@ def test_weighted_mixed_training_step(self, mock_training_model, dgl_methane):
loss = loss["loss"]
assert torch.isclose(loss, torch.tensor([expected_loss], dtype=torch.float32))
assert torch.isclose(loss, torch.tensor([123.534743]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_train_model_no_error(example_training_config, tmpdir):

data_module = DGLMoleculeDataModule(example_training_config)

with tmpdir.as_cwd():
shutil.copy(
EXAMPLE_FEATURIZED_LAZY_DATA.resolve(),
"."
)
shutil.copy(
EXAMPLE_FEATURIZED_LAZY_DATA_SHORT.resolve(),
"."
)
shutil.copytree(
EXAMPLE_UNFEATURIZED_PARQUET_DATASET.resolve(),
EXAMPLE_UNFEATURIZED_PARQUET_DATASET.stem
)
shutil.copytree(
EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.resolve(),
EXAMPLE_UNFEATURIZED_PARQUET_DATASET_SHORT.stem
)
for stage in ["train", "val", "test"]:
config = data_module._dataset_configs[stage]
config = config.copy(
update={
"use_cached_data": True,
"cache_directory": ".",
}
)
data_module._dataset_configs[stage] = config

data_module.prepare_data()
assert isinstance(data_module.train_dataloader(), DGLMoleculeDataLoader)

model = TrainingGNNModel(example_training_config)
trainer = pl.Trainer(
accelerator="gpu", devices=1, max_epochs=2,
)
trainer.fit(model, datamodule=data_module)

0 comments on commit 4c6953b

Please sign in to comment.