From 5b145966eb87691c25f5357929d9b169df2be99e Mon Sep 17 00:00:00 2001 From: belsten Date: Tue, 19 Nov 2024 19:52:15 -0800 Subject: [PATCH 1/6] consolidate priors into one file and and change imports to reflect change --- sparsecoding/data/datasets/bars.py | 2 +- sparsecoding/priors/__init__.py | 1 - sparsecoding/priors/common.py | 34 -------------- sparsecoding/priors/l0.py | 65 --------------------------- sparsecoding/priors/spike_slab.py | 72 ------------------------------ 5 files changed, 1 insertion(+), 173 deletions(-) delete mode 100644 sparsecoding/priors/__init__.py delete mode 100644 sparsecoding/priors/common.py delete mode 100644 sparsecoding/priors/l0.py delete mode 100644 sparsecoding/priors/spike_slab.py diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/data/datasets/bars.py index 16f877d..242e1d5 100644 --- a/sparsecoding/data/datasets/bars.py +++ b/sparsecoding/data/datasets/bars.py @@ -1,7 +1,7 @@ import torch from torch.utils.data import Dataset -from sparsecoding.priors.common import Prior +from sparsecoding.priors import Prior class BarsDataset(Dataset): diff --git a/sparsecoding/priors/__init__.py b/sparsecoding/priors/__init__.py deleted file mode 100644 index 0d04df9..0000000 --- a/sparsecoding/priors/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for sparse priors.""" diff --git a/sparsecoding/priors/common.py b/sparsecoding/priors/common.py deleted file mode 100644 index 4940661..0000000 --- a/sparsecoding/priors/common.py +++ /dev/null @@ -1,34 +0,0 @@ -from abc import ABC, abstractmethod - - -class Prior(ABC): - """A distribution over weights. - - Parameters - ---------- - weights_dim : int - Number of weights for each sample. - """ - @abstractmethod - def D(self): - """ - Number of weights per sample. - """ - - @abstractmethod - def sample( - self, - num_samples: int = 1, - ): - """Sample weights from the prior. - - Parameters - ---------- - num_samples : int, default=1 - Number of samples. - - Returns - ------- - samples : Tensor, shape [num_samples, self.D] - Sampled weights. - """ diff --git a/sparsecoding/priors/l0.py b/sparsecoding/priors/l0.py deleted file mode 100644 index 6fcb333..0000000 --- a/sparsecoding/priors/l0.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -from sparsecoding.priors.common import Prior - - -class L0Prior(Prior): - """Prior with a distribution over the l0-norm of the weights. - - A class of priors where the weights are binary; - the distribution is over the l0-norm of the weight vector - (how many weights are active). - - Parameters - ---------- - prob_distr : Tensor, shape [D], dtype float32 - Probability distribution over the l0-norm of the weights. - """ - - def __init__( - self, - prob_distr: torch.Tensor, - ): - if prob_distr.dim() != 1: - raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") - if prob_distr.dtype != torch.float32: - raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") - if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): - raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") - - self.prob_distr = prob_distr - - @property - def D(self): - return self.prob_distr.shape[0] - - def sample( - self, - num_samples: int - ): - N = num_samples - - num_active_weights = 1 + torch.multinomial( - input=self.prob_distr, - num_samples=num_samples, - replacement=True, - ) # [N] - - d_idxs = torch.arange(self.D) - active_idx_mask = ( - d_idxs.reshape(1, self.D) - < num_active_weights.reshape(N, 1) - ) # [N, self.D] - - n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] - # Need to shuffle here so that it's not always the first weights that are active. - shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] - shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] - - # [num_active_weights], [num_active_weights] - active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] - - weights = torch.zeros((N, self.D), dtype=torch.float32) - weights[active_weight_idxs] += 1. - - return weights diff --git a/sparsecoding/priors/spike_slab.py b/sparsecoding/priors/spike_slab.py deleted file mode 100644 index ad1b316..0000000 --- a/sparsecoding/priors/spike_slab.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -from torch.distributions.laplace import Laplace - -from sparsecoding.priors.common import Prior - - -class SpikeSlabPrior(Prior): - """Prior where weights are drawn from a "spike-and-slab" distribution. - - The "spike" is at 0 and the "slab" is Laplacian. - - See: - https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf - for a good review of the spike-and-slab model. - - Parameters - ---------- - dim : int - Number of weights per sample. - p_spike : float - The probability of the weight being 0. - scale : float - The "scale" of the Laplacian distribution (larger is wider). - positive_only : bool - Ensure that the weights are positive by taking the absolute value - of weights sampled from the Laplacian. - """ - - def __init__( - self, - dim: int, - p_spike: float, - scale: float, - positive_only: bool = True, - ): - if dim < 0: - raise ValueError(f"`dim` should be nonnegative, got {dim}.") - if p_spike < 0 or p_spike > 1: - raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") - if scale <= 0: - raise ValueError(f"`scale` must be positive, got {scale}.") - - self.dim = dim - self.p_spike = p_spike - self.scale = scale - self.positive_only = positive_only - - @property - def D(self): - return self.dim - - def sample(self, num_samples: int): - N = num_samples - - zero_weights = torch.zeros((N, self.D), dtype=torch.float32) - slab_weights = Laplace( - loc=zero_weights, - scale=torch.full((N, self.D), self.scale, dtype=torch.float32), - ).sample() # [N, D] - - if self.positive_only: - slab_weights = torch.abs(slab_weights) - - spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike - - weights = torch.where( - spike_over_slab, - zero_weights, - slab_weights, - ) - - return weights From 828dc9d63b75873ad5ce35d1b07ca304acea7ed2 Mon Sep 17 00:00:00 2001 From: belsten Date: Tue, 19 Nov 2024 19:53:15 -0800 Subject: [PATCH 2/6] add priors file and update tests to reflect new import --- sparsecoding/priors.py | 167 ++++++++++++++++++++++++++++++++ tests/inference/common.py | 3 +- tests/priors/test_l0.py | 2 +- tests/priors/test_spike_slab.py | 2 +- 4 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 sparsecoding/priors.py diff --git a/sparsecoding/priors.py b/sparsecoding/priors.py new file mode 100644 index 0000000..026ac43 --- /dev/null +++ b/sparsecoding/priors.py @@ -0,0 +1,167 @@ +import torch +from torch.distributions.laplace import Laplace + +from abc import ABC, abstractmethod + + +class Prior(ABC): + """A distribution over weights. + + Parameters + ---------- + weights_dim : int + Number of weights for each sample. + """ + @abstractmethod + def D(self): + """ + Number of weights per sample. + """ + + @abstractmethod + def sample( + self, + num_samples: int = 1, + ): + """Sample weights from the prior. + + Parameters + ---------- + num_samples : int, default=1 + Number of samples. + + Returns + ------- + samples : Tensor, shape [num_samples, self.D] + Sampled weights. + """ + + +class SpikeSlabPrior(Prior): + """Prior where weights are drawn from a "spike-and-slab" distribution. + + The "spike" is at 0 and the "slab" is Laplacian. + + See: + https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf + for a good review of the spike-and-slab model. + + Parameters + ---------- + dim : int + Number of weights per sample. + p_spike : float + The probability of the weight being 0. + scale : float + The "scale" of the Laplacian distribution (larger is wider). + positive_only : bool + Ensure that the weights are positive by taking the absolute value + of weights sampled from the Laplacian. + """ + + def __init__( + self, + dim: int, + p_spike: float, + scale: float, + positive_only: bool = True, + ): + if dim < 0: + raise ValueError(f"`dim` should be nonnegative, got {dim}.") + if p_spike < 0 or p_spike > 1: + raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") + if scale <= 0: + raise ValueError(f"`scale` must be positive, got {scale}.") + + self.dim = dim + self.p_spike = p_spike + self.scale = scale + self.positive_only = positive_only + + @property + def D(self): + return self.dim + + def sample(self, num_samples: int): + N = num_samples + + zero_weights = torch.zeros((N, self.D), dtype=torch.float32) + slab_weights = Laplace( + loc=zero_weights, + scale=torch.full((N, self.D), self.scale, dtype=torch.float32), + ).sample() # [N, D] + + if self.positive_only: + slab_weights = torch.abs(slab_weights) + + spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike + + weights = torch.where( + spike_over_slab, + zero_weights, + slab_weights, + ) + + return weights + + +class L0Prior(Prior): + """Prior with a distribution over the l0-norm of the weights. + + A class of priors where the weights are binary; + the distribution is over the l0-norm of the weight vector + (how many weights are active). + + Parameters + ---------- + prob_distr : Tensor, shape [D], dtype float32 + Probability distribution over the l0-norm of the weights. + """ + + def __init__( + self, + prob_distr: torch.Tensor, + ): + if prob_distr.dim() != 1: + raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") + if prob_distr.dtype != torch.float32: + raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") + if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): + raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") + + self.prob_distr = prob_distr + + @property + def D(self): + return self.prob_distr.shape[0] + + def sample( + self, + num_samples: int + ): + N = num_samples + + num_active_weights = 1 + torch.multinomial( + input=self.prob_distr, + num_samples=num_samples, + replacement=True, + ) # [N] + + d_idxs = torch.arange(self.D) + active_idx_mask = ( + d_idxs.reshape(1, self.D) + < num_active_weights.reshape(N, 1) + ) # [N, self.D] + + n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] + # Need to shuffle here so that it's not always the first weights that are active. + shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] + shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] + + # [num_active_weights], [num_active_weights] + active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] + + weights = torch.zeros((N, self.D), dtype=torch.float32) + weights[active_weight_idxs] += 1. + + return weights diff --git a/tests/inference/common.py b/tests/inference/common.py index 2ff14d2..eb5d574 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,7 +1,6 @@ import torch -from sparsecoding.priors.l0 import L0Prior -from sparsecoding.priors.spike_slab import SpikeSlabPrior +from sparsecoding.priors import L0Prior, SpikeSlabPrior from sparsecoding.data.datasets.bars import BarsDataset torch.manual_seed(1997) diff --git a/tests/priors/test_l0.py b/tests/priors/test_l0.py index 7518ad5..12dcce6 100644 --- a/tests/priors/test_l0.py +++ b/tests/priors/test_l0.py @@ -1,7 +1,7 @@ import torch import unittest -from sparsecoding.priors.l0 import L0Prior +from sparsecoding.priors import L0Prior class TestL0Prior(unittest.TestCase): diff --git a/tests/priors/test_spike_slab.py b/tests/priors/test_spike_slab.py index df4ec31..20804fe 100644 --- a/tests/priors/test_spike_slab.py +++ b/tests/priors/test_spike_slab.py @@ -1,7 +1,7 @@ import torch import unittest -from sparsecoding.priors.spike_slab import SpikeSlabPrior +from sparsecoding.priors import SpikeSlabPrior class TestSpikeSlabPrior(unittest.TestCase): From 32848c30ed42ed6cd1e740fbafa36ba82a3e44da Mon Sep 17 00:00:00 2001 From: belsten Date: Fri, 22 Nov 2024 15:00:30 -0800 Subject: [PATCH 3/6] restructure repo st transforms, datasets, dictionaries are all in base level, not in data dir --- sparsecoding/data/datasets/field.py | 63 ------------------- .../{data/datasets/bars.py => datasets.py} | 56 +++++++++++++++++ sparsecoding/dictionaries.py | 26 ++++++++ sparsecoding/{data => }/transforms/patch.py | 0 sparsecoding/{data => }/transforms/whiten.py | 0 5 files changed, 82 insertions(+), 63 deletions(-) delete mode 100644 sparsecoding/data/datasets/field.py rename sparsecoding/{data/datasets/bars.py => datasets.py} (51%) create mode 100644 sparsecoding/dictionaries.py rename sparsecoding/{data => }/transforms/patch.py (100%) rename sparsecoding/{data => }/transforms/whiten.py (100%) diff --git a/sparsecoding/data/datasets/field.py b/sparsecoding/data/datasets/field.py deleted file mode 100644 index 4a0aae8..0000000 --- a/sparsecoding/data/datasets/field.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -from scipy.io import loadmat -import torch -from torch.utils.data import Dataset - -from sparsecoding.data.transforms.patch import patchify - - -class FieldDataset(Dataset): - """Dataset used in Olshausen & Field (1996). - - Paper: - https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf - Emergence of simple-cell receptive field properties - by learning a sparse code for natural images. - - Parameters - ---------- - root : str - Location to download the dataset to. - patch_size : int - Side length of patches for sparse dictionary learning. - stride : int, optional - Stride for sampling patches. If not specified, set to `patch_size` - (non-overlapping patches). - """ - - B = 10 - C = 1 - H = 512 - W = 512 - - def __init__( - self, - root: str, - patch_size: int = 8, - stride: int = None, - ): - self.P = patch_size - if stride is None: - stride = patch_size - - root = os.path.expanduser(root) - os.system(f"mkdir -p {root}") - if not os.path.exists(f"{root}/field.mat"): - os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") - os.system(f"mv IMAGES.mat {root}/field.mat") - - self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] - assert self.images.shape == (self.H, self.W, self.B) - - self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] - self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] - - self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] - self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] - - def __len__(self): - return self.patches.shape[0] - - def __getitem__(self, idx): - return self.patches[idx] diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/datasets.py similarity index 51% rename from sparsecoding/data/datasets/bars.py rename to sparsecoding/datasets.py index 242e1d5..f25ed51 100644 --- a/sparsecoding/data/datasets/bars.py +++ b/sparsecoding/datasets.py @@ -61,3 +61,59 @@ def __len__(self): def __getitem__(self, idx: int): return self.data[idx] + + +class FieldDataset(Dataset): + """Dataset used in Olshausen & Field (1996). + + Paper: + https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf + Emergence of simple-cell receptive field properties + by learning a sparse code for natural images. + + Parameters + ---------- + root : str + Location to download the dataset to. + patch_size : int + Side length of patches for sparse dictionary learning. + stride : int, optional + Stride for sampling patches. If not specified, set to `patch_size` + (non-overlapping patches). + """ + + B = 10 + C = 1 + H = 512 + W = 512 + + def __init__( + self, + root: str, + patch_size: int = 8, + stride: int = None, + ): + self.P = patch_size + if stride is None: + stride = patch_size + + root = os.path.expanduser(root) + os.system(f"mkdir -p {root}") + if not os.path.exists(f"{root}/field.mat"): + os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") + os.system(f"mv IMAGES.mat {root}/field.mat") + + self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] + assert self.images.shape == (self.H, self.W, self.B) + + self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] + self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] + + self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] + self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] + + def __len__(self): + return self.patches.shape[0] + + def __getitem__(self, idx): + return self.patches[idx] diff --git a/sparsecoding/dictionaries.py b/sparsecoding/dictionaries.py new file mode 100644 index 0000000..2547e3d --- /dev/null +++ b/sparsecoding/dictionaries.py @@ -0,0 +1,26 @@ +import os +import torch +import numpy as np +import pickle as pkl + +MODULE_PATH = os.path.dirname(__file__) +DICTIONARY_PATH = os.path.join(MODULE_PATH, "data/dictionaries") + + +def load_dictionary_from_pickle(path): + dictionary_file = open(path, 'rb') + numpy_dictionary = pkl.load(dictionary_file) + dictionary_file.close() + dictionary = torch.tensor(numpy_dictionary.astype(np.float32)) + return dictionary + + +def load_bars_dictionary(): + path = os.path.join(DICTIONARY_PATH, "bars", "bars-16_by_16.p") + return load_dictionary_from_pickle(path) + + +def load_olshausen_dictionary(): + path = os.path.join(DICTIONARY_PATH, "olshausen", "olshausen-1.5x_overcomplete.p") + return load_dictionary_from_pickle(path) + \ No newline at end of file diff --git a/sparsecoding/data/transforms/patch.py b/sparsecoding/transforms/patch.py similarity index 100% rename from sparsecoding/data/transforms/patch.py rename to sparsecoding/transforms/patch.py diff --git a/sparsecoding/data/transforms/whiten.py b/sparsecoding/transforms/whiten.py similarity index 100% rename from sparsecoding/data/transforms/whiten.py rename to sparsecoding/transforms/whiten.py From 61cf6c2129402c8c009670fd470ebdfedb62bc0c Mon Sep 17 00:00:00 2001 From: belsten Date: Fri, 22 Nov 2024 15:05:59 -0800 Subject: [PATCH 4/6] datasets load necessary libs and change imports in tests --- sparsecoding/datasets.py | 3 +++ tests/inference/common.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sparsecoding/datasets.py b/sparsecoding/datasets.py index f25ed51..0dc4549 100644 --- a/sparsecoding/datasets.py +++ b/sparsecoding/datasets.py @@ -1,4 +1,7 @@ import torch +import os +from scipy.io import loadmat +from sparsecoding.transforms.patch import patchify from torch.utils.data import Dataset from sparsecoding.priors import Prior diff --git a/tests/inference/common.py b/tests/inference/common.py index eb5d574..2305ed7 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,7 +1,7 @@ import torch from sparsecoding.priors import L0Prior, SpikeSlabPrior -from sparsecoding.data.datasets.bars import BarsDataset +from sparsecoding.datasets import BarsDataset torch.manual_seed(1997) From eb7727e26cf129968e8540703ff148e082f62d77 Mon Sep 17 00:00:00 2001 From: belsten Date: Fri, 22 Nov 2024 15:09:12 -0800 Subject: [PATCH 5/6] flake8 fix --- sparsecoding/dictionaries.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sparsecoding/dictionaries.py b/sparsecoding/dictionaries.py index 2547e3d..1405f60 100644 --- a/sparsecoding/dictionaries.py +++ b/sparsecoding/dictionaries.py @@ -14,13 +14,10 @@ def load_dictionary_from_pickle(path): dictionary = torch.tensor(numpy_dictionary.astype(np.float32)) return dictionary - def load_bars_dictionary(): path = os.path.join(DICTIONARY_PATH, "bars", "bars-16_by_16.p") return load_dictionary_from_pickle(path) - def load_olshausen_dictionary(): path = os.path.join(DICTIONARY_PATH, "olshausen", "olshausen-1.5x_overcomplete.p") - return load_dictionary_from_pickle(path) - \ No newline at end of file + return load_dictionary_from_pickle(path) \ No newline at end of file From 694d62c4d5370d42a23d142f383226cb7bc6c10b Mon Sep 17 00:00:00 2001 From: belsten Date: Fri, 22 Nov 2024 15:18:41 -0800 Subject: [PATCH 6/6] flake8 (again) --- sparsecoding/dictionaries.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sparsecoding/dictionaries.py b/sparsecoding/dictionaries.py index 1405f60..2c934d2 100644 --- a/sparsecoding/dictionaries.py +++ b/sparsecoding/dictionaries.py @@ -14,10 +14,12 @@ def load_dictionary_from_pickle(path): dictionary = torch.tensor(numpy_dictionary.astype(np.float32)) return dictionary + def load_bars_dictionary(): path = os.path.join(DICTIONARY_PATH, "bars", "bars-16_by_16.p") return load_dictionary_from_pickle(path) + def load_olshausen_dictionary(): path = os.path.join(DICTIONARY_PATH, "olshausen", "olshausen-1.5x_overcomplete.p") - return load_dictionary_from_pickle(path) \ No newline at end of file + return load_dictionary_from_pickle(path)