diff --git a/sparsecoding/models.py b/sparsecoding/models.py index af31288..714e789 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -104,29 +104,18 @@ def learn_dictionary(self, dataset, n_epoch, batch_size): losses = [] dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) - iterloader = iter(dataloader) - for i in range(n_epoch): - try: - batch = next(iterloader) - except StopIteration: - dataloader = DataLoader(dataset, batch_size=batch_size, - shuffle=True) - iterloader = iter(dataloader) - batch = next(iterloader) - - # infer coefficients - a = self.inference_method.infer(batch, self.dictionary) - - # update dictionary - self.update_dictionary(batch, a) - - # normalize dictionary - self.normalize_dictionary() - - # compute current loss - loss = self.compute_loss(batch, a) - - losses.append(loss) + for _ in range(n_epoch): + loss = 0.0 + for batch in dataloader: + # infer coefficients + a = self.inference_method.infer(batch, self.dictionary) + # update dictionary + self.update_dictionary(batch, a) + # normalize dictionary + self.normalize_dictionary() + # compute current loss + loss += self.compute_loss(batch, a) + losses.append(loss/len(dataloader)) return np.asarray(losses) def compute_loss(self, data, a):