diff --git a/tests/test_train_mnist.py b/tests/test_train_mnist.py index 5383b39..d3a851c 100644 --- a/tests/test_train_mnist.py +++ b/tests/test_train_mnist.py @@ -1,3 +1,4 @@ +import pytest import torch from tests.utils.mnist import MnistModel, MnistModelConfig @@ -22,6 +23,7 @@ def test_train_mnist(tmp_path): # Without parsing command line args args = TrainerArgs() + args.small_run = 4 trainer2 = Trainer( args, @@ -48,3 +50,6 @@ def test_train_mnist(tmp_path): loss4 = trainer3.keep_avg_train["avg_loss"] assert loss3 > loss4 + + with pytest.raises(ValueError, match="cannot both be None"): + Trainer(args, MnistModelConfig(), output_path=tmp_path, model=None)