Skip to content

Commit

Permalink
More work on the image sampler, not yet complete
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 20, 2024
1 parent b235f18 commit 196e5b7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 46 deletions.
6 changes: 3 additions & 3 deletions examples/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def run_implicit_images(cfg: DictConfig):
data_module = Text2ImageDataModule(
filenames=full_path, batch_size=cfg.batch_size
)
image_generator = Text2ImageSampler(
filename=full_path[0], batch_size=cfg.batch_size
image_sampler = Text2ImageSampler(
filenames=[full_path[0]], batch_size=cfg.batch_size
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
checkpoint = ModelCheckpoint(
Expand All @@ -49,7 +49,7 @@ def run_implicit_images(cfg: DictConfig):
max_epochs=cfg.max_epochs,
devices=cfg.gpus,
accelerator=cfg.accelerator,
callbacks=[lr_monitor, checkpoint],
callbacks=[lr_monitor, checkpoint, image_sampler],
reload_dataloaders_every_n_epochs=1
)
model = GenNet(cfg)
Expand Down
97 changes: 55 additions & 42 deletions high_order_implicit_representation/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from high_order_implicit_representation.single_image_dataset import (
image_neighborhood_dataset,
image_to_dataset,
Text2ImageRenderDataset
Text2ImageRenderDataset,
)
import math
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -242,10 +242,13 @@ def on_train_epoch_end(
f"image", image, global_step=trainer.global_step
)


class Text2ImageSampler(Callback):
def __init__(self, filename, batch_size):
self._dataset = Text2ImageRenderDataset(filename)
self._dataloader = DataLoader(self._dataset, batch_size=batch_size, shuffle=False)
def __init__(self, filenames, batch_size):
self._dataset = Text2ImageRenderDataset(filenames)
self._dataloader = DataLoader(
self._dataset, batch_size=batch_size, shuffle=False
)
self._batch_size = batch_size

@rank_zero_only
Expand All @@ -254,50 +257,60 @@ def on_train_epoch_end(
) -> None:
pl_module.eval()
with torch.no_grad():
self._inputs = self._inputs.to(device=pl_module.device)

y_hat_list = []

for caption_embedding, flattened_image, flattened_position in self._dataloader:
for index, rgb in enumerate(flattened_image):
image_count=0
for caption_embedding, flattened_position, image in self._dataloader:
size = len(flattened_position)
for i in range(0, size, self._batch_size):
# embed = caption_embedding[i:(i+self._batch_size)]
# rgb = flattened_image[i:(i+self._batch_size)]
pos = flattened_position[i : (i + self._batch_size)]

for i in range(size//self._batch_size):
res = pl_module(
caption_embedding, flattened_position[index], rgb
caption_embedding,
flattened_position[
i * self._batch_size : (i + 1) * self._batch_size
],
)

y_hat_list.append(res.detach().cpu())
y_hat = torch.cat(y_hat_list)

ans = y_hat.reshape(
self._image.shape[0], self._image.shape[1], self._image.shape[2]
)
ans = 0.5 * (ans + 1.0)
y_hat_list.append(res.detach().cpu())

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(ans.detach().cpu().numpy())
axarr[0].set_title("fit")
axarr[1].imshow(self._image.cpu())
axarr[1].set_title("original")

for i in range(2):
axarr[i].axes.get_xaxis().set_visible(False)
axarr[i].axes.get_yaxis().set_visible(False)
y_hat = torch.vstack(y_hat_list)

buf = io.BytesIO()
plt.savefig(
buf,
dpi="figure",
format=None,
metadata=None,
bbox_inches=None,
pad_inches=0.1,
facecolor="auto",
edgecolor="auto",
backend=None,
)
buf.seek(0)
image = PIL.Image.open(buf)
image = transforms.ToTensor()(image)
ans = y_hat.reshape(
self._image.shape[0], self._image.shape[1], self._image.shape[2]
)
ans = 0.5 * (ans + 1.0)

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(ans.detach().cpu().numpy())
axarr[0].set_title("fit")
axarr[1].imshow(self._image.cpu())
axarr[1].set_title("original")

for i in range(2):
axarr[i].axes.get_xaxis().set_visible(False)
axarr[i].axes.get_yaxis().set_visible(False)

buf = io.BytesIO()
plt.savefig(
buf,
dpi="figure",
format=None,
metadata=None,
bbox_inches=None,
pad_inches=0.1,
facecolor="auto",
edgecolor="auto",
backend=None,
)
buf.seek(0)
image = PIL.Image.open(buf)
image = transforms.ToTensor()(image)

trainer.logger.experiment.add_image(
f"image", image, global_step=trainer.global_step
)
trainer.logger.experiment.add_image(
f"image_{image_count}", image, global_step=trainer.global_step
)
2 changes: 1 addition & 1 deletion high_order_implicit_representation/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def gen_data(self):
caption_embedding = self.sentence_model.encode(caption)

flattened_image, flattened_position, image = simple_image_to_dataset(image)
return caption_embedding, flattened_image, flattened_position
return caption_embedding, flattened_position, image


def __getitem__(self, idx):
Expand Down

0 comments on commit 196e5b7

Please sign in to comment.