-
Notifications
You must be signed in to change notification settings - Fork 974
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
49 changed files
with
2,033 additions
and
228 deletions.
There are no files selected for viewing
14 changes: 14 additions & 0 deletions
14
lavis/configs/datasets/blip_diffusion_datasets/defaults.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2022, salesforce.com, inc. | ||
# All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
|
||
datasets: | ||
blip_diffusion_finetune: # name of the dataset builder | ||
# data_dir: ${env.data_dir}/datasets | ||
data_type: images # [images|videos|features] | ||
|
||
build_info: | ||
# Be careful not to append minus sign (-) before split to avoid itemizing | ||
images: | ||
storage: "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
lavis/datasets/builders/text_to_image_generation_builder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
from lavis.common.registry import registry | ||
from lavis.datasets.datasets.subject_driven_t2i_dataset import ( | ||
SubjectDrivenTextToImageDataset, | ||
) | ||
from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder | ||
|
||
|
||
@registry.register_builder("blip_diffusion_finetune") | ||
class BlipDiffusionFinetuneBuilder(BaseDatasetBuilder): | ||
train_dataset_cls = SubjectDrivenTextToImageDataset | ||
|
||
DATASET_CONFIG_DICT = { | ||
"default": "configs/datasets/blip_diffusion_datasets/defaults.yaml" | ||
} | ||
|
||
def _download_ann(self): | ||
pass | ||
|
||
def build(self): | ||
self.build_processors() | ||
|
||
build_info = self.config.build_info | ||
|
||
dataset = self.train_dataset_cls( | ||
image_dir=build_info.images.storage, | ||
subject_text=build_info.subject_text, | ||
inp_image_processor=self.kw_processors["inp_vis_processor"], | ||
tgt_image_processor=self.kw_processors["tgt_vis_processor"], | ||
txt_processor=self.text_processors["eval"], | ||
) | ||
|
||
return {"train": dataset} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import os | ||
|
||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
from torch.utils.data.dataloader import default_collate | ||
|
||
|
||
class SubjectDrivenTextToImageDataset(Dataset): | ||
def __init__( | ||
self, | ||
image_dir, | ||
subject_text, | ||
inp_image_processor, | ||
tgt_image_processor, | ||
txt_processor, | ||
repetition=100000, | ||
): | ||
self.subject = txt_processor(subject_text.lower()) | ||
self.image_dir = image_dir | ||
|
||
self.inp_image_transform = inp_image_processor | ||
self.tgt_image_transform = tgt_image_processor | ||
|
||
self.text_processor = txt_processor | ||
|
||
image_paths = os.listdir(image_dir) | ||
# image paths are jpg png webp | ||
image_paths = [ | ||
os.path.join(image_dir, imp) | ||
for imp in image_paths | ||
if os.path.splitext(imp)[1][1:] | ||
in ["jpg", "png", "webp", "jpeg", "JPG", "PNG", "WEBP", "JPEG"] | ||
] | ||
# make absolute path | ||
self.image_paths = [os.path.abspath(imp) for imp in image_paths] | ||
self.repetition = repetition | ||
|
||
def __len__(self): | ||
return len(self.image_paths) * self.repetition | ||
|
||
@property | ||
def len_without_repeat(self): | ||
return len(self.image_paths) | ||
|
||
def collater(self, samples): | ||
return default_collate(samples) | ||
|
||
def __getitem__(self, index): | ||
image_path = self.image_paths[index % len(self.image_paths)] | ||
image = Image.open(image_path).convert("RGB") | ||
|
||
# For fine-tuning, we use the same caption for all images | ||
# maybe worth trying different captions for different images | ||
caption = f"a {self.subject}" | ||
caption = self.text_processor(caption) | ||
|
||
inp_image = self.inp_image_transform(image) | ||
tgt_image = self.tgt_image_transform(image) | ||
|
||
return { | ||
"inp_image": inp_image, | ||
"tgt_image": tgt_image, | ||
"caption": caption, | ||
"subject_text": self.subject, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.