Skip to content

Commit

Permalink
finetuning code.
Browse files Browse the repository at this point in the history
  • Loading branch information
dxli94 committed Jun 22, 2023
1 parent 74ae950 commit 9dd7d3d
Show file tree
Hide file tree
Showing 49 changed files with 2,033 additions and 228 deletions.
14 changes: 14 additions & 0 deletions lavis/configs/datasets/blip_diffusion_datasets/defaults.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause

datasets:
blip_diffusion_finetune: # name of the dataset builder
# data_dir: ${env.data_dir}/datasets
data_type: images # [images|videos|features]

build_info:
# Be careful not to append minus sign (-) before split to avoid itemizing
images:
storage: ""
4 changes: 2 additions & 2 deletions lavis/configs/models/blip-diffusion/blip_diffusion_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ model:
preprocess:
vis_processor:
train:
name: "blip_diffusion_image_eval"
name: "blip_diffusion_inp_image_eval"
eval:
name: "blip_diffusion_image_eval"
name: "blip_diffusion_inp_image_eval"
text_processor:
train:
name: "blip_caption"
Expand Down
2 changes: 2 additions & 0 deletions lavis/datasets/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
Flickr30kBuilder,
)
from lavis.datasets.builders.dialogue_builder import AVSDDialBuilder
from lavis.datasets.builders.text_to_image_generation_builder import BlipDiffusionFinetuneBuilder

from lavis.common.registry import registry

__all__ = [
"BlipDiffusionFinetuneBuilder",
"COCOCapBuilder",
"COCORetrievalBuilder",
"COCOVQABuilder",
Expand Down
10 changes: 9 additions & 1 deletion lavis/datasets/builders/base_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(self, cfg=None):
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}

# additional processors, each specified by a name in string.
self.kw_processors = {}

def build_datasets(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
Expand Down Expand Up @@ -73,7 +76,12 @@ def build_processors(self):

self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)


kw_proc_cfg = self.config.get("kw_processor")
if kw_proc_cfg is not None:
for name, cfg in kw_proc_cfg.items():
self.kw_processors[name] = self._build_proc_from_cfg(cfg)

@staticmethod
def _build_proc_from_cfg(cfg):
return (
Expand Down
39 changes: 39 additions & 0 deletions lavis/datasets/builders/text_to_image_generation_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from lavis.common.registry import registry
from lavis.datasets.datasets.subject_driven_t2i_dataset import (
SubjectDrivenTextToImageDataset,
)
from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder


@registry.register_builder("blip_diffusion_finetune")
class BlipDiffusionFinetuneBuilder(BaseDatasetBuilder):
train_dataset_cls = SubjectDrivenTextToImageDataset

DATASET_CONFIG_DICT = {
"default": "configs/datasets/blip_diffusion_datasets/defaults.yaml"
}

def _download_ann(self):
pass

def build(self):
self.build_processors()

build_info = self.config.build_info

dataset = self.train_dataset_cls(
image_dir=build_info.images.storage,
subject_text=build_info.subject_text,
inp_image_processor=self.kw_processors["inp_vis_processor"],
tgt_image_processor=self.kw_processors["tgt_vis_processor"],
txt_processor=self.text_processors["eval"],
)

return {"train": dataset}
72 changes: 72 additions & 0 deletions lavis/datasets/datasets/subject_driven_t2i_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os

from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate


class SubjectDrivenTextToImageDataset(Dataset):
def __init__(
self,
image_dir,
subject_text,
inp_image_processor,
tgt_image_processor,
txt_processor,
repetition=100000,
):
self.subject = txt_processor(subject_text.lower())
self.image_dir = image_dir

self.inp_image_transform = inp_image_processor
self.tgt_image_transform = tgt_image_processor

self.text_processor = txt_processor

image_paths = os.listdir(image_dir)
# image paths are jpg png webp
image_paths = [
os.path.join(image_dir, imp)
for imp in image_paths
if os.path.splitext(imp)[1][1:]
in ["jpg", "png", "webp", "jpeg", "JPG", "PNG", "WEBP", "JPEG"]
]
# make absolute path
self.image_paths = [os.path.abspath(imp) for imp in image_paths]
self.repetition = repetition

def __len__(self):
return len(self.image_paths) * self.repetition

@property
def len_without_repeat(self):
return len(self.image_paths)

def collater(self, samples):
return default_collate(samples)

def __getitem__(self, index):
image_path = self.image_paths[index % len(self.image_paths)]
image = Image.open(image_path).convert("RGB")

# For fine-tuning, we use the same caption for all images
# maybe worth trying different captions for different images
caption = f"a {self.subject}"
caption = self.text_processor(caption)

inp_image = self.inp_image_transform(image)
tgt_image = self.tgt_image_transform(image)

return {
"inp_image": inp_image,
"tgt_image": tgt_image,
"caption": caption,
"subject_text": self.subject,
}
2 changes: 2 additions & 0 deletions lavis/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def load_checkpoint_from_config(self, cfg, **kwargs):
assert "Found load_finetuned is False, but pretrain_path is None."
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)

def before_training(self, **kwargs):
pass

def before_evaluation(self, **kwargs):
pass
Expand Down
Loading

0 comments on commit 9dd7d3d

Please sign in to comment.