forked from dessa-oss/DeepFake-Detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
128 lines (93 loc) · 5.02 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from pathlib import Path
import numpy as np
import dlib
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from utils import load_and_preprocess_image
def collate_fn(batch):
imgs = [item['image'] for item in batch if item['image'] is not None]
targets = [item['label'] for item in batch if item['image'] is not None]
filenames = [item['filename'] for item in batch if item['image'] is not None]
imgs = torch.stack(imgs)
targets = torch.stack(targets)
return {'image': imgs, 'label': targets, 'filename': filenames}
def get_transforms():
pre_trained_mean, pre_trained_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomAffine(degrees=40, scale=(.9, 1.1), shear=0),
transforms.RandomPerspective(distortion_scale=0.2),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.ToTensor(),
transforms.RandomErasing(scale=(0.02, 0.16), ratio=(0.3, 1.6)),
transforms.Normalize(mean=pre_trained_mean, std=pre_trained_std),
])
val_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=pre_trained_mean, std=pre_trained_std)
])
return train_transforms, val_transforms
class FFDataset(Dataset):
def __init__(self, filenames, filepath, transform, output_image_size=224, recompute=False):
self.filenames = filenames
self.transform = transform
self.image_size = output_image_size
self.recompute = recompute
self.cached_path = Path(filepath)
self.cached_path.mkdir(exist_ok=True)
self.face_detector = dlib.get_frontal_face_detector()
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx: int):
filename = self.filenames[idx]
image_id = filename.stem
filename = str(filename)
label = 1 if 'fake' in filename.split('/') else 0
preprocessed_filename = self.cached_path / f'processed_{image_id}.npy'
if preprocessed_filename.is_file() and not self.recompute:
image = np.load(preprocessed_filename)
else:
image = load_and_preprocess_image(filename, self.image_size, self.face_detector)
if image is None:
image = []
np.save(preprocessed_filename, image)
if len(image) == 0:
return {'image': None, 'label': None ,'filename': filename}
image = Image.fromarray(image)
image = self.transform(image)
label = torch.tensor(label)
return {'image': image, 'label': label, 'filename': filename}
def create_dataloaders(params):
train_transforms, val_transforms = get_transforms()
train_dl = _create_dataloader(f'/datasets/{params["train_data"]}_deepfake', mode='train', batch_size=params['batch_size'],
transformations=train_transforms, sample_ratio=params['sample_ratio'])
val_base_dl = _create_dataloader(f'/datasets/base_deepfake/val', mode='val', batch_size=params['batch_size'], transformations=val_transforms,
sample_ratio=params['sample_ratio'])
val_augment_dl = _create_dataloader(f'/datasets/augment_deepfake/val', mode='val', batch_size=params['batch_size'], transformations=val_transforms,
sample_ratio=params['sample_ratio'])
display_file_paths = [f'/datasets/{i}_deepfake/val' for i in ['base', 'augment']]
display_dl_iter = iter(_create_dataloader(display_file_paths, mode='val', batch_size=32, transformations=val_transforms,
sample_ratio=params['sample_ratio']))
return train_dl, val_base_dl, val_augment_dl, display_dl_iter
def _create_dataloader(file_paths, mode, batch_size, transformations, sample_ratio, num_workers=80):
if not isinstance(file_paths, list):
file_paths = [file_paths]
filenames = []
for file_path in file_paths:
data_path = Path(file_path)
real_frame_filenames = _find_filenames(data_path / 'real/frames/', '*.png')
fake_frame_filenames = _find_filenames(data_path / 'fake/frames/', '*.png')
filenames += real_frame_filenames
filenames += fake_frame_filenames
assert len(filenames) != 0, f'filenames are empty {filenames}'
np.random.shuffle(filenames)
if mode == 'train':
filenames = filenames[:int(sample_ratio * len(filenames))]
ds = FFDataset(filenames, filepath=f'/datasets/precomputed/', transform=transformations, recompute=False)
dl = DataLoader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, collate_fn=collate_fn)
print(f"{mode} data: {len(ds)}")
return dl
def _find_filenames(file_dir_path, file_pattern): return list(file_dir_path.glob(file_pattern))