From d0cd231287d86de0a1719015ca3880421f6989a0 Mon Sep 17 00:00:00 2001 From: jloveric Date: Sun, 19 May 2024 08:26:21 -0700 Subject: [PATCH] Fixed tensor shape/size issues --- README.md | 3 ++- config/generative_config.yaml | 2 +- examples/text_to_image.py | 3 +++ high_order_implicit_representation/networks.py | 10 +++++----- .../single_image_dataset.py | 13 ++++--------- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index ec16e86..edc74a0 100644 --- a/README.md +++ b/README.md @@ -70,8 +70,9 @@ Applying to random noise produces ![Random Noise]() # Associative Dictionary -This is an attempt to store more than one image in a network based on text embedding and associated image. In principle it could also be a generative model if you ask for something not in the dictionary, but we'll see what happens. +This is an attempt to store more than one image in a network based on text embedding and associated image. In principle it could also be a generative model if you ask for something not in the dictionary, but we'll see what happens. I'm using the [Pick-a-Pic](https://stability.ai/research/pick-a-pic) so you'll need to download those parquet files - the idea here is not (yet) to train on an enormous dataset but maybe use 10s to 100s of images. ``` +python3 examples/text_to_image.py batch_size=2 optimizer=sparse_lion ``` # Random Interpolation (A Generative Model) diff --git a/config/generative_config.yaml b/config/generative_config.yaml index 46da6a1..bf8c600 100644 --- a/config/generative_config.yaml +++ b/config/generative_config.yaml @@ -9,7 +9,7 @@ mlp: segments: 2 embedding_size: 384 # Text embedding size -input_size: 10 # Actually the input size of mlp +input_size: 2 # x,y output_size: 3 # rgb input_segments: 20 # number of segments for x,y position layer_type: "continuous" diff --git a/examples/text_to_image.py b/examples/text_to_image.py index 3f19cd9..8ea2ebd 100644 --- a/examples/text_to_image.py +++ b/examples/text_to_image.py @@ -13,6 +13,7 @@ image_to_dataset, Text2ImageDataModule ) +import torch import logging logging.basicConfig() @@ -23,6 +24,8 @@ @hydra.main(config_path="../config", config_name="generative_config", version_base="1.3") def run_implicit_images(cfg: DictConfig): + torch.multiprocessing.set_start_method('spawn') + logger.info(OmegaConf.to_yaml(cfg)) logger.info(f"Working directory {os.getcwd()}") logger.info(f"Orig working directory {hydra.utils.get_original_cwd()}") diff --git a/high_order_implicit_representation/networks.py b/high_order_implicit_representation/networks.py index ae4bd3b..5a4136f 100644 --- a/high_order_implicit_representation/networks.py +++ b/high_order_implicit_representation/networks.py @@ -206,13 +206,13 @@ def __init__(self, cfg: DictConfig): self.loss = nn.MSELoss() - def forward(self, x): - return self.model(x) + def forward(self, caption,x): + return self.model(caption,x) def eval_step(self, batch: Tensor, name: str): - x, y = batch - y_hat = self(x.flatten(1)) - loss = self.loss(y_hat.flatten(), y.flatten()) + caption, x, color = batch + y_hat = self(caption, x.flatten(1)) + loss = self.loss(y_hat.flatten(), color.flatten()) self.log(f"{name}_loss", loss, prog_bar=True) return loss diff --git a/high_order_implicit_representation/single_image_dataset.py b/high_order_implicit_representation/single_image_dataset.py index 585c859..9dc6e89 100644 --- a/high_order_implicit_representation/single_image_dataset.py +++ b/high_order_implicit_representation/single_image_dataset.py @@ -339,17 +339,14 @@ def __call__(self): data = pd.read_parquet(file) for index, row in data.iterrows(): - print("index", index) caption = row["caption"] - print("caption", caption) jpg_0 = row["jpg_0"] img = Image.open(io.BytesIO(jpg_0)) - arr = np.asarray(img) + arr = np.copy(np.asarray(img)) yield caption, torch.from_numpy(arr) - print("again") jpg_1 = row["jpg_1"] img = Image.open(io.BytesIO(jpg_1)) - arr = np.asarray(img) + arr = np.copy(np.asarray(img)) yield caption, torch.from_numpy(arr) @@ -360,14 +357,13 @@ def __init__(self, filenames: List[str]): self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2") def __len__(self): - return int(1e12) + return int(1e6) def gen_data(self): caption, image = next(self.dataset()) caption_embedding = self.sentence_model.encode(caption) - print("got here") flattened_image, flattened_position, image = simple_image_to_dataset(image) for index, rgb in enumerate(flattened_image): yield caption_embedding, flattened_position[index], rgb @@ -398,7 +394,6 @@ def gen_data(self): caption, image = next(self.dataset()) caption_embedding = self.sentence_model.encode(caption) - print("got here") flattened_image, flattened_position, image = simple_image_to_dataset(image) return caption_embedding, flattened_image, flattened_position @@ -412,7 +407,7 @@ def __getitem__(self, idx): class Text2ImageDataModule(LightningDataModule): - def __init__(self, filenames: List[str], num_workers:int=10, batch_size:int=32, pin_memory:bool=False): + def __init__(self, filenames: List[str], num_workers:int=0, batch_size:int=32, pin_memory:bool=False): super().__init__() self._filenames = filenames self._shuffle = False