Skip to content

Commit

Permalink
One text2Image dataset is tested
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 17, 2024
1 parent 66f112c commit f97d053
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion high_order_implicit_representation/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def __getitem__(self, idx):
# ans = self.dataset()
# print('ans', ans)

return next(self.gen_data())
return self.gen_data()


class Text2ImageDataModule(LightningDataModule):
Expand Down
18 changes: 12 additions & 6 deletions tests/test_single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def test_image_neighborhood_reader():


def test_parquet_dataset():
dataset = Text2ImageDataset(
filenames=["test_data/test.parquet"]
)
dataset = Text2ImageDataset(filenames=["test_data/test.parquet"])
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

# caption, position, rgb = next(iter(dataloader))
Expand All @@ -51,6 +49,14 @@ def test_parquet_dataset():


def test_text_to_image_sampler_dataloader():
dataloader = Text2ImageRenderDataset(
filenames=["test_data/test.parquet"]
)
dataset = Text2ImageRenderDataset(filenames=["test_data/test.parquet"])

dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

for element in dataloader:
print("element", element[0][0].shape)
print("element", element[0][1].shape)
print("element", element[1][0].shape)
assert element[0].shape[0] ==2
assert len(element) == 3
break

0 comments on commit f97d053

Please sign in to comment.