-
Notifications
You must be signed in to change notification settings - Fork 8
/
biquads.py
107 lines (81 loc) · 3.09 KB
/
biquads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import argparse
import pathlib
import torchaudio
from tqdm import tqdm
from models.utils import get_logits2biquads
from harm_and_noise import loader
logits2biquads = get_logits2biquads("coef")
@torch.no_grad()
def get_biquads(m, x):
h = m.feature_trsfm(x)
enc = m.encoder
logits = enc.backbone(h)
coarse_split_size = list(map(sum, enc.split_sizes))
def fn(k: str, apply_trsfm=False):
idx = enc.args_keys.index(k)
start = sum(coarse_split_size[:idx])
end = start + coarse_split_size[idx]
if end - start == 0:
return None
if apply_trsfm:
return enc.trsfms[idx](
*torch.split(logits[..., start:end], enc.split_sizes[idx], dim=-1)
)
return logits[..., start:end].squeeze(-1)
def bq_fn(k: str):
logits_slice = fn(k)
log_gain = logits_slice[..., 0]
biquad_logits = logits_slice[..., 1:].reshape(*logits.shape[:-1], -1, 2)
biquads = logits2biquads(biquad_logits)
return log_gain, biquads
harm_log_gain, harm_biquads = bq_fn("harm_filter_params")
noise_log_gain, noise_biquads = bq_fn("noise_filter_params")
harm_osc_params = fn("harm_oscillator_params", apply_trsfm=True)
voicing = logits[..., 1].sigmoid()
if len(harm_osc_params):
return (
voicing,
harm_log_gain,
harm_biquads,
noise_log_gain,
noise_biquads,
harm_osc_params[0],
)
return voicing, harm_log_gain, harm_biquads, noise_log_gain, noise_biquads
def main():
parser = argparse.ArgumentParser()
parser.add_argument("config", type=str)
parser.add_argument("ckpt", type=str)
parser.add_argument("audio_dir", type=str)
parser.add_argument("outfile", type=str)
parser.add_argument("--duration", type=float, default=6.0)
args = parser.parse_args()
model = loader(args.config, args.ckpt)
model.eval()
audio_dir = pathlib.Path(args.audio_dir)
chunk_size = int(24000 * args.duration)
output_dict = {}
for audio_path in tqdm(list(audio_dir.rglob("*.wav"))):
x, sr = torchaudio.load(audio_path)
assert sr == 24000
for i, x_chunk in enumerate(torch.split(x, chunk_size, dim=1)):
voicing, harm_log_gain, harm_biquads, noise_log_gain, noise_biquads, *_ = (
get_biquads(model, x_chunk)
)
update_dict = {
f"{audio_path.stem}_{i}.harm_log_gain": harm_log_gain,
f"{audio_path.stem}_{i}.harm_biquads": harm_biquads,
f"{audio_path.stem}_{i}.noise_log_gain": noise_log_gain,
f"{audio_path.stem}_{i}.noise_biquads": noise_biquads,
f"{audio_path.stem}_{i}.voicing": voicing,
}
if len(_):
harm_osc_params = _[0]
update_dict[f"{audio_path.stem}_{i}.table_select_weight"] = (
harm_osc_params
)
output_dict.update(update_dict)
torch.save(output_dict, args.outfile)
if __name__ == "__main__":
main()