forked from lllyasviel/ControlNet
-
Notifications
You must be signed in to change notification settings - Fork 2
/
tutorial_dataset.py
92 lines (66 loc) · 2.87 KB
/
tutorial_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import json # We need to use the JSON package to load the data, since the data is stored in JSON format
import cv2
import numpy as np
from torch.utils.data import Dataset
import os
from annotator.util import HWC3, resize_image, resize_image_square
ROOT = "/gpfs/space/projects/stud_ml_22/ControlNet-different-backbones/"
# ROOT = "./"
class Fill50kDataset(Dataset):
def __init__(self):
self.data = []
with open(os.path.join(ROOT, 'data/fill50k/prompt.json'), 'rt') as f:
for line in f:
# d = json.loads(line)
# if d['source'].split('/')[1].split('.')[0].startswith('200'):
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
source_filename = item['source']
target_filename = item['target']
prompt = item['prompt']
source = cv2.imread(os.path.join(ROOT, 'data/fill50k/', source_filename))
target = cv2.imread(os.path.join(ROOT, 'data/fill50k/', target_filename))
source = resize_image(source, 256)
target = resize_image(target, 256)
# Do not forget that OpenCV read images in BGR order.
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0
# Normalize target images to [-1, 1].
target = (target.astype(np.float32) / 127.5) - 1.0
return dict(jpg=target, txt=prompt, hint=source)
class ValDataset(Dataset):
def __init__(self, dataset_name):
self.data = []
if dataset_name == 'fill50k':
names = ['fill50k']
else:
names = ['things', 'laion-art', 'CC3M']
for ds in names:
with open(os.path.join(ROOT, 'data/', ds, 'val_data.json'), 'rt') as f:
for line in f:
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
source_path = item['source']
target_path = item['target']
prompt = item['prompt'][:77]
ds_label = item['ds_label']
source = cv2.imread(source_path)
target = cv2.imread(target_path)
source = resize_image_square(source, 256)
target = resize_image_square(target, 256)
# Do not forget that OpenCV read images in BGR order.
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0
# Normalize target images to [-1, 1].
target = (target.astype(np.float32) / 127.5) - 1.0
return dict(jpg=target, txt=prompt, hint=source, ds_label=ds_label)