Skip to content

Commit

Permalink
Name changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 17, 2024
1 parent f97d053 commit 3a4c2b2
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from high_order_implicit_representation.networks import GenNet
from pytorch_lightning.callbacks import LearningRateMonitor
from high_order_implicit_representation.rendering import Text2ImageGenerator
from high_order_implicit_representation.rendering import Text2ImageSampler
from high_order_implicit_representation.single_image_dataset import (
image_to_dataset,
Text2ImageDataModule
Expand All @@ -34,15 +34,15 @@ def run_implicit_images(cfg: DictConfig):
data_module = Text2ImageDataModule(
filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations
)
image_generator = Text2ImageGenerator(
image_generator = Text2ImageSampler(
filename=full_path[0], rotations=cfg.rotations, batch_size=cfg.batch_size
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
trainer = Trainer(
max_epochs=cfg.max_epochs,
devices=cfg.gpus,
accelerator=cfg.accelerator,
callbacks=[lr_monitor, image_generator],
callbacks=[lr_monitor],
)
model = GenNet(cfg)
trainer.fit(model, datamodule=data_module)
Expand Down
2 changes: 1 addition & 1 deletion high_order_implicit_representation/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def on_train_epoch_end(
f"image", image, global_step=trainer.global_step
)

class Text2ImageGenerator(Callback):
class Text2ImageSampler(Callback):
def __init__(self, filename, rotations, batch_size):
self._dataset = Text2ImageRenderDataset(filename, rotations=rotations)
self._dataloader = DataLoader(self._dataset, batch_size=batch_size, shuffle=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def test_text_to_image_sampler_dataloader():
print("element", element[1][0].shape)
assert element[0].shape[0] ==2
assert len(element) == 3
break
break

0 comments on commit 3a4c2b2

Please sign in to comment.