From c12d4c2e62479b84e1f27bc53aa9bd9ba586e06c Mon Sep 17 00:00:00 2001 From: jloveric Date: Mon, 13 May 2024 19:04:51 -0700 Subject: [PATCH] Working on runner --- examples/text_to_image.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/text_to_image.py b/examples/text_to_image.py index 9ecd8fb..2f6312a 100644 --- a/examples/text_to_image.py +++ b/examples/text_to_image.py @@ -6,12 +6,12 @@ from high_order_layers_torch.networks import * from pytorch_lightning import Trainer import matplotlib.pyplot as plt -from high_order_implicit_representation.networks import Net +from high_order_implicit_representation.networks import GenerativeNetwork from pytorch_lightning.callbacks import LearningRateMonitor from high_order_implicit_representation.rendering import ImageGenerator from high_order_implicit_representation.single_image_dataset import ( image_to_dataset, - ImageDataModule, + Text2ImageDataModule ) import logging @@ -31,7 +31,7 @@ def run_implicit_images(cfg: DictConfig): if cfg.train is True: full_path = [f"{root_dir}/{path}" for path in cfg.images] - data_module = ImageDataModule( + data_module = Text2ImageDataModule( filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations ) image_generator = ImageGenerator( @@ -44,7 +44,7 @@ def run_implicit_images(cfg: DictConfig): accelerator=cfg.accelerator, callbacks=[lr_monitor, image_generator], ) - model = Net(cfg) + model = GenerativeNetwork(cfg) trainer.fit(model, datamodule=data_module) logger.info("testing") @@ -58,7 +58,7 @@ def run_implicit_images(cfg: DictConfig): checkpoint_path = f"{hydra.utils.get_original_cwd()}/{cfg.checkpoint}" logger.info(f"checkpoint_path {checkpoint_path}") - model = Net.load_from_checkpoint(checkpoint_path) + model = GenerativeNetwork.load_from_checkpoint(checkpoint_path) model.eval() image_dir = f"{hydra.utils.get_original_cwd()}/{cfg.images[0]}"