Skip to content

Commit

Permalink
Fixed tensor shape/size issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 19, 2024
1 parent 85b0fbf commit d0cd231
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 16 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ Applying to random noise produces
![Random Noise]()

# Associative Dictionary
This is an attempt to store more than one image in a network based on text embedding and associated image. In principle it could also be a generative model if you ask for something not in the dictionary, but we'll see what happens.
This is an attempt to store more than one image in a network based on text embedding and associated image. In principle it could also be a generative model if you ask for something not in the dictionary, but we'll see what happens. I'm using the [Pick-a-Pic](https://stability.ai/research/pick-a-pic) so you'll need to download those parquet files - the idea here is not (yet) to train on an enormous dataset but maybe use 10s to 100s of images.
```
python3 examples/text_to_image.py batch_size=2 optimizer=sparse_lion
```

# Random Interpolation (A Generative Model)
Expand Down
2 changes: 1 addition & 1 deletion config/generative_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mlp:
segments: 2

embedding_size: 384 # Text embedding size
input_size: 10 # Actually the input size of mlp
input_size: 2 # x,y
output_size: 3 # rgb
input_segments: 20 # number of segments for x,y position
layer_type: "continuous"
Expand Down
3 changes: 3 additions & 0 deletions examples/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
image_to_dataset,
Text2ImageDataModule
)
import torch
import logging

logging.basicConfig()
Expand All @@ -23,6 +24,8 @@
@hydra.main(config_path="../config", config_name="generative_config", version_base="1.3")
def run_implicit_images(cfg: DictConfig):

torch.multiprocessing.set_start_method('spawn')

logger.info(OmegaConf.to_yaml(cfg))
logger.info(f"Working directory {os.getcwd()}")
logger.info(f"Orig working directory {hydra.utils.get_original_cwd()}")
Expand Down
10 changes: 5 additions & 5 deletions high_order_implicit_representation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,13 @@ def __init__(self, cfg: DictConfig):

self.loss = nn.MSELoss()

def forward(self, x):
return self.model(x)
def forward(self, caption,x):
return self.model(caption,x)

def eval_step(self, batch: Tensor, name: str):
x, y = batch
y_hat = self(x.flatten(1))
loss = self.loss(y_hat.flatten(), y.flatten())
caption, x, color = batch
y_hat = self(caption, x.flatten(1))
loss = self.loss(y_hat.flatten(), color.flatten())

self.log(f"{name}_loss", loss, prog_bar=True)
return loss
Expand Down
13 changes: 4 additions & 9 deletions high_order_implicit_representation/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,14 @@ def __call__(self):
data = pd.read_parquet(file)

for index, row in data.iterrows():
print("index", index)
caption = row["caption"]
print("caption", caption)
jpg_0 = row["jpg_0"]
img = Image.open(io.BytesIO(jpg_0))
arr = np.asarray(img)
arr = np.copy(np.asarray(img))
yield caption, torch.from_numpy(arr)
print("again")
jpg_1 = row["jpg_1"]
img = Image.open(io.BytesIO(jpg_1))
arr = np.asarray(img)
arr = np.copy(np.asarray(img))
yield caption, torch.from_numpy(arr)


Expand All @@ -360,14 +357,13 @@ def __init__(self, filenames: List[str]):
self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2")

def __len__(self):
return int(1e12)
return int(1e6)

def gen_data(self):

caption, image = next(self.dataset())
caption_embedding = self.sentence_model.encode(caption)

print("got here")
flattened_image, flattened_position, image = simple_image_to_dataset(image)
for index, rgb in enumerate(flattened_image):
yield caption_embedding, flattened_position[index], rgb
Expand Down Expand Up @@ -398,7 +394,6 @@ def gen_data(self):
caption, image = next(self.dataset())
caption_embedding = self.sentence_model.encode(caption)

print("got here")
flattened_image, flattened_position, image = simple_image_to_dataset(image)
return caption_embedding, flattened_image, flattened_position

Expand All @@ -412,7 +407,7 @@ def __getitem__(self, idx):


class Text2ImageDataModule(LightningDataModule):
def __init__(self, filenames: List[str], num_workers:int=10, batch_size:int=32, pin_memory:bool=False):
def __init__(self, filenames: List[str], num_workers:int=0, batch_size:int=32, pin_memory:bool=False):
super().__init__()
self._filenames = filenames
self._shuffle = False
Expand Down

0 comments on commit d0cd231

Please sign in to comment.