From 0934e5ccd09ccd877d462250354db9c63e46f46c Mon Sep 17 00:00:00 2001 From: belsten Date: Tue, 16 Apr 2024 11:47:22 -0700 Subject: [PATCH] Iterate through all of dataloader each epoch. Compute epoch energy as the average of the batch energies. --- sparsecoding/models.py | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index af31288..714e789 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -104,29 +104,18 @@ def learn_dictionary(self, dataset, n_epoch, batch_size): losses = [] dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) - iterloader = iter(dataloader) - for i in range(n_epoch): - try: - batch = next(iterloader) - except StopIteration: - dataloader = DataLoader(dataset, batch_size=batch_size, - shuffle=True) - iterloader = iter(dataloader) - batch = next(iterloader) - - # infer coefficients - a = self.inference_method.infer(batch, self.dictionary) - - # update dictionary - self.update_dictionary(batch, a) - - # normalize dictionary - self.normalize_dictionary() - - # compute current loss - loss = self.compute_loss(batch, a) - - losses.append(loss) + for _ in range(n_epoch): + loss = 0.0 + for batch in dataloader: + # infer coefficients + a = self.inference_method.infer(batch, self.dictionary) + # update dictionary + self.update_dictionary(batch, a) + # normalize dictionary + self.normalize_dictionary() + # compute current loss + loss += self.compute_loss(batch, a) + losses.append(loss/len(dataloader)) return np.asarray(losses) def compute_loss(self, data, a):