From bc62a2f844e0f506187540684ad00bc9976a932a Mon Sep 17 00:00:00 2001 From: jloveric Date: Mon, 20 May 2024 20:42:06 -0700 Subject: [PATCH] Sampler seems to work --- .../rendering.py | 10 ++-- .../single_image_dataset.py | 15 +++--- tests/test_rendering.py | 52 ++++++++++++++++++- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/high_order_implicit_representation/rendering.py b/high_order_implicit_representation/rendering.py index 5d9c22b..caaa1c0 100644 --- a/high_order_implicit_representation/rendering.py +++ b/high_order_implicit_representation/rendering.py @@ -255,16 +255,17 @@ def __init__(self, filenames, batch_size): def on_train_epoch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> None: + self._dataset.reset() pl_module.eval() with torch.no_grad(): - print("We are calling this") image_count=0 for caption_embedding, flattened_position, image in self._dataloader: + image_count+=1 + print('image_count', image_count) flattened_position=flattened_position[0] size = len(flattened_position) y_hat_list = [] for i in range(0, size, self._batch_size): - embed_single = caption_embedding.to(pl_module.device) @@ -284,14 +285,15 @@ def on_train_epoch_end( image.shape[1], image.shape[2], 3 ) - print('ans.shape', ans.shape) + #ans = ans.permute(2,0,1) + ans = ans.squeeze(0) ans = 0.5 * (ans + 1.0) f, axarr = plt.subplots(1, 2) axarr[0].imshow(ans.detach().cpu().numpy()) axarr[0].set_title("fit") - axarr[1].imshow(image.cpu()) + axarr[1].imshow(image.squeeze(0).cpu()) axarr[1].set_title("original") for i in range(2): diff --git a/high_order_implicit_representation/single_image_dataset.py b/high_order_implicit_representation/single_image_dataset.py index 7e8a7dd..05cf164 100644 --- a/high_order_implicit_representation/single_image_dataset.py +++ b/high_order_implicit_representation/single_image_dataset.py @@ -402,25 +402,26 @@ def __init__(self, filenames: List[str]): super().__init__() self.dataset = PickAPic(files=filenames) self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2") - + self.reset() + def reset(self) : - self.dataset.reset() + self.generator = self.gen_data() def __len__(self): return int(1e12) def gen_data(self): - caption, image = next(self.dataset()) - caption_embedding = self.sentence_model.encode(caption) + for caption, image in self.dataset(): + caption_embedding = self.sentence_model.encode(caption) - flattened_image, flattened_position, image = simple_image_to_dataset(image) - return caption_embedding, flattened_position, image + flattened_image, flattened_position, image = simple_image_to_dataset(image) + yield caption_embedding, flattened_position, image def __getitem__(self, idx): - return self.gen_data() + return next(self.generator) class Text2ImageDataModule(LightningDataModule): diff --git a/tests/test_rendering.py b/tests/test_rendering.py index a5c3086..2e5d78d 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -2,10 +2,11 @@ from high_order_implicit_representation.rendering import ( neighborhood_sample_generator, NeighborGenerator, + Text2ImageSampler ) import torch from omegaconf import DictConfig -from high_order_implicit_representation.networks import Net +from high_order_implicit_representation.networks import Net, GenNet from pytorch_lightning import Trainer @@ -133,3 +134,52 @@ def test_neighbor_generator(): samples=2, frames=2, output_size=[64, 64], width=3, outside=3 ) generator.on_train_epoch_end(trainer=trainer, pl_module=model) + + +def test_text2image_sampler() : + width = 3 + outside = 3 + + input_features = ((width + 2 * outside) * (width + 2 * outside) - width * width) * 3 + output_size = width * width * 3 + + cfg = DictConfig( + content={ + "max_epochs": 1, + "gpus": 0, + "lr": 1e-4, + "batch_size": 16, + "segments": 2, + "optimizer": { + "name": "adam", + "lr": 1.0e-3, + "scheduler": "plateau", + "patience": 10, + "factor": 0.1, + }, + "mlp": { + "layers": 2, + "segments": 2, + "scale": 2.0, + "width": 4, + "periodicity": None, + "rescale_output": False, + }, + "embedding_size": 384, + "input_size": 2, + "output_size": 3, + "input_segments": 20, + "layer_type": "continuous", + "n": 3, + } + ) + + model = GenNet(cfg) + trainer = Trainer( + max_epochs=cfg.max_epochs, + accelerator='cpu', + ) + + # Just make sure this runs + generator = Text2ImageSampler(filenames=["test_data/test.parquet"], batch_size=2000) + generator.on_train_epoch_end(trainer=trainer, pl_module=model) \ No newline at end of file