Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add function of am streaming inference #84

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ repos:
hooks:
- id: black
additional_dependencies: ['click==8.0.4']
- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/PyCQA/flake8
rev: 3.8.4
hooks:
- id: flake8
args: ['--max-line-length=120', '--extend-ignore=E203']
args: ['--max-line-length=120', '--extend-ignore=E203']
178 changes: 152 additions & 26 deletions kantts/bin/infer_sambert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
)


def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_feature=None):
if norm_type == 'mean_std':
def denorm_f0(
mel, f0_threshold=30, uv_threshold=0.6, norm_type="mean_std", f0_feature=None
):
if norm_type == "mean_std":
f0_mvn = f0_feature

f0 = mel[:, -2]
Expand All @@ -38,7 +40,7 @@ def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_f

mel[:, -2] = f0
mel[:, -1] = uv
else: # global
else: # global
f0_global_max_min = f0_feature

f0 = mel[:, -2]
Expand All @@ -55,9 +57,9 @@ def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_f

return mel

def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None):
inputs_feat_lst = ling_unit.encode_symbol_sequence(symbol_seq)

def sync_preprocess(symbol_seq, ling_unit, device, se=None):
inputs_feat_lst = ling_unit.encode_symbol_sequence(symbol_seq)
inputs_feat_index = 0
if ling_unit.using_byte():
inputs_byte_index = (
Expand Down Expand Up @@ -113,14 +115,20 @@ def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None):
torch.zeros(1).to(device).long() + inputs_emo.size(1) - 1
) # minus 1 for "~"

return inputs_ling, inputs_emo, inputs_spk, inputs_len


# non-streaming inference
def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None):
inputs_ling, inputs_emo, inputs_spk, inputs_len = sync_preprocess(
symbol_seq, ling_unit, device, se
)
res = fsnet(
inputs_ling[:, :-1, :],
inputs_emo[:, :-1],
inputs_spk,
inputs_len,
)

x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
# enc_slf_attn_lst = res["enc_slf_attn_lst"]
Expand Down Expand Up @@ -153,7 +161,79 @@ def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None):
)


def am_infer(sentence, ckpt, output_dir, se_file=None, config=None):
# streaming inference
def am_chunk_synthesis(
symbol_seq, fsnet, ling_unit, device, se=None, mel_chunk_size=48
):
inputs_ling, inputs_emo, inputs_spk, inputs_len = sync_preprocess(
symbol_seq, ling_unit, device, se
)
complete_length = 0
for chunk_id, res in enumerate(
fsnet.chunk_forward(
inputs_ling[:, :-1, :],
inputs_emo[:, :-1],
inputs_spk,
inputs_len,
mel_chunk_size=mel_chunk_size,
)
):
if chunk_id == 0:
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
LR_length_rounded = res["LR_length_rounded"]
log_duration_predictions = res["log_duration_predictions"]
pitch_predictions = res["pitch_predictions"]
energy_predictions = res["energy_predictions"]
valid_length = int(LR_length_rounded[0].item())
duration_predictions = (
(torch.exp(log_duration_predictions) - 1 + 0.5)
.long()
.squeeze()
.cpu()
.numpy()
)
pitch_predictions = pitch_predictions.squeeze().cpu().numpy()
energy_predictions = energy_predictions.squeeze().cpu().numpy()
logging.info(
"x_band_width:{}, h_band_width: {}".format(x_band_width, h_band_width)
)
else:
duration_predictions, pitch_predictions, energy_predictions = (
None,
None,
None,
)

dec_output_chunk = res["dec_output_chunk"]
postnet_output_chunk = res["postnet_output_chunk"]

if complete_length + dec_output_chunk.size(1) > valid_length:
useless_length = complete_length + dec_output_chunk.size(1) - valid_length
dec_output_chunk = dec_output_chunk[0, :-useless_length, :]
postnet_output_chunk = postnet_output_chunk[0, :-useless_length, :]
dec_output_chunk = dec_output_chunk.squeeze().cpu().numpy()
postnet_output_chunk = postnet_output_chunk.squeeze().cpu().numpy()

yield (
dec_output_chunk,
postnet_output_chunk,
duration_predictions,
pitch_predictions,
energy_predictions,
)
complete_length += dec_output_chunk.shape[0]


def am_infer(
sentence,
ckpt,
output_dir,
se_file=None,
config=None,
inference_type="non-streaming",
mel_chunk_size=48,
):
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
Expand All @@ -174,36 +254,38 @@ def am_infer(sentence, ckpt, output_dir, se_file=None, config=None):
ling_unit_size = ling_unit.get_unit_size()
config["Model"]["KanTtsSAMBERT"]["params"].update(ling_unit_size)

se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False)
se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False)
se = np.load(se_file) if se_enable else None

