Skip to content

新增VITS批量推理修复了一些bug #732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 71 additions & 37 deletions GPT_SoVITS/TTS_infer_pack/TTS.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
Expand Down Expand Up @@ -366,6 +367,7 @@ def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
Expand All @@ -375,24 +377,26 @@ def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold
phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
phones = torch.LongTensor(item["phones"])
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]
phones = torch.LongTensor(item["phones"])
all_phones = phones.clone()
all_phones = phones
# norm_text = item["norm_text"]

bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])

phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])

phones_batch = phones_list
max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
Expand All @@ -406,6 +410,7 @@ def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold

batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list),
"all_bert_features": all_bert_features_batch,
Expand Down Expand Up @@ -491,7 +496,11 @@ def run(self, inputs:dict):

if split_bucket:
print(i18n("分桶处理模式已开启"))


# if vits_batched_inference:
# print(i18n("VITS批量推理模式已开启"))
# else:
# print(i18n("VITS单句推理模式已开启"))

no_prompt_text = False
if prompt_text in [None, ""]:
Expand Down Expand Up @@ -529,7 +538,6 @@ def run(self, inputs:dict):

###### text preprocessing ########
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
audio = []
t1 = ttime()
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
Expand All @@ -538,24 +546,23 @@ def run(self, inputs:dict):
split_bucket=split_bucket
)
t2 = ttime()
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=torch.float16 if self.configs.is_half else torch.float32,
device=self.configs.device
)


###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
batch_phones = item["phones"]
batch_phones_len = item["phones_len"]
all_phoneme_ids = item["all_phones"]
all_phoneme_lens = item["all_phones_len"]
all_bert_features = item["all_bert_features"]
norm_text = item["norm_text"]

# batch_phones = batch_phones.to(self.configs.device)
batch_phones_len = batch_phones_len.to(self.configs.device)
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
all_bert_features = all_bert_features.to(self.configs.device)
Expand All @@ -566,7 +573,7 @@ def run(self, inputs:dict):
if no_prompt_text :
prompt = None
else:
prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device)
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)

with torch.no_grad():
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
Expand All @@ -583,41 +590,52 @@ def run(self, inputs:dict):
t4 = ttime()
t_34 += t4 - t3

refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device)
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
if self.configs.is_half:
refer_audio_spepc = refer_audio_spepc.half()

## 直接对batch进行decode 生成的音频会有问题


batch_audio_fragment = []

# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment =(self.vits_model.decode(
# pred_semantic, batch_phones, refer_audio_spepc
# ).detach()[:, 0, :])
# max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音
# if max_audio>1: batch_audio_fragment/=max_audio
# batch_audio_fragment = batch_audio_fragment.cpu().numpy()
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
# ))

## 改成串行处理
batch_audio_fragment = []
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode(
_pred_semantic, phones, refer_audio_spepc
).detach()[0, 0, :])
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0)
batch_audio_fragment.append(
audio_fragment.cpu().numpy()
) ###试试重建不带上prompt部分
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spepc
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]


# ## vits串行推理
# for i, idx in enumerate(idx_list):
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
# audio_fragment =(self.vits_model.decode(
# _pred_semantic, phones, refer_audio_spepc
# ).detach()[0, 0, :])
# batch_audio_fragment.append(
# audio_fragment
# ) ###试试重建不带上prompt部分

t5 = ttime()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess(batch_audio_fragment,
yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate,
batch_index_list,
speed_factor,
Expand All @@ -626,7 +644,8 @@ def run(self, inputs:dict):
audio.append(batch_audio_fragment)

if self.stop_flag:
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
dtype=np.int16)
return

if not return_fragment:
Expand All @@ -640,15 +659,30 @@ def run(self, inputs:dict):


def audio_postprocess(self,
audio:np.ndarray,
audio:List[torch.Tensor],
sr:int,
batch_index_list:list=None,
speed_factor:float=1.0,
split_bucket:bool=True)->tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=torch.float16 if self.configs.is_half else torch.float32,
device=self.configs.device
)

for i, batch in enumerate(audio):
for j, audio_fragment in enumerate(batch):
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy()


if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
audio = [item for batch in audio for item in batch]
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])


audio = np.concatenate(audio, 0)
Expand Down
20 changes: 17 additions & 3 deletions GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import splits, get_method as get_seg_method
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method

# from tools.i18n.i18n import I18nAuto

Expand Down Expand Up @@ -39,6 +39,10 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
return result






class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
Expand Down Expand Up @@ -74,12 +78,18 @@ def pre_seg_text(self, text:str, lang:str, text_split_method:str):
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []



for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "。" if lang != "en" else "."
texts.append(text)

# 解决句子过长导致Bert报错的问题
texts.extend(split_big_text(text))


return texts

Expand Down Expand Up @@ -176,4 +186,8 @@ def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
dtype=torch.float32,
).to(self.device)

return feature
return feature




28 changes: 27 additions & 1 deletion GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ def decorator(func):

splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }

def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)

# 切割文本
segments = re.split('([' + punctuation + '])', text)

# 初始化结果列表和当前片段
result = []
current_segment = ''

for segment in segments:
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
if len(current_segment + segment) > max_len:
result.append(current_segment)
current_segment = segment
else:
current_segment += segment

# 将最后一个片段加入结果列表
if current_segment:
result.append(current_segment)

return result



def split(todo_text):
todo_text = todo_text.replace("……", "。").replace("——", ",")
Expand Down Expand Up @@ -121,6 +147,6 @@ def cut5(inp):


if __name__ == '__main__':
method = get_method("cut1")
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

20 changes: 17 additions & 3 deletions GPT_SoVITS/inference_webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)

import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto

Expand Down Expand Up @@ -65,6 +69,15 @@
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
if gpt_path is not None:
tts_config.t2s_weights_path = gpt_path
if sovits_path is not None:
tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None:
tts_config.cnhuhbert_base_path = cnhubert_base_path
if bert_path is not None:
tts_config.bert_base_path = bert_path

tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
Expand Down Expand Up @@ -169,7 +182,7 @@ def get_weights_names():
with gr.Row():

with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True)
batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
Expand All @@ -181,7 +194,8 @@ def get_weights_names():
value=i18n("凑四句一切"),
interactive=True,
)
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
with gr.Row():
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
# with gr.Column():
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
Expand Down
Loading