-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
51 lines (42 loc) · 1.82 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torchvision.datasets as dsets
from torchvision import transforms
class Data_Loader():
def __init__(self, train, dataset, image_path, image_size, batch_size, shuf=True):
self.dataset = dataset
self.path = image_path
self.imsize = image_size
self.batch = batch_size
self.shuf = shuf
self.train = train
def transform(self, resize, totensor, normalize, centercrop):
options = []
if centercrop:
options.append(transforms.CenterCrop(160))
if resize:
options.append(transforms.Resize((self.imsize,self.imsize)))
if totensor:
options.append(transforms.ToTensor())
if normalize:
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = transforms.Compose(options)
return transform
def load_lsun(self, classes='church_outdoor_train'):
transforms = self.transform(True, True, True, False)
dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
return dataset
def load_celeb(self):
transforms = self.transform(True, True, True, True)
dataset = dsets.ImageFolder(self.path+'/CelebA', transform=transforms)
return dataset
def loader(self):
if self.dataset == 'lsun':
dataset = self.load_lsun()
elif self.dataset == 'celeb':
dataset = self.load_celeb()
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=self.batch,
shuffle=self.shuf,
num_workers=2,
drop_last=True)
return loader