-
Notifications
You must be signed in to change notification settings - Fork 28
/
test.py
130 lines (109 loc) · 4.92 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import mir_eval
import pretty_midi as pm
from utils import logger
from btc_model import *
from utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths
import argparse
import warnings
warnings.filterwarnings('ignore')
logger.logging_verbosity(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# hyperparameters
parser = argparse.ArgumentParser()
parser.add_argument('--voca', default=True, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--audio_dir', type=str, default='./test')
parser.add_argument('--save_dir', type=str, default='./test')
args = parser.parse_args()
config = HParams.load("run_config.yaml")
if args.voca is True:
config.feature['large_voca'] = True
config.model['num_chords'] = 170
model_file = './test/btc_model_large_voca.pt'
idx_to_chord = idx2voca_chord()
logger.info("label type: large voca")
else:
model_file = './test/btc_model.pt'
idx_to_chord = idx2chord
logger.info("label type: Major and minor")
model = BTC_model(config=config.model).to(device)
# Load model
if os.path.isfile(model_file):
checkpoint = torch.load(model_file)
mean = checkpoint['mean']
std = checkpoint['std']
model.load_state_dict(checkpoint['model'])
logger.info("restore model")
# Audio files with format of wav and mp3
audio_paths = get_audio_paths(args.audio_dir)
# Chord recognition and save lab file
for i, audio_path in enumerate(audio_paths):
logger.info("======== %d of %d in progress ========" % (i + 1, len(audio_paths)))
# Load mp3
feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
logger.info("audio file loaded and feature computation success : %s" % audio_path)
# Majmin type chord recognition
feature = feature.T
feature = (feature - mean) / std
time_unit = feature_per_second
n_timestep = config.model['timestep']
num_pad = n_timestep - (feature.shape[0] % n_timestep)
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
num_instance = feature.shape[0] // n_timestep
start_time = 0.0
lines = []
with torch.no_grad():
model.eval()
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
for t in range(num_instance):
self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
prediction, _ = model.output_layer(self_attn_output)
prediction = prediction.squeeze()
for i in range(n_timestep):
if t == 0 and i == 0:
prev_chord = prediction[i].item()
continue
if prediction[i].item() != prev_chord:
lines.append(
'%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
start_time = time_unit * (n_timestep * t + i)
prev_chord = prediction[i].item()
if t == num_instance - 1 and i + num_pad == n_timestep:
if start_time != time_unit * (n_timestep * t + i):
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
break
# lab file write
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
save_path = os.path.join(args.save_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
with open(save_path, 'w') as f:
for line in lines:
f.write(line)
logger.info("label file saved : %s" % save_path)
# lab file to midi file
starts, ends, pitchs = list(), list(), list()
intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
for p in range(12):
for i, (interval, chord) in enumerate(zip(intervals, chords)):
root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
if i == 0:
start_time = interval[0]
label = tmp_label
continue
if tmp_label != label:
if label == 1.0:
starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
start_time = interval[0]
label = tmp_label
if i == (len(intervals) - 1):
if label == 1.0:
starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
midi = pm.PrettyMIDI()
instrument = pm.Instrument(program=0)
for start, end, pitch in zip(starts, ends, pitchs):
pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
instrument.notes.append(pm_note)
midi.instruments.append(instrument)
midi.write(save_path.replace('.lab', '.midi'))