forked from junyanz/pytorch-CycleGAN-and-pix2pix
-
Notifications
You must be signed in to change notification settings - Fork 0
/
colorization_dataset.py
68 lines (56 loc) · 2.65 KB
/
colorization_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from skimage import color # require skimage
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
class ColorizationDataset(BaseDataset):
"""This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.
This dataset is required by pix2pix-based colorization model ('--model colorization')
"""
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
By default, the number of channels for input image is 1 (L) and
the number of channels for output image is 2 (ab). The direction is from A to B
"""
parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')
return parser
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.dir = os.path.join(opt.dataroot, opt.phase)
self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
self.transform = get_transform(self.opt, convert=False)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) - - the L channel of an image
B (tensor) - - the ab channels of the same image
A_paths (str) - - image paths
B_paths (str) - - image paths (same as A_paths)
"""
path = self.AB_paths[index]
im = Image.open(path).convert('RGB')
im = self.transform(im)
im = np.array(im)
lab = color.rgb2lab(im).astype(np.float32)
lab_t = transforms.ToTensor()(lab)
A = lab_t[[0], ...] / 50.0 - 1.0
B = lab_t[[1, 2], ...] / 110.0
return {'A': A, 'B': B, 'A_paths': path, 'B_paths': path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.AB_paths)