diff --git a/sparsecoding/datasets.py b/sparsecoding/datasets.py index f25ed51..0dc4549 100644 --- a/sparsecoding/datasets.py +++ b/sparsecoding/datasets.py @@ -1,4 +1,7 @@ import torch +import os +from scipy.io import loadmat +from sparsecoding.transforms.patch import patchify from torch.utils.data import Dataset from sparsecoding.priors import Prior diff --git a/tests/inference/common.py b/tests/inference/common.py index eb5d574..2305ed7 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,7 +1,7 @@ import torch from sparsecoding.priors import L0Prior, SpikeSlabPrior -from sparsecoding.data.datasets.bars import BarsDataset +from sparsecoding.datasets import BarsDataset torch.manual_seed(1997)