Skip to content

Commit

Permalink
Working on new generative sampler, not finished yet
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 15, 2024
1 parent 6f15e5e commit 3371495
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from high_order_implicit_representation.networks import GenNet
from pytorch_lightning.callbacks import LearningRateMonitor
from high_order_implicit_representation.rendering import ImageGenerator
from high_order_implicit_representation.rendering import Text2ImageGenerator
from high_order_implicit_representation.single_image_dataset import (
image_to_dataset,
Text2ImageDataModule
Expand All @@ -34,7 +34,7 @@ def run_implicit_images(cfg: DictConfig):
data_module = Text2ImageDataModule(
filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations
)
image_generator = ImageGenerator(
image_generator = Text2ImageGenerator(
filename=full_path[0], rotations=cfg.rotations, batch_size=cfg.batch_size
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
Expand All @@ -58,7 +58,7 @@ def run_implicit_images(cfg: DictConfig):
checkpoint_path = f"{hydra.utils.get_original_cwd()}/{cfg.checkpoint}"

logger.info(f"checkpoint_path {checkpoint_path}")
model = GenerativeNetwork.load_from_checkpoint(checkpoint_path)
model = GenNet.load_from_checkpoint(checkpoint_path)

model.eval()
image_dir = f"{hydra.utils.get_original_cwd()}/{cfg.images[0]}"
Expand Down
63 changes: 63 additions & 0 deletions high_order_implicit_representation/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from high_order_implicit_representation.single_image_dataset import (
image_neighborhood_dataset,
image_to_dataset,
Text2ImageDataset
)
import math
import matplotlib.pyplot as plt
import io
import PIL
from torchvision import transforms
from torch.utils.data import DataLoader


logger = logging.getLogger(__name__)
default_size = [64, 64]
Expand Down Expand Up @@ -238,3 +241,63 @@ def on_train_epoch_end(
trainer.logger.experiment.add_image(
f"image", image, global_step=trainer.global_step
)

class Text2ImageGenerator(Callback):
def __init__(self, filename, rotations, batch_size):
self._dataset = Text2ImageDataset(filename, rotations=rotations)
self._dataloader = DataLoader(self._dataset, batch_size=batch_size, shuffle=False)
self._batch_size = batch_size

@rank_zero_only
def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
pl_module.eval()
with torch.no_grad():
self._inputs = self._inputs.to(device=pl_module.device)

y_hat_list = []

for batch in self._dataloader:
res = pl_module(
self._inputs[
batch * self._batch_size : (batch + 1) * self._batch_size
]
)
y_hat_list.append(res.detach().cpu())
y_hat = torch.cat(y_hat_list)

ans = y_hat.reshape(
self._image.shape[0], self._image.shape[1], self._image.shape[2]
)
ans = 0.5 * (ans + 1.0)

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(ans.detach().cpu().numpy())
axarr[0].set_title("fit")
axarr[1].imshow(self._image.cpu())
axarr[1].set_title("original")

for i in range(2):
axarr[i].axes.get_xaxis().set_visible(False)
axarr[i].axes.get_yaxis().set_visible(False)

buf = io.BytesIO()
plt.savefig(
buf,
dpi="figure",
format=None,
metadata=None,
bbox_inches=None,
pad_inches=0.1,
facecolor="auto",
edgecolor="auto",
backend=None,
)
buf.seek(0)
image = PIL.Image.open(buf)
image = transforms.ToTensor()(image)

trainer.logger.experiment.add_image(
f"image", image, global_step=trainer.global_step
)

0 comments on commit 3371495

Please sign in to comment.