From d4647471c16e5b25287f9607b691e4a9f60a1f89 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 10 Oct 2024 13:23:08 -0400 Subject: [PATCH] Correct tests --- tests/unit_tests/test_lightning_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_lightning_module.py b/tests/unit_tests/test_lightning_module.py index e2269d7..516c906 100644 --- a/tests/unit_tests/test_lightning_module.py +++ b/tests/unit_tests/test_lightning_module.py @@ -27,7 +27,7 @@ def create_fake_qusi_lightning_module() -> QusiLightningModule: def create_fake_metric_group() -> MetricGroup: - fake_metric_group = MetricGroup(loss_metric=Mock(return_value=torch.tensor([1])), + fake_metric_group = MetricGroup(loss_metric=Mock(return_value=torch.tensor(1)), state_based_logging_metrics=ModuleList([MockStateBasedMetric()]), functional_logging_metrics=ModuleList([MockFunctionalMetric()])) return fake_metric_group