From aa88e4c0a452ca763c934b8765a2343d4f7f698e Mon Sep 17 00:00:00 2001 From: EricFuma Date: Sat, 7 Oct 2023 15:27:29 +0800 Subject: [PATCH] add am function of streaming inference --- .pre-commit-config.yaml | 4 +- kantts/bin/infer_sambert.py | 178 ++++++++++++++--- kantts/models/sambert/fsmn.py | 39 ++++ kantts/models/sambert/kantts_sambert.py | 245 +++++++++++++++++++++--- 4 files changed, 407 insertions(+), 59 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72d488d..cd298a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,8 +4,8 @@ repos: hooks: - id: black additional_dependencies: ['click==8.0.4'] -- repo: https://gitlab.com/pycqa/flake8 +- repo: https://github.com/PyCQA/flake8 rev: 3.8.4 hooks: - id: flake8 - args: ['--max-line-length=120', '--extend-ignore=E203'] + args: ['--max-line-length=120', '--extend-ignore=E203'] \ No newline at end of file diff --git a/kantts/bin/infer_sambert.py b/kantts/bin/infer_sambert.py index d0fa04d..ff6607d 100644 --- a/kantts/bin/infer_sambert.py +++ b/kantts/bin/infer_sambert.py @@ -23,8 +23,10 @@ ) -def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_feature=None): - if norm_type == 'mean_std': +def denorm_f0( + mel, f0_threshold=30, uv_threshold=0.6, norm_type="mean_std", f0_feature=None +): + if norm_type == "mean_std": f0_mvn = f0_feature f0 = mel[:, -2] @@ -38,7 +40,7 @@ def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_f mel[:, -2] = f0 mel[:, -1] = uv - else: # global + else: # global f0_global_max_min = f0_feature f0 = mel[:, -2] @@ -55,9 +57,9 @@ def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_f return mel -def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None): - inputs_feat_lst = ling_unit.encode_symbol_sequence(symbol_seq) +def sync_preprocess(symbol_seq, ling_unit, device, se=None): + inputs_feat_lst = ling_unit.encode_symbol_sequence(symbol_seq) inputs_feat_index = 0 if ling_unit.using_byte(): inputs_byte_index = ( @@ -113,14 +115,20 @@ def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None): torch.zeros(1).to(device).long() + inputs_emo.size(1) - 1 ) # minus 1 for "~" + return inputs_ling, inputs_emo, inputs_spk, inputs_len + +# non-streaming inference +def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None): + inputs_ling, inputs_emo, inputs_spk, inputs_len = sync_preprocess( + symbol_seq, ling_unit, device, se + ) res = fsnet( inputs_ling[:, :-1, :], inputs_emo[:, :-1], inputs_spk, inputs_len, ) - x_band_width = res["x_band_width"] h_band_width = res["h_band_width"] # enc_slf_attn_lst = res["enc_slf_attn_lst"] @@ -153,7 +161,79 @@ def am_synthesis(symbol_seq, fsnet, ling_unit, device, se=None): ) -def am_infer(sentence, ckpt, output_dir, se_file=None, config=None): +# streaming inference +def am_chunk_synthesis( + symbol_seq, fsnet, ling_unit, device, se=None, mel_chunk_size=48 +): + inputs_ling, inputs_emo, inputs_spk, inputs_len = sync_preprocess( + symbol_seq, ling_unit, device, se + ) + complete_length = 0 + for chunk_id, res in enumerate( + fsnet.chunk_forward( + inputs_ling[:, :-1, :], + inputs_emo[:, :-1], + inputs_spk, + inputs_len, + mel_chunk_size=mel_chunk_size, + ) + ): + if chunk_id == 0: + x_band_width = res["x_band_width"] + h_band_width = res["h_band_width"] + LR_length_rounded = res["LR_length_rounded"] + log_duration_predictions = res["log_duration_predictions"] + pitch_predictions = res["pitch_predictions"] + energy_predictions = res["energy_predictions"] + valid_length = int(LR_length_rounded[0].item()) + duration_predictions = ( + (torch.exp(log_duration_predictions) - 1 + 0.5) + .long() + .squeeze() + .cpu() + .numpy() + ) + pitch_predictions = pitch_predictions.squeeze().cpu().numpy() + energy_predictions = energy_predictions.squeeze().cpu().numpy() + logging.info( + "x_band_width:{}, h_band_width: {}".format(x_band_width, h_band_width) + ) + else: + duration_predictions, pitch_predictions, energy_predictions = ( + None, + None, + None, + ) + + dec_output_chunk = res["dec_output_chunk"] + postnet_output_chunk = res["postnet_output_chunk"] + + if complete_length + dec_output_chunk.size(1) > valid_length: + useless_length = complete_length + dec_output_chunk.size(1) - valid_length + dec_output_chunk = dec_output_chunk[0, :-useless_length, :] + postnet_output_chunk = postnet_output_chunk[0, :-useless_length, :] + dec_output_chunk = dec_output_chunk.squeeze().cpu().numpy() + postnet_output_chunk = postnet_output_chunk.squeeze().cpu().numpy() + + yield ( + dec_output_chunk, + postnet_output_chunk, + duration_predictions, + pitch_predictions, + energy_predictions, + ) + complete_length += dec_output_chunk.shape[0] + + +def am_infer( + sentence, + ckpt, + output_dir, + se_file=None, + config=None, + inference_type="non-streaming", + mel_chunk_size=48, +): if not torch.cuda.is_available(): device = torch.device("cpu") else: @@ -174,36 +254,38 @@ def am_infer(sentence, ckpt, output_dir, se_file=None, config=None): ling_unit_size = ling_unit.get_unit_size() config["Model"]["KanTtsSAMBERT"]["params"].update(ling_unit_size) - se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False) + se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False) se = np.load(se_file) if se_enable else None # nsf - nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False) + nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False) if nsf_enable: - nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_norm_type", "mean_std") + nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get( + "nsf_norm_type", "mean_std" + ) if nsf_norm_type == "mean_std": f0_mvn_file = os.path.join( os.path.dirname(os.path.dirname(ckpt)), "mvn.npy" ) - f0_feature = np.load(f0_mvn_file) - else: # global - nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_minimum", 30.0) - nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_maximum", 730.0) + f0_feature = np.load(f0_mvn_file) + else: # global + nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get( + "nsf_f0_global_minimum", 30.0 + ) + nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get( + "nsf_f0_global_maximum", 730.0 + ) f0_feature = [nsf_f0_global_maximum, nsf_f0_global_minimum] - model, _, _ = model_builder(config, device) - fsnet = model["KanTtsSAMBERT"] - logging.info("Loading checkpoint: {}".format(ckpt)) - state_dict = torch.load(ckpt) + state_dict = torch.load(ckpt, map_location=device) fsnet.load_state_dict(state_dict["model"], strict=False) results_dir = os.path.join(output_dir, "feat") os.makedirs(results_dir, exist_ok=True) fsnet.eval() - with open(sentence, encoding="utf-8") as f: for line in f: line = line.strip().split("\t") @@ -214,12 +296,46 @@ def am_infer(sentence, ckpt, output_dir, se_file=None, config=None): energy_path = "%s/%s_energy.txt" % (results_dir, line[0]) with torch.no_grad(): - mel, mel_post, dur, f0, energy = am_synthesis( - line[1], fsnet, ling_unit, device, se=se - ) - + if inference_type == "non-streaming": + mel, mel_post, dur, f0, energy = am_synthesis( + line[1], + fsnet, + ling_unit, + device, + se=se, + ) + else: + mel_post = None + for chunk_id, ( + mel_chunk, + mel_post_chunk, + dur_chunk, + f0_chunk, + energy_chunk, + ) in enumerate( + am_chunk_synthesis( + line[1], + fsnet, + ling_unit, + device, + se=se, + mel_chunk_size=mel_chunk_size, + ) + ): + if chunk_id == 0: + dur, f0, energy = dur_chunk, f0_chunk, energy_chunk + if mel_post is None: + mel_post = mel_post_chunk + else: + mel_post = np.concatenate( + [mel_post, mel_post_chunk], axis=0 + ) + + # FIXME: if nsf_enable: - mel_post = denorm_f0(mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature) + mel_post = denorm_f0( + mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature + ) np.save(mel_path, mel_post) np.savetxt(dur_path, dur) @@ -233,7 +349,17 @@ def am_infer(sentence, ckpt, output_dir, se_file=None, config=None): parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--ckpt", type=str, required=True) parser.add_argument("--se_file", type=str, required=False) + parser.add_argument( + "--inference_type", type=str, required=False, default="non-streaming" + ) + parser.add_argument("--mel_chunk_size", type=int, required=False, default=24) args = parser.parse_args() - - am_infer(args.sentence, args.ckpt, args.output_dir, args.se_file) + am_infer( + args.sentence, + args.ckpt, + args.output_dir, + args.se_file, + inference_type=args.inference_type, + mel_chunk_size=args.mel_chunk_size, + ) \ No newline at end of file diff --git a/kantts/models/sambert/fsmn.py b/kantts/models/sambert/fsmn.py index be72d89..84d6f71 100644 --- a/kantts/models/sambert/fsmn.py +++ b/kantts/models/sambert/fsmn.py @@ -1,6 +1,7 @@ """ FSMN Pytorch Version """ +import torch import torch.nn as nn import torch.nn.functional as F @@ -71,6 +72,27 @@ def forward(self, input, mask=None): return output + def chunk_forward(self, input, mask=None, left_cache=None, rp=None): + if mask is not None: + input = input.masked_fill(mask.unsqueeze(-1), 0) + # padding + if left_cache is None: + x = F.pad(input, (0, 0, self.lp, 0, 0, 0), mode="constant", value=0.0) + else: + x = torch.cat([left_cache, input], dim=1) + # 更新 cache + if rp is not None: + new_left_cache = x[:, -rp - self.lp : x.size(1) - rp] # self.lp + x = F.pad(x, (0, 0, 0, self.rp, 0, 0), mode="constant", value=0.0) + output = ( + self.conv_dw(x.contiguous().transpose(1, 2)).contiguous().transpose(1, 2) + ) + output += input[:, : output.size(1)] + output = self.dropout(output) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + return output, new_left_cache + class FsmnEncoderV2(nn.Module): def __init__( @@ -122,3 +144,20 @@ def forward(self, input, mask=None): x = memory return x + + def chunk_forward(self, input, mask=None, left_caches=None, right_pad_size=None): + x = F.dropout(input, self.dropout, self.training) + new_left_caches = [] + for ffn, memory_block, left_cache in zip( + self.ffn_lst, self.memory_block_lst, left_caches + ): + context = ffn(x) + memory, left_cache = memory_block.chunk_forward( + context, mask, left_cache, right_pad_size + ) + new_left_caches.append(left_cache) + memory = F.dropout(memory, self.dropout, self.training) + if memory.size(-1) == x.size(-1): + memory += x + x = memory + return x, new_left_caches \ No newline at end of file diff --git a/kantts/models/sambert/kantts_sambert.py b/kantts/models/sambert/kantts_sambert.py index 91ce5b9..484995a 100644 --- a/kantts/models/sambert/kantts_sambert.py +++ b/kantts/models/sambert/kantts_sambert.py @@ -611,6 +611,36 @@ def forward( return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list + def chunk_forward( + self, + memory, + x_band_width, + h_band_width, + mask=None, + return_attns=False, + ): + # stream module is only use for inference + batch_size = memory.size(0) + go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device) + self.mel_dec.reset_state() + input = go_frame + for step in range(memory.size(1)): + ( + dec_output_step, + dec_pnca_attn_x_step, + dec_pnca_attn_h_step, + ) = self.mel_dec.infer( + step, + input, + memory, + x_band_width, + h_band_width, + mask=mask, + return_attns=return_attns, + ) + input = dec_output_step[:, :, -self.d_mel :] + yield dec_output_step, dec_pnca_attn_x_step, dec_pnca_attn_h_step + class PostNet(nn.Module): def __init__(self, config): @@ -716,7 +746,9 @@ def __init__(self, config): self.text_encoder = TextFftEncoder(config) self.se_enable = config.get("SE", False) if not self.se_enable: - self.spk_tokenizer = nn.Embedding(config["speaker"], config["speaker_units"]) + self.spk_tokenizer = nn.Embedding( + config["speaker"], config["speaker_units"] + ) self.emo_tokenizer = nn.Embedding(config["emotion"], config["emotion_units"]) self.variance_adaptor = VarianceAdaptor(config) self.mel_decoder = MelPNCADecoder(config) @@ -859,7 +891,7 @@ def insert_fp( ) return text_hid, inputs_emotion, inputs_speaker, inter_lengths - def forward( + def pre_forward( self, inputs_ling, inputs_emotion, @@ -874,7 +906,6 @@ def forward( fp_label=None, ): batch_size = inputs_ling.size(0) - is_training = mel_targets is not None input_masks = get_mask_from_lengths(input_lengths, max_len=inputs_ling.size(1)) @@ -925,7 +956,9 @@ def forward( duration_targets[i, input_lengths[i]] = padding emo_hid = self.emo_tokenizer(inputs_emotion) - spk_hid = inputs_speaker if self.se_enable else self.spk_tokenizer(inputs_speaker) + spk_hid = ( + inputs_speaker if self.se_enable else self.spk_tokenizer(inputs_speaker) + ) inter_masks = get_mask_from_lengths(inter_lengths, max_len=text_hid.size(1)) @@ -991,6 +1024,59 @@ def forward( + 0.5 ) h_band_width = x_band_width + res = { + "x_band_width": x_band_width, + "h_band_width": h_band_width, + "enc_slf_attn_lst": enc_sla_attn_lst, + "LR_length_rounded": LR_length_rounded, + "log_duration_predictions": log_duration_predictions, + "pitch_predictions": pitch_predictions, + "energy_predictions": energy_predictions, + "duration_targets": duration_targets, + "pitch_targets": pitch_targets, + "energy_targets": energy_targets, + "fp_predictions": FP_p, + "valid_inter_lengths": inter_lengths, + "LR_text_outputs": LR_text_outputs, + "LR_emo_outputs": LR_emo_outputs, + "LR_spk_outputs": LR_spk_outputs, + } + if self.MAS and is_training: + res["attn_soft"] = attn_soft + res["attn_hard"] = attn_hard + res["attn_logprob"] = attn_logprob + return memory, lfr_masks, output_masks, res + + def forward( + self, + inputs_ling, + inputs_emotion, + inputs_speaker, + input_lengths, + output_lengths=None, + mel_targets=None, + duration_targets=None, + pitch_targets=None, + energy_targets=None, + attn_priors=None, + fp_label=None, + ): + batch_size = inputs_ling.size(0) + memory, lfr_masks, output_masks, res = self.pre_forward( + inputs_ling, + inputs_emotion, + inputs_speaker, + input_lengths, + output_lengths=output_lengths, + mel_targets=mel_targets, + duration_targets=duration_targets, + pitch_targets=pitch_targets, + energy_targets=energy_targets, + attn_priors=attn_priors, + fp_label=fp_label, + ) + x_band_width = res["x_band_width"] + h_band_width = res["h_band_width"] dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder( memory, @@ -1013,35 +1099,132 @@ def forward( if output_masks is not None: postnet_outputs = postnet_outputs.masked_fill(output_masks.unsqueeze(-1), 0) - res = { - "x_band_width": x_band_width, - "h_band_width": h_band_width, - "enc_slf_attn_lst": enc_sla_attn_lst, - "pnca_x_attn_lst": pnca_x_attn_lst, - "pnca_h_attn_lst": pnca_h_attn_lst, - "dec_outputs": dec_outputs, - "postnet_outputs": postnet_outputs, - "LR_length_rounded": LR_length_rounded, - "log_duration_predictions": log_duration_predictions, - "pitch_predictions": pitch_predictions, - "energy_predictions": energy_predictions, - "duration_targets": duration_targets, - "pitch_targets": pitch_targets, - "energy_targets": energy_targets, - "fp_predictions": FP_p, - "valid_inter_lengths": inter_lengths, - } + res["pnca_x_attn_lst"] = pnca_x_attn_lst + res["pnca_h_attn_lst"] = pnca_h_attn_lst + res["dec_outputs"] = dec_outputs + res["postnet_outputs"] = postnet_outputs + return res - res["LR_text_outputs"] = LR_text_outputs - res["LR_emo_outputs"] = LR_emo_outputs - res["LR_spk_outputs"] = LR_spk_outputs + # Use only for inference + def chunk_forward( + self, + inputs_ling, + inputs_emotion, + inputs_speaker, + input_lengths, + output_lengths=None, + attn_priors=None, + fp_label=None, + mel_chunk_size=48, + ): + batch_size = inputs_ling.size(0) + memory, lfr_masks, output_masks, res = self.pre_forward( + inputs_ling, + inputs_emotion, + inputs_speaker, + input_lengths, + output_lengths=output_lengths, + attn_priors=attn_priors, + fp_label=fp_label, + ) + x_band_width = res["x_band_width"] + h_band_width = res["h_band_width"] + + # mel_decoder + complete_length = 0 + dec_outputs = torch.empty( + batch_size, + 0, + self.mel_decoder.d_mel, + dtype=memory.dtype, + device=memory.device, + ) + total_length = memory.size(1) * 3 - if self.MAS and is_training: - res["attn_soft"] = attn_soft - res["attn_hard"] = attn_hard - res["attn_logprob"] = attn_logprob + # initialize cache + h0 = torch.zeros( + [batch_size, 1, self.mel_postnet.lstm_units], device=memory.device + ) + c0 = torch.zeros( + [batch_size, 1, self.mel_postnet.lstm_units], device=memory.device + ) + left_memory_caches = [ + None for _ in range(len(self.mel_postnet.fsmn.memory_block_lst)) + ] - return res + # size of right side receptive filed: 12 + receptive_field_size = self.mel_postnet.fsmn.memory_block_lst[0].rp * len( + self.mel_postnet.fsmn.memory_block_lst + ) + for ( + dec_output_step, + dec_pnca_attn_x_step, + dec_pnca_attn_h_step, + ) in self.mel_decoder.chunk_forward( + memory, x_band_width, h_band_width, mask=lfr_masks, return_attns=True + ): + dec_output_step = dec_output_step.contiguous().view( + batch_size, -1, self.mel_decoder.d_mel + ) + if output_masks is not None: + dec_output_step = dec_output_step.masked_fill( + output_masks.unsqueeze(-1)[ + :, + dec_outputs.size(1) : dec_outputs.size(1) + + dec_output_step.size(1), + :, + ], + 0, # NOQA + ) + dec_outputs = torch.concat([dec_outputs, dec_output_step], dim=1) + # mel postnet + target_length = complete_length + mel_chunk_size + receptive_field_size + if ( + dec_outputs.size(1) >= target_length + or dec_outputs.size(1) == total_length + ): + + # Cache + dec_output_chunk = dec_outputs[:, complete_length:target_length] + ( + postnet_fsmn_output, + left_memory_caches, + ) = self.mel_postnet.fsmn.chunk_forward( + dec_output_chunk, + output_masks[:, complete_length:target_length], + left_memory_caches, + max(0, dec_output_chunk.size(1) - mel_chunk_size), + ) + postnet_fsmn_output = postnet_fsmn_output[:, :mel_chunk_size] + postnet_lstm_output, (h0, c0) = self.mel_postnet.lstm( + postnet_fsmn_output, (h0, c0) + ) + mel_residual_output = self.mel_postnet.fc(postnet_lstm_output) + mel_residual_output = ( + mel_residual_output + + dec_outputs[:, complete_length : complete_length + mel_chunk_size] + ) + if output_masks is not None: + postnet_output = mel_residual_output.masked_fill( + output_masks[ + :, complete_length : complete_length + mel_chunk_size + ].unsqueeze(-1), + 0, + ) + res["postnet_output_chunk"] = postnet_output + res["dec_output_chunk"] = dec_outputs[ + :, complete_length : complete_length + mel_chunk_size + ] + """ TODO + res["pnca_x_attn_lst"] = pnca_x_attn_lst + res["pnca_h_attn_lst"] = pnca_h_attn_lst + """ + yield res + complete_length += postnet_output.size(1) + # Only the first chunk returns additional information + # and subsequent packets only return `postnet_output_chunk` and `dec_output_chunk` + # to avoid redundant information transmission. + res = dict() class KanTtsTextsyBERT(nn.Module): @@ -1065,4 +1248,4 @@ def forward(self, inputs_ling, input_lengths): res["logits"] = logits res["enc_slf_attn_lst"] = enc_sla_attn_lst - return res + return res \ No newline at end of file