-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
134 lines (114 loc) · 5.32 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
'''
Datasets
'''
import os
import random
import torch
import torchaudio
import numpy as np
import pandas as pd
import soundfile as sf
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
class CNCeleb(Dataset):
def __init__(self, train_list, train_path, num_frames, **kwargs):
self.train_path = train_path
self.num_frames = num_frames
if os.path.exists(train_list):
print('load {}'.format(train_list))
df = pd.read_csv(train_list)
speaker_int_labels = []
utt_paths = []
for (utt_path, label) in zip(df["utt_path"].values, df["speaker_int_label"].values):
if utt_path[-4:] == 'flac':
utt_paths.append(utt_path)
speaker_int_labels.append(label)
else:
utt_tuples, speakers = findAllUtt(train_path, extension='flac', speaker_level=1)
utt_tuples = np.array(utt_tuples, dtype=str)
utt_paths = utt_tuples.T[0]
speaker_int_labels = utt_tuples.T[1].astype(int)
speaker_str_labels = []
for i in speaker_int_labels:
speaker_str_labels.append(speakers[i])
csv_dict = {"speaker_str_label": speaker_str_labels,
"utt_path": utt_paths,
"speaker_int_label": speaker_int_labels
}
df = pd.DataFrame(data=csv_dict)
try:
df.to_csv(train_list)
print(f'Saved data list file at {train_list}')
except OSError as err:
print(f'Ran in an error while saving {train_list}: {err}')
# Load data & labels
self.data_list = utt_paths
self.data_label = speaker_int_labels
self.n_class = len(np.unique(self.data_label))
print("find {} speakers".format(self.n_class))
print("find {} utterance".format(len(self.data_list)))
def __getitem__(self, index):
audio, sr = sf.read(self.data_list[index])
length = self.num_frames * 160 + 240
if audio.shape[0] <= length:
shortage = length - audio.shape[0]
audio = np.pad(audio, (0, shortage), 'wrap')
start_frame = np.int64(random.random() * (audio.shape[0] - length))
audio = audio[start_frame:start_frame + length]
audio = np.stack([audio], axis=0)
return torch.FloatTensor(audio[0]), self.data_label[index]
def __len__(self):
return len(self.data_list)
def findAllUtt(dirName, extension='flac', speaker_level=1):
if dirName[-1] != os.sep:
dirName += os.sep
prefixSize = len(dirName)
# speaker_dict:{speaker_str_label:speaker_int_label}
# utt_tuple:(utt_path,speaker_int_label)
speaker_dict = {}
utt_tuples = []
print("finding {}, Waiting...".format(extension))
for root, dirs, filenames in tqdm(os.walk(dirName, followlinks=True)):
filtered_files = [f for f in filenames if f.endswith(extension)]
if len(filtered_files) > 0:
speaker_str_label = root[prefixSize:].split(os.sep)[0]
if speaker_str_label not in speaker_dict.keys():
speaker_dict[speaker_str_label] = len(speaker_dict)
speaker_int_label = speaker_dict[speaker_str_label]
for filename in filtered_files:
utt_path = os.path.join(root, filename)
utt_tuples.append((utt_path, speaker_int_label))
outSpeakers = [None]*len(speaker_dict)
for key, index in speaker_dict.items():
outSpeakers[index] = key
print("find {} speakers".format(len(outSpeakers)))
print("find {} utterance".format(len(utt_tuples)))
# return [(utt_path:speaker_int_label), ...], [id00012, id00031, ...]
return utt_tuples, outSpeakers
def create_cnceleb_trails(cnceleb_root, trails_path, extension='flac'):
enroll_lst_path = os.path.join(cnceleb_root, "eval/lists/enroll.lst")
raw_trl_path = os.path.join(cnceleb_root, "eval/lists/trials.lst")
spk2wav_mapping = {}
enroll_lst = np.loadtxt(enroll_lst_path, str)
for item in tqdm(enroll_lst, desc='speaker mapping', mininterval=2, ncols=50):
path = os.path.splitext(item[1])
spk2wav_mapping[item[0]] = path[0] + '.{}'.format(extension)
trials = np.loadtxt(raw_trl_path, str)
with open(trails_path, "w") as f:
for item in tqdm(trials, desc='handle trials', mininterval=2, ncols=50):
enroll_path = os.path.join(cnceleb_root, "eval", spk2wav_mapping[item[0]])
test_path = os.path.join(cnceleb_root, "eval", item[1])
test_path = os.path.splitext(test_path)[0] + '.{}'.format(extension)
label = item[2]
f.write("{} {} {}\n".format(label, enroll_path, test_path))
if __name__ == "__main__":
cn1_root = '/home2/database/sre/CN-Celeb-2022/task1/cn_1'
cn2_dev = '/home2/database/sre/CN-Celeb-2022/task1/cn_2/data'
train_list_path = 'data/cn2_train_list.csv'
dataset = CNCeleb(train_list_path, cn1_root, 200)
loader = DataLoader(dataset, batch_size=5, shuffle=True)
for idx, batch in enumerate(loader):
data, label = batch
print('data:', data.shape, data)
print('label', label.shape, label)
break