Skip to content

Commit

Permalink
Update config
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 19, 2024
1 parent c104d41 commit b235f18
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Applying to random noise produces
# Associative Dictionary
This is an attempt to store more than one image in a network based on text embedding and associated image. In principle it could also be a generative model if you ask for something not in the dictionary, but we'll see what happens. I'm using the [Pick-a-Pic](https://stability.ai/research/pick-a-pic) so you'll need to download those parquet files - the idea here is not (yet) to train on an enormous dataset but maybe use 10s to 100s of images.
```
python3 examples/text_to_image.py batch_size=2 optimizer=sparse_lion
python3 examples/text_to_image.py batch_size=2000 optimizer=sparse_lion max_epochs=10
```

# Random Interpolation (A Generative Model)
Expand Down
2 changes: 1 addition & 1 deletion config/generative_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mlp:
periodicity: null
rescale_output: False
scale: 2.0
width: 10
width: 100
layers: 2
segments: 2

Expand Down
18 changes: 0 additions & 18 deletions high_order_implicit_representation/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,6 @@ def gen_data(self):
self.count+=1

def __getitem__(self, idx):
# I'm totally ignoring the index
# ans = self.dataset()
# print('ans', ans)

return next(self.generator)


Expand Down Expand Up @@ -423,9 +419,6 @@ def gen_data(self):


def __getitem__(self, idx):
# I'm totally ignoring the index
# ans = self.dataset()
# print('ans', ans)

return self.gen_data()

Expand All @@ -452,17 +445,6 @@ def train_dataset(self) -> Dataset:
def test_dataset(self) -> Dataset:
return self._test_dataset

def collate_fn_reset(self, dataset):
def collate(batch):
text = torch.vstack([row[0] for row in batch])
pos = torch.vstack([row[1] for row in batch])
rgb = torch.vstack([row[2] for row in batch])

#print('batch', batch)
return text, pos, rgb #torch.vstack(torch.from_numpy(batch[0])), torch.vstack(batch[1]), torch.vstack(batch[2])
dataset.reset()
return collate

def train_dataloader(self) -> DataLoader:
self._train_dataset.reset()
return DataLoader(
Expand Down

0 comments on commit b235f18

Please sign in to comment.