Skip to content

Commit

Permalink
Working on runner
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 14, 2024
1 parent a3e083b commit c12d4c2
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from high_order_layers_torch.networks import *
from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
from high_order_implicit_representation.networks import Net
from high_order_implicit_representation.networks import GenerativeNetwork
from pytorch_lightning.callbacks import LearningRateMonitor
from high_order_implicit_representation.rendering import ImageGenerator
from high_order_implicit_representation.single_image_dataset import (
image_to_dataset,
ImageDataModule,
Text2ImageDataModule
)
import logging

Expand All @@ -31,7 +31,7 @@ def run_implicit_images(cfg: DictConfig):

if cfg.train is True:
full_path = [f"{root_dir}/{path}" for path in cfg.images]
data_module = ImageDataModule(
data_module = Text2ImageDataModule(
filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations
)
image_generator = ImageGenerator(
Expand All @@ -44,7 +44,7 @@ def run_implicit_images(cfg: DictConfig):
accelerator=cfg.accelerator,
callbacks=[lr_monitor, image_generator],
)
model = Net(cfg)
model = GenerativeNetwork(cfg)
trainer.fit(model, datamodule=data_module)
logger.info("testing")

Expand All @@ -58,7 +58,7 @@ def run_implicit_images(cfg: DictConfig):
checkpoint_path = f"{hydra.utils.get_original_cwd()}/{cfg.checkpoint}"

logger.info(f"checkpoint_path {checkpoint_path}")
model = Net.load_from_checkpoint(checkpoint_path)
model = GenerativeNetwork.load_from_checkpoint(checkpoint_path)

model.eval()
image_dir = f"{hydra.utils.get_original_cwd()}/{cfg.images[0]}"
Expand Down

0 comments on commit c12d4c2

Please sign in to comment.