# nsf
nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False)
nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False)
if nsf_enable:
nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_norm_type", "mean_std")
nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get(
"nsf_norm_type", "mean_std"
)
if nsf_norm_type == "mean_std":
f0_mvn_file = os.path.join(
os.path.dirname(os.path.dirname(ckpt)), "mvn.npy"
)
f0_feature = np.load(f0_mvn_file)
else: # global
nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_minimum", 30.0)
nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_maximum", 730.0)
f0_feature = np.load(f0_mvn_file)
else: # global
nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get(
"nsf_f0_global_minimum", 30.0
)
nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get(
"nsf_f0_global_maximum", 730.0
)
f0_feature = [nsf_f0_global_maximum, nsf_f0_global_minimum]

model, _, _ = model_builder(config, device)

fsnet = model["KanTtsSAMBERT"]

logging.info("Loading checkpoint: {}".format(ckpt))
state_dict = torch.load(ckpt)
state_dict = torch.load(ckpt, map_location=device)

fsnet.load_state_dict(state_dict["model"], strict=False)

results_dir = os.path.join(output_dir, "feat")
os.makedirs(results_dir, exist_ok=True)
fsnet.eval()

with open(sentence, encoding="utf-8") as f:
for line in f:
line = line.strip().split("\t")
Expand All @@ -214,12 +296,46 @@ def am_infer(sentence, ckpt, output_dir, se_file=None, config=None):
energy_path = "%s/%s_energy.txt" % (results_dir, line[0])

with torch.no_grad():
mel, mel_post, dur, f0, energy = am_synthesis(
line[1], fsnet, ling_unit, device, se=se
)

if inference_type == "non-streaming":
mel, mel_post, dur, f0, energy = am_synthesis(
line[1],
fsnet,
ling_unit,
device,
se=se,
)
else:
mel_post = None
for chunk_id, (
mel_chunk,
mel_post_chunk,
dur_chunk,
f0_chunk,
energy_chunk,
) in enumerate(
am_chunk_synthesis(
line[1],
fsnet,
ling_unit,
device,
se=se,
mel_chunk_size=mel_chunk_size,
)
):
if chunk_id == 0:
dur, f0, energy = dur_chunk, f0_chunk, energy_chunk
if mel_post is None:
mel_post = mel_post_chunk
else:
mel_post = np.concatenate(
[mel_post, mel_post_chunk], axis=0
)

# FIXME:
if nsf_enable:
mel_post = denorm_f0(mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature)
mel_post = denorm_f0(
mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature
)

np.save(mel_path, mel_post)
np.savetxt(dur_path, dur)
Expand All @@ -233,7 +349,17 @@ def am_infer(sentence, ckpt, output_dir, se_file=None, config=None):
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--se_file", type=str, required=False)
parser.add_argument(
"--inference_type", type=str, required=False, default="non-streaming"
)
parser.add_argument("--mel_chunk_size", type=int, required=False, default=24)

args = parser.parse_args()

am_infer(args.sentence, args.ckpt, args.output_dir, args.se_file)
am_infer(
args.sentence,
args.ckpt,
args.output_dir,
args.se_file,
inference_type=args.inference_type,
mel_chunk_size=args.mel_chunk_size,
)
39 changes: 39 additions & 0 deletions kantts/models/sambert/fsmn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
FSMN Pytorch Version
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -71,6 +72,27 @@ def forward(self, input, mask=None):

return output

def chunk_forward(self, input, mask=None, left_cache=None, rp=None):
if mask is not None:
input = input.masked_fill(mask.unsqueeze(-1), 0)
# padding
if left_cache is None:
x = F.pad(input, (0, 0, self.lp, 0, 0, 0), mode="constant", value=0.0)
else:
x = torch.cat([left_cache, input], dim=1)
# 更新 cache
if rp is not None:
new_left_cache = x[:, -rp - self.lp : x.size(1) - rp] # self.lp
x = F.pad(x, (0, 0, 0, self.rp, 0, 0), mode="constant", value=0.0)
output = (
self.conv_dw(x.contiguous().transpose(1, 2)).contiguous().transpose(1, 2)
)
output += input[:, : output.size(1)]
output = self.dropout(output)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, new_left_cache


class FsmnEncoderV2(nn.Module):
def __init__(
Expand Down Expand Up @@ -122,3 +144,20 @@ def forward(self, input, mask=None):
x = memory

return x

def chunk_forward(self, input, mask=None, left_caches=None, right_pad_size=None):
x = F.dropout(input, self.dropout, self.training)
new_left_caches = []
for ffn, memory_block, left_cache in zip(
self.ffn_lst, self.memory_block_lst, left_caches
):
context = ffn(x)
memory, left_cache = memory_block.chunk_forward(
context, mask, left_cache, right_pad_size
)
new_left_caches.append(left_cache)
memory = F.dropout(memory, self.dropout, self.training)
if memory.size(-1) == x.size(-1):
memory += x
x = memory
return x, new_left_caches
Loading