diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/data/datasets/bars.py deleted file mode 100644 index 16f877d..0000000 --- a/sparsecoding/data/datasets/bars.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from torch.utils.data import Dataset - -from sparsecoding.priors.common import Prior - - -class BarsDataset(Dataset): - """Toy dataset where the dictionary elements are horizontal and vertical bars. - - Dataset elements are formed by taking linear combinations of the dictionary elements, - where the weights are sampled according to the input Prior. - - Parameters - ---------- - patch_size : int - Side length for elements of the dataset. - dataset_size : int - Number of dataset elements to generate. - prior : Prior - Prior distribution on the weights. Should be sparse. - - Attributes - ---------- - basis : Tensor, shape [2 * patch_size, patch_size, patch_size] - Dictionary elements (horizontal and vertical bars). - weights : Tensor, shape [dataset_size, 2 * patch_size] - Weights for each of the dataset elements. - data : Tensor, shape [dataset_size, patch_size, patch_size] - Weighted linear combinations of the basis elements. - """ - - def __init__( - self, - patch_size: int, - dataset_size: int, - prior: Prior, - ): - self.P = patch_size - self.N = dataset_size - - one_hots = torch.nn.functional.one_hot(torch.arange(self.P)) # [P, P] - one_hots = one_hots.type(torch.float32) # [P, P] - - h_bars = one_hots.reshape(self.P, self.P, 1) - v_bars = one_hots.reshape(self.P, 1, self.P) - - h_bars = h_bars.expand(self.P, self.P, self.P) - v_bars = v_bars.expand(self.P, self.P, self.P) - self.basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] - - self.weights = prior.sample(self.N) # [N, 2*P] - - self.data = torch.einsum( - "nd,dhw->nhw", - self.weights, - self.basis, - ) - - def __len__(self): - return self.N - - def __getitem__(self, idx: int): - return self.data[idx] diff --git a/sparsecoding/data/datasets/field.py b/sparsecoding/data/datasets/field.py deleted file mode 100644 index 4a0aae8..0000000 --- a/sparsecoding/data/datasets/field.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -from scipy.io import loadmat -import torch -from torch.utils.data import Dataset - -from sparsecoding.data.transforms.patch import patchify - - -class FieldDataset(Dataset): - """Dataset used in Olshausen & Field (1996). - - Paper: - https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf - Emergence of simple-cell receptive field properties - by learning a sparse code for natural images. - - Parameters - ---------- - root : str - Location to download the dataset to. - patch_size : int - Side length of patches for sparse dictionary learning. - stride : int, optional - Stride for sampling patches. If not specified, set to `patch_size` - (non-overlapping patches). - """ - - B = 10 - C = 1 - H = 512 - W = 512 - - def __init__( - self, - root: str, - patch_size: int = 8, - stride: int = None, - ): - self.P = patch_size - if stride is None: - stride = patch_size - - root = os.path.expanduser(root) - os.system(f"mkdir -p {root}") - if not os.path.exists(f"{root}/field.mat"): - os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") - os.system(f"mv IMAGES.mat {root}/field.mat") - - self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] - assert self.images.shape == (self.H, self.W, self.B) - - self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] - self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] - - self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] - self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] - - def __len__(self): - return self.patches.shape[0] - - def __getitem__(self, idx): - return self.patches[idx] diff --git a/sparsecoding/datasets.py b/sparsecoding/datasets.py new file mode 100644 index 0000000..0dc4549 --- /dev/null +++ b/sparsecoding/datasets.py @@ -0,0 +1,122 @@ +import torch +import os +from scipy.io import loadmat +from sparsecoding.transforms.patch import patchify +from torch.utils.data import Dataset + +from sparsecoding.priors import Prior + + +class BarsDataset(Dataset): + """Toy dataset where the dictionary elements are horizontal and vertical bars. + + Dataset elements are formed by taking linear combinations of the dictionary elements, + where the weights are sampled according to the input Prior. + + Parameters + ---------- + patch_size : int + Side length for elements of the dataset. + dataset_size : int + Number of dataset elements to generate. + prior : Prior + Prior distribution on the weights. Should be sparse. + + Attributes + ---------- + basis : Tensor, shape [2 * patch_size, patch_size, patch_size] + Dictionary elements (horizontal and vertical bars). + weights : Tensor, shape [dataset_size, 2 * patch_size] + Weights for each of the dataset elements. + data : Tensor, shape [dataset_size, patch_size, patch_size] + Weighted linear combinations of the basis elements. + """ + + def __init__( + self, + patch_size: int, + dataset_size: int, + prior: Prior, + ): + self.P = patch_size + self.N = dataset_size + + one_hots = torch.nn.functional.one_hot(torch.arange(self.P)) # [P, P] + one_hots = one_hots.type(torch.float32) # [P, P] + + h_bars = one_hots.reshape(self.P, self.P, 1) + v_bars = one_hots.reshape(self.P, 1, self.P) + + h_bars = h_bars.expand(self.P, self.P, self.P) + v_bars = v_bars.expand(self.P, self.P, self.P) + self.basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] + + self.weights = prior.sample(self.N) # [N, 2*P] + + self.data = torch.einsum( + "nd,dhw->nhw", + self.weights, + self.basis, + ) + + def __len__(self): + return self.N + + def __getitem__(self, idx: int): + return self.data[idx] + + +class FieldDataset(Dataset): + """Dataset used in Olshausen & Field (1996). + + Paper: + https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf + Emergence of simple-cell receptive field properties + by learning a sparse code for natural images. + + Parameters + ---------- + root : str + Location to download the dataset to. + patch_size : int + Side length of patches for sparse dictionary learning. + stride : int, optional + Stride for sampling patches. If not specified, set to `patch_size` + (non-overlapping patches). + """ + + B = 10 + C = 1 + H = 512 + W = 512 + + def __init__( + self, + root: str, + patch_size: int = 8, + stride: int = None, + ): + self.P = patch_size + if stride is None: + stride = patch_size + + root = os.path.expanduser(root) + os.system(f"mkdir -p {root}") + if not os.path.exists(f"{root}/field.mat"): + os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") + os.system(f"mv IMAGES.mat {root}/field.mat") + + self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] + assert self.images.shape == (self.H, self.W, self.B) + + self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] + self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] + + self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] + self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] + + def __len__(self): + return self.patches.shape[0] + + def __getitem__(self, idx): + return self.patches[idx] diff --git a/sparsecoding/dictionaries.py b/sparsecoding/dictionaries.py new file mode 100644 index 0000000..2c934d2 --- /dev/null +++ b/sparsecoding/dictionaries.py @@ -0,0 +1,25 @@ +import os +import torch +import numpy as np +import pickle as pkl + +MODULE_PATH = os.path.dirname(__file__) +DICTIONARY_PATH = os.path.join(MODULE_PATH, "data/dictionaries") + + +def load_dictionary_from_pickle(path): + dictionary_file = open(path, 'rb') + numpy_dictionary = pkl.load(dictionary_file) + dictionary_file.close() + dictionary = torch.tensor(numpy_dictionary.astype(np.float32)) + return dictionary + + +def load_bars_dictionary(): + path = os.path.join(DICTIONARY_PATH, "bars", "bars-16_by_16.p") + return load_dictionary_from_pickle(path) + + +def load_olshausen_dictionary(): + path = os.path.join(DICTIONARY_PATH, "olshausen", "olshausen-1.5x_overcomplete.p") + return load_dictionary_from_pickle(path) diff --git a/sparsecoding/priors.py b/sparsecoding/priors.py new file mode 100644 index 0000000..026ac43 --- /dev/null +++ b/sparsecoding/priors.py @@ -0,0 +1,167 @@ +import torch +from torch.distributions.laplace import Laplace + +from abc import ABC, abstractmethod + + +class Prior(ABC): + """A distribution over weights. + + Parameters + ---------- + weights_dim : int + Number of weights for each sample. + """ + @abstractmethod + def D(self): + """ + Number of weights per sample. + """ + + @abstractmethod + def sample( + self, + num_samples: int = 1, + ): + """Sample weights from the prior. + + Parameters + ---------- + num_samples : int, default=1 + Number of samples. + + Returns + ------- + samples : Tensor, shape [num_samples, self.D] + Sampled weights. + """ + + +class SpikeSlabPrior(Prior): + """Prior where weights are drawn from a "spike-and-slab" distribution. + + The "spike" is at 0 and the "slab" is Laplacian. + + See: + https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf + for a good review of the spike-and-slab model. + + Parameters + ---------- + dim : int + Number of weights per sample. + p_spike : float + The probability of the weight being 0. + scale : float + The "scale" of the Laplacian distribution (larger is wider). + positive_only : bool + Ensure that the weights are positive by taking the absolute value + of weights sampled from the Laplacian. + """ + + def __init__( + self, + dim: int, + p_spike: float, + scale: float, + positive_only: bool = True, + ): + if dim < 0: + raise ValueError(f"`dim` should be nonnegative, got {dim}.") + if p_spike < 0 or p_spike > 1: + raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") + if scale <= 0: + raise ValueError(f"`scale` must be positive, got {scale}.") + + self.dim = dim + self.p_spike = p_spike + self.scale = scale + self.positive_only = positive_only + + @property + def D(self): + return self.dim + + def sample(self, num_samples: int): + N = num_samples + + zero_weights = torch.zeros((N, self.D), dtype=torch.float32) + slab_weights = Laplace( + loc=zero_weights, + scale=torch.full((N, self.D), self.scale, dtype=torch.float32), + ).sample() # [N, D] + + if self.positive_only: + slab_weights = torch.abs(slab_weights) + + spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike + + weights = torch.where( + spike_over_slab, + zero_weights, + slab_weights, + ) + + return weights + + +class L0Prior(Prior): + """Prior with a distribution over the l0-norm of the weights. + + A class of priors where the weights are binary; + the distribution is over the l0-norm of the weight vector + (how many weights are active). + + Parameters + ---------- + prob_distr : Tensor, shape [D], dtype float32 + Probability distribution over the l0-norm of the weights. + """ + + def __init__( + self, + prob_distr: torch.Tensor, + ): + if prob_distr.dim() != 1: + raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") + if prob_distr.dtype != torch.float32: + raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") + if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): + raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") + + self.prob_distr = prob_distr + + @property + def D(self): + return self.prob_distr.shape[0] + + def sample( + self, + num_samples: int + ): + N = num_samples + + num_active_weights = 1 + torch.multinomial( + input=self.prob_distr, + num_samples=num_samples, + replacement=True, + ) # [N] + + d_idxs = torch.arange(self.D) + active_idx_mask = ( + d_idxs.reshape(1, self.D) + < num_active_weights.reshape(N, 1) + ) # [N, self.D] + + n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] + # Need to shuffle here so that it's not always the first weights that are active. + shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] + shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] + + # [num_active_weights], [num_active_weights] + active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] + + weights = torch.zeros((N, self.D), dtype=torch.float32) + weights[active_weight_idxs] += 1. + + return weights diff --git a/sparsecoding/priors/__init__.py b/sparsecoding/priors/__init__.py deleted file mode 100644 index 0d04df9..0000000 --- a/sparsecoding/priors/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for sparse priors.""" diff --git a/sparsecoding/priors/common.py b/sparsecoding/priors/common.py deleted file mode 100644 index 4940661..0000000 --- a/sparsecoding/priors/common.py +++ /dev/null @@ -1,34 +0,0 @@ -from abc import ABC, abstractmethod - - -class Prior(ABC): - """A distribution over weights. - - Parameters - ---------- - weights_dim : int - Number of weights for each sample. - """ - @abstractmethod - def D(self): - """ - Number of weights per sample. - """ - - @abstractmethod - def sample( - self, - num_samples: int = 1, - ): - """Sample weights from the prior. - - Parameters - ---------- - num_samples : int, default=1 - Number of samples. - - Returns - ------- - samples : Tensor, shape [num_samples, self.D] - Sampled weights. - """ diff --git a/sparsecoding/priors/l0.py b/sparsecoding/priors/l0.py deleted file mode 100644 index 6fcb333..0000000 --- a/sparsecoding/priors/l0.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -from sparsecoding.priors.common import Prior - - -class L0Prior(Prior): - """Prior with a distribution over the l0-norm of the weights. - - A class of priors where the weights are binary; - the distribution is over the l0-norm of the weight vector - (how many weights are active). - - Parameters - ---------- - prob_distr : Tensor, shape [D], dtype float32 - Probability distribution over the l0-norm of the weights. - """ - - def __init__( - self, - prob_distr: torch.Tensor, - ): - if prob_distr.dim() != 1: - raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") - if prob_distr.dtype != torch.float32: - raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") - if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): - raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") - - self.prob_distr = prob_distr - - @property - def D(self): - return self.prob_distr.shape[0] - - def sample( - self, - num_samples: int - ): - N = num_samples - - num_active_weights = 1 + torch.multinomial( - input=self.prob_distr, - num_samples=num_samples, - replacement=True, - ) # [N] - - d_idxs = torch.arange(self.D) - active_idx_mask = ( - d_idxs.reshape(1, self.D) - < num_active_weights.reshape(N, 1) - ) # [N, self.D] - - n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] - # Need to shuffle here so that it's not always the first weights that are active. - shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] - shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] - - # [num_active_weights], [num_active_weights] - active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] - - weights = torch.zeros((N, self.D), dtype=torch.float32) - weights[active_weight_idxs] += 1. - - return weights diff --git a/sparsecoding/priors/spike_slab.py b/sparsecoding/priors/spike_slab.py deleted file mode 100644 index ad1b316..0000000 --- a/sparsecoding/priors/spike_slab.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -from torch.distributions.laplace import Laplace - -from sparsecoding.priors.common import Prior - - -class SpikeSlabPrior(Prior): - """Prior where weights are drawn from a "spike-and-slab" distribution. - - The "spike" is at 0 and the "slab" is Laplacian. - - See: - https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf - for a good review of the spike-and-slab model. - - Parameters - ---------- - dim : int - Number of weights per sample. - p_spike : float - The probability of the weight being 0. - scale : float - The "scale" of the Laplacian distribution (larger is wider). - positive_only : bool - Ensure that the weights are positive by taking the absolute value - of weights sampled from the Laplacian. - """ - - def __init__( - self, - dim: int, - p_spike: float, - scale: float, - positive_only: bool = True, - ): - if dim < 0: - raise ValueError(f"`dim` should be nonnegative, got {dim}.") - if p_spike < 0 or p_spike > 1: - raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") - if scale <= 0: - raise ValueError(f"`scale` must be positive, got {scale}.") - - self.dim = dim - self.p_spike = p_spike - self.scale = scale - self.positive_only = positive_only - - @property - def D(self): - return self.dim - - def sample(self, num_samples: int): - N = num_samples - - zero_weights = torch.zeros((N, self.D), dtype=torch.float32) - slab_weights = Laplace( - loc=zero_weights, - scale=torch.full((N, self.D), self.scale, dtype=torch.float32), - ).sample() # [N, D] - - if self.positive_only: - slab_weights = torch.abs(slab_weights) - - spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike - - weights = torch.where( - spike_over_slab, - zero_weights, - slab_weights, - ) - - return weights diff --git a/sparsecoding/data/transforms/patch.py b/sparsecoding/transforms/patch.py similarity index 100% rename from sparsecoding/data/transforms/patch.py rename to sparsecoding/transforms/patch.py diff --git a/sparsecoding/data/transforms/whiten.py b/sparsecoding/transforms/whiten.py similarity index 100% rename from sparsecoding/data/transforms/whiten.py rename to sparsecoding/transforms/whiten.py diff --git a/tests/inference/common.py b/tests/inference/common.py index 2ff14d2..2305ed7 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,8 +1,7 @@ import torch -from sparsecoding.priors.l0 import L0Prior -from sparsecoding.priors.spike_slab import SpikeSlabPrior -from sparsecoding.data.datasets.bars import BarsDataset +from sparsecoding.priors import L0Prior, SpikeSlabPrior +from sparsecoding.datasets import BarsDataset torch.manual_seed(1997) diff --git a/tests/priors/test_l0.py b/tests/priors/test_l0.py index 7518ad5..12dcce6 100644 --- a/tests/priors/test_l0.py +++ b/tests/priors/test_l0.py @@ -1,7 +1,7 @@ import torch import unittest -from sparsecoding.priors.l0 import L0Prior +from sparsecoding.priors import L0Prior class TestL0Prior(unittest.TestCase): diff --git a/tests/priors/test_spike_slab.py b/tests/priors/test_spike_slab.py index df4ec31..20804fe 100644 --- a/tests/priors/test_spike_slab.py +++ b/tests/priors/test_spike_slab.py @@ -1,7 +1,7 @@ import torch import unittest -from sparsecoding.priors.spike_slab import SpikeSlabPrior +from sparsecoding.priors import SpikeSlabPrior class TestSpikeSlabPrior(unittest.TestCase):