Skip to content

Commit

Permalink
Sampler seems to work
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 21, 2024
1 parent 462747e commit bc62a2f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
10 changes: 6 additions & 4 deletions high_order_implicit_representation/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,17 @@ def __init__(self, filenames, batch_size):
def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
self._dataset.reset()
pl_module.eval()
with torch.no_grad():
print("We are calling this")
image_count=0
for caption_embedding, flattened_position, image in self._dataloader:
image_count+=1
print('image_count', image_count)
flattened_position=flattened_position[0]
size = len(flattened_position)
y_hat_list = []
for i in range(0, size, self._batch_size):


embed_single = caption_embedding.to(pl_module.device)

Expand All @@ -284,14 +285,15 @@ def on_train_epoch_end(
image.shape[1], image.shape[2], 3
)

print('ans.shape', ans.shape)
#ans = ans.permute(2,0,1)
ans = ans.squeeze(0)

ans = 0.5 * (ans + 1.0)

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(ans.detach().cpu().numpy())
axarr[0].set_title("fit")
axarr[1].imshow(image.cpu())
axarr[1].imshow(image.squeeze(0).cpu())
axarr[1].set_title("original")

for i in range(2):
Expand Down
15 changes: 8 additions & 7 deletions high_order_implicit_representation/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,25 +402,26 @@ def __init__(self, filenames: List[str]):
super().__init__()
self.dataset = PickAPic(files=filenames)
self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2")

self.reset()

def reset(self) :
self.dataset.reset()
self.generator = self.gen_data()

def __len__(self):
return int(1e12)

def gen_data(self):

caption, image = next(self.dataset())
caption_embedding = self.sentence_model.encode(caption)
for caption, image in self.dataset():
caption_embedding = self.sentence_model.encode(caption)

flattened_image, flattened_position, image = simple_image_to_dataset(image)
return caption_embedding, flattened_position, image
flattened_image, flattened_position, image = simple_image_to_dataset(image)
yield caption_embedding, flattened_position, image


def __getitem__(self, idx):

return self.gen_data()
return next(self.generator)


class Text2ImageDataModule(LightningDataModule):
Expand Down
52 changes: 51 additions & 1 deletion tests/test_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from high_order_implicit_representation.rendering import (
neighborhood_sample_generator,
NeighborGenerator,
Text2ImageSampler
)
import torch
from omegaconf import DictConfig
from high_order_implicit_representation.networks import Net
from high_order_implicit_representation.networks import Net, GenNet
from pytorch_lightning import Trainer


Expand Down Expand Up @@ -133,3 +134,52 @@ def test_neighbor_generator():
samples=2, frames=2, output_size=[64, 64], width=3, outside=3
)
generator.on_train_epoch_end(trainer=trainer, pl_module=model)


def test_text2image_sampler() :
width = 3
outside = 3

input_features = ((width + 2 * outside) * (width + 2 * outside) - width * width) * 3
output_size = width * width * 3

cfg = DictConfig(
content={
"max_epochs": 1,
"gpus": 0,
"lr": 1e-4,
"batch_size": 16,
"segments": 2,
"optimizer": {
"name": "adam",
"lr": 1.0e-3,
"scheduler": "plateau",
"patience": 10,
"factor": 0.1,
},
"mlp": {
"layers": 2,
"segments": 2,
"scale": 2.0,
"width": 4,
"periodicity": None,
"rescale_output": False,
},
"embedding_size": 384,
"input_size": 2,
"output_size": 3,
"input_segments": 20,
"layer_type": "continuous",
"n": 3,
}
)

model = GenNet(cfg)
trainer = Trainer(
max_epochs=cfg.max_epochs,
accelerator='cpu',
)

# Just make sure this runs
generator = Text2ImageSampler(filenames=["test_data/test.parquet"], batch_size=2000)
generator.on_train_epoch_end(trainer=trainer, pl_module=model)

0 comments on commit bc62a2f

Please sign in to comment.