From 3535cfe3b0e39db91c71d6269d96f5495d4d300a Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Sun, 10 Mar 2024 21:37:28 +0800 Subject: [PATCH] =?UTF-8?q?=09=E6=96=B0=E5=A2=9EVITS=E6=89=B9=E9=87=8F?= =?UTF-8?q?=E6=8E=A8=E7=90=86=20=20=20GPT=5FSoVITS/TTS=5Finfer=5Fpack/TTS.?= =?UTF-8?q?py=20=09fix=20some=20bugs=20=20=20GPT=5FSoVITS/TTS=5Finfer=5Fpa?= =?UTF-8?q?ck/TextPreprocessor.py=20=09fix=20some=20bugs=20=20=20GPT=5FSoV?= =?UTF-8?q?ITS/TTS=5Finfer=5Fpack/text=5Fsegmentation=5Fmethod.py=20=09fix?= =?UTF-8?q?=20some=20bugs=20=20=20GPT=5FSoVITS/inference=5Fwebui.py=20=09f?= =?UTF-8?q?ix=20some=20bugs=20=20=20GPT=5FSoVITS/module/models.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 108 ++++++++++++------ GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 20 +++- .../text_segmentation_method.py | 28 ++++- GPT_SoVITS/inference_webui.py | 20 +++- GPT_SoVITS/module/models.py | 50 ++++++++ 5 files changed, 182 insertions(+), 44 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 7cfaf46b3..c11103469 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,3 +1,4 @@ +import math import os, sys now_dir = os.getcwd() sys.path.append(now_dir) @@ -366,6 +367,7 @@ def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold for batch_idx, index_list in enumerate(batch_index_list): item_list = [data[idx] for idx in index_list] phones_list = [] + phones_len_list = [] # bert_features_list = [] all_phones_list = [] all_phones_len_list = [] @@ -375,24 +377,26 @@ def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold phones_max_len = 0 for item in item_list: if prompt_data is not None: - all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1) + all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]) phones = torch.LongTensor(item["phones"]) # norm_text = prompt_data["norm_text"]+item["norm_text"] else: all_bert_features = item["bert_features"] phones = torch.LongTensor(item["phones"]) - all_phones = phones.clone() + all_phones = phones # norm_text = item["norm_text"] bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) phones_max_len = max(phones_max_len, phones.shape[-1]) phones_list.append(phones) + phones_len_list.append(phones.shape[-1]) all_phones_list.append(all_phones) all_phones_len_list.append(all_phones.shape[-1]) all_bert_features_list.append(all_bert_features) norm_text_batch.append(item["norm_text"]) + phones_batch = phones_list max_len = max(bert_max_len, phones_max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) @@ -406,6 +410,7 @@ def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold batch = { "phones": phones_batch, + "phones_len": torch.LongTensor(phones_len_list), "all_phones": all_phones_batch, "all_phones_len": torch.LongTensor(all_phones_len_list), "all_bert_features": all_bert_features_batch, @@ -491,7 +496,11 @@ def run(self, inputs:dict): if split_bucket: print(i18n("分桶处理模式已开启")) - + + # if vits_batched_inference: + # print(i18n("VITS批量推理模式已开启")) + # else: + # print(i18n("VITS单句推理模式已开启")) no_prompt_text = False if prompt_text in [None, ""]: @@ -529,7 +538,6 @@ def run(self, inputs:dict): ###### text preprocessing ######## data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) - audio = [] t1 = ttime() data, batch_index_list = self.to_batch(data, prompt_data=self.prompt_cache if not no_prompt_text else None, @@ -538,24 +546,23 @@ def run(self, inputs:dict): split_bucket=split_bucket ) t2 = ttime() - zero_wav = torch.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=torch.float16 if self.configs.is_half else torch.float32, - device=self.configs.device - ) ###### inference ###### t_34 = 0.0 t_45 = 0.0 + audio = [] for item in data: t3 = ttime() batch_phones = item["phones"] + batch_phones_len = item["phones_len"] all_phoneme_ids = item["all_phones"] all_phoneme_lens = item["all_phones_len"] all_bert_features = item["all_bert_features"] norm_text = item["norm_text"] + # batch_phones = batch_phones.to(self.configs.device) + batch_phones_len = batch_phones_len.to(self.configs.device) all_phoneme_ids = all_phoneme_ids.to(self.configs.device) all_phoneme_lens = all_phoneme_lens.to(self.configs.device) all_bert_features = all_bert_features.to(self.configs.device) @@ -566,7 +573,7 @@ def run(self, inputs:dict): if no_prompt_text : prompt = None else: - prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device) + prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) with torch.no_grad(): pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( @@ -583,41 +590,52 @@ def run(self, inputs:dict): t4 = ttime() t_34 += t4 - t3 - refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device) + refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device) if self.configs.is_half: refer_audio_spepc = refer_audio_spepc.half() - - ## 直接对batch进行decode 生成的音频会有问题 + + + batch_audio_fragment = [] + + # ## vits并行推理 method 1 # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment =(self.vits_model.decode( - # pred_semantic, batch_phones, refer_audio_spepc - # ).detach()[:, 0, :]) - # max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音 - # if max_audio>1: batch_audio_fragment/=max_audio - # batch_audio_fragment = batch_audio_fragment.cpu().numpy() + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc + # )) - ## 改成串行处理 - batch_audio_fragment = [] - for i, idx in enumerate(idx_list): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment =(self.vits_model.decode( - _pred_semantic, phones, refer_audio_spepc - ).detach()[0, 0, :]) - max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 - if max_audio>1: audio_fragment/=max_audio - audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0) - batch_audio_fragment.append( - audio_fragment.cpu().numpy() - ) ###试试重建不带上prompt部分 + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones,refer_audio_spepc + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + + + # ## vits串行推理 + # for i, idx in enumerate(idx_list): + # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + # audio_fragment =(self.vits_model.decode( + # _pred_semantic, phones, refer_audio_spepc + # ).detach()[0, 0, :]) + # batch_audio_fragment.append( + # audio_fragment + # ) ###试试重建不带上prompt部分 t5 = ttime() t_45 += t5 - t4 if return_fragment: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess(batch_audio_fragment, + yield self.audio_postprocess([batch_audio_fragment], self.configs.sampling_rate, batch_index_list, speed_factor, @@ -626,7 +644,8 @@ def run(self, inputs:dict): audio.append(batch_audio_fragment) if self.stop_flag: - yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16) + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + dtype=np.int16) return if not return_fragment: @@ -640,15 +659,30 @@ def run(self, inputs:dict): def audio_postprocess(self, - audio:np.ndarray, + audio:List[torch.Tensor], sr:int, batch_index_list:list=None, speed_factor:float=1.0, split_bucket:bool=True)->tuple[int, np.ndarray]: + zero_wav = torch.zeros( + int(self.configs.sampling_rate * 0.3), + dtype=torch.float16 if self.configs.is_half else torch.float32, + device=self.configs.device + ) + + for i, batch in enumerate(audio): + for j, audio_fragment in enumerate(batch): + max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 + if max_audio>1: audio_fragment/=max_audio + audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio[i][j] = audio_fragment.cpu().numpy() + + if split_bucket: audio = self.recovery_order(audio, batch_index_list) else: - audio = [item for batch in audio for item in batch] + # audio = [item for batch in audio for item in batch] + audio = sum(audio, []) audio = np.concatenate(audio, 0) diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 9fcb8e1ea..2669bf411 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -10,7 +10,7 @@ from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer -from TTS_infer_pack.text_segmentation_method import splits, get_method as get_seg_method +from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method # from tools.i18n.i18n import I18nAuto @@ -39,6 +39,10 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list: return result + + + + class TextPreprocessor: def __init__(self, bert_model:AutoModelForMaskedLM, tokenizer:AutoTokenizer, device:torch.device): @@ -74,12 +78,18 @@ def pre_seg_text(self, text:str, lang:str, text_split_method:str): _texts = text.split("\n") _texts = merge_short_text_in_array(_texts, 5) texts = [] + + + for text in _texts: # 解决输入目标文本的空行导致报错的问题 if (len(text.strip()) == 0): continue if (text[-1] not in splits): text += "。" if lang != "en" else "." - texts.append(text) + + # 解决句子过长导致Bert报错的问题 + texts.extend(split_big_text(text)) + return texts @@ -176,4 +186,8 @@ def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str): dtype=torch.float32, ).to(self.device) - return feature \ No newline at end of file + return feature + + + + diff --git a/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py b/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py index 7bc6b0090..2a182b293 100644 --- a/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py +++ b/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py @@ -24,6 +24,32 @@ def decorator(func): splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } +def split_big_text(text, max_len=510): + # 定义全角和半角标点符号 + punctuation = "".join(splits) + + # 切割文本 + segments = re.split('([' + punctuation + '])', text) + + # 初始化结果列表和当前片段 + result = [] + current_segment = '' + + for segment in segments: + # 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段 + if len(current_segment + segment) > max_len: + result.append(current_segment) + current_segment = segment + else: + current_segment += segment + + # 将最后一个片段加入结果列表 + if current_segment: + result.append(current_segment) + + return result + + def split(todo_text): todo_text = todo_text.replace("……", "。").replace("——", ",") @@ -121,6 +147,6 @@ def cut5(inp): if __name__ == '__main__': - method = get_method("cut1") + method = get_method("cut5") print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。")) \ No newline at end of file diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 3f58dbcea..2d223f906 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -29,9 +29,13 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available() +gpt_path = os.environ.get("gpt_path", None) +sovits_path = os.environ.get("sovits_path", None) +cnhubert_base_path = os.environ.get("cnhubert_base_path", None) +bert_path = os.environ.get("bert_path", None) + import gradio as gr from TTS_infer_pack.TTS import TTS, TTS_Config -from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5 from TTS_infer_pack.text_segmentation_method import get_method from tools.i18n.i18n import I18nAuto @@ -65,6 +69,15 @@ tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") tts_config.device = device tts_config.is_half = is_half +if gpt_path is not None: + tts_config.t2s_weights_path = gpt_path +if sovits_path is not None: + tts_config.vits_weights_path = sovits_path +if cnhubert_base_path is not None: + tts_config.cnhuhbert_base_path = cnhubert_base_path +if bert_path is not None: + tts_config.bert_base_path = bert_path + tts_pipline = TTS(tts_config) gpt_path = tts_config.t2s_weights_path sovits_path = tts_config.vits_weights_path @@ -169,7 +182,7 @@ def get_weights_names(): with gr.Row(): with gr.Column(): - batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True) + batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True) speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) @@ -181,7 +194,8 @@ def get_weights_names(): value=i18n("凑四句一切"), interactive=True, ) - split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) + with gr.Row(): + split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) # with gr.Column(): output = gr.Audio(label=i18n("输出的语音")) with gr.Row(): diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index a4d223522..75bc61778 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1,5 +1,6 @@ import copy import math +from typing import List import torch from torch import nn from torch.nn import functional as F @@ -985,6 +986,55 @@ def decode(self, codes, text, refer, noise_scale=0.5): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + + + @torch.no_grad() + def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze( + commons.sequence_mask(refer_lengths, refer.size(2)), 1 + ).to(refer.dtype) + ge = self.ref_enc(refer * refer_mask, refer_mask) + + # y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to( + # codes.dtype + # ) + y_lengths = (y_lengths * 2).long().to(codes.device) + text_lengths = text_lengths.long().to(text.device) + # y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + # text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + # 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题? + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate( + quantized, size=int(quantized.shape[-1] * 2), mode="nearest" + ) + + x, m_p, logs_p, y_mask = self.enc_p( + quantized, y_lengths, text, text_lengths, ge + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + z_masked = (z * y_mask)[:, :, :] + + # 串行。把padding部分去掉再decode + o_list:List[torch.Tensor] = [] + for i in range(z_masked.shape[0]): + z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0) + o = self.dec(z_slice, g=ge)[0, 0, :].detach() + o_list.append(o) + + # 并行(会有问题)。先decode,再把padding的部分去掉 + # o = self.dec(z_masked, g=ge) + # upsample_rate = int(math.prod(self.upsample_rates)) + # o_lengths = y_lengths*upsample_rate + # o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)] + + return o_list def extract_latent(self, x): ssl = self.ssl_proj(x)