From 174c4bbab3a18965b356fb0f6662edb19264ce38 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Sun, 10 Mar 2024 14:07:58 +0800 Subject: [PATCH] =?UTF-8?q?=09=E5=A2=9E=E5=8A=A0flash=20attention=20?= =?UTF-8?q?=E9=80=89=E9=A1=B9:=20=20=20GPT=5FSoVITS/AR/models/t2s=5Flightn?= =?UTF-8?q?ing=5Fmodule.py=20=09=E5=A2=9E=E5=8A=A0flash=20attention=20?= =?UTF-8?q?=E9=80=89=E9=A1=B9:=20=20=20GPT=5FSoVITS/AR/models/t2s=5Fmodel.?= =?UTF-8?q?py=20=09=E5=A2=9E=E5=8A=A0flash=20attention=20=E9=80=89?= =?UTF-8?q?=E9=A1=B9:=20=20=20GPT=5FSoVITS/TTS=5Finfer=5Fpack/TTS.py=20=09?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0flash=20attention=20=E9=80=89=E9=A1=B9:=20=20?= =?UTF-8?q?=20GPT=5FSoVITS/TTS=5Finfer=5Fpack/TextPreprocessor.py=20=09?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0flash=20attention=20=E9=80=89=E9=A1=B9:=20=20?= =?UTF-8?q?=20GPT=5FSoVITS/configs/tts=5Finfer.yaml=20=09=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0flash=20attention=20=E9=80=89=E9=A1=B9:=20=20=20GPT=5F?= =?UTF-8?q?SoVITS/inference=5Fwebui.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_lightning_module.py | 4 +- GPT_SoVITS/AR/models/t2s_model.py | 233 +++++++++++++++--- GPT_SoVITS/TTS_infer_pack/TTS.py | 21 +- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 5 +- GPT_SoVITS/configs/tts_infer.yaml | 2 + GPT_SoVITS/inference_webui.py | 4 +- 6 files changed, 225 insertions(+), 44 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index 2dd3f3928..1b602629a 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -13,11 +13,11 @@ from AR.modules.optim import ScaledAdam class Text2SemanticLightningModule(LightningModule): - def __init__(self, config, output_dir, is_train=True): + def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False): super().__init__() self.config = config self.top_k = 3 - self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled) pretrained_s1 = config.get("pretrained_s1") if pretrained_s1 and is_train: # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ed46b2b1d..a3170b9fc 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -1,7 +1,9 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e +import os, sys +now_dir = os.getcwd() +sys.path.append(now_dir) from typing import List - import torch from tqdm import tqdm @@ -174,7 +176,7 @@ def decode_next_token( class Text2SemanticDecoder(nn.Module): - def __init__(self, config, norm_first=False, top_k=3): + def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False): super(Text2SemanticDecoder, self).__init__() self.model_dim = config["model"]["hidden_dim"] self.embedding_dim = config["model"]["embedding_dim"] @@ -226,37 +228,42 @@ def __init__(self, config, norm_first=False, top_k=3): multidim_average="global", ignore_index=self.EOS, ) - - blocks = [] - - for i in range(self.num_layers): - layer = self.h.layers[i] - t2smlp = T2SMLP( - layer.linear1.weight, - layer.linear1.bias, - layer.linear2.weight, - layer.linear2.bias - ) - - block = T2SBlock( - self.num_head, - self.model_dim, - t2smlp, - layer.self_attn.in_proj_weight, - layer.self_attn.in_proj_bias, - layer.self_attn.out_proj.weight, - layer.self_attn.out_proj.bias, - layer.norm1.weight, - layer.norm1.bias, - layer.norm1.eps, - layer.norm2.weight, - layer.norm2.bias, - layer.norm2.eps - ) - - blocks.append(block) - self.t2s_transformer = T2STransformer(self.num_layers, blocks) + if not flash_attn_enabled: + print("Not Using Flash Attention") + self.infer_panel = self.infer_panel_batch_only + else: + print("Using Flash Attention") + blocks = [] + + for i in range(self.num_layers): + layer = self.h.layers[i] + t2smlp = T2SMLP( + layer.linear1.weight, + layer.linear1.bias, + layer.linear2.weight, + layer.linear2.bias + ) + + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -640,6 +647,168 @@ def infer_panel( if idx_list[i] is None: idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 + if ref_free: + return y_list, [0]*x.shape[0] + return y_list, idx_list + + def infer_panel_batch_only( + self, + x, #####全部文本token + x_lens, + prompts, ####参考音频token + bert_feature, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + ): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + + # AR Decoder + y = prompts + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + stop = False + # print(1111111,self.num_layers) + cache = { + "all_stage": self.num_layers, + "k": [None] * self.num_layers, ###根据配置自己手写 + "v": [None] * self.num_layers, + # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 + "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 + # "logits":None,###原版就已经只对结尾求再拼接了,不用管 + # "xy_dec":None,###不需要,本来只需要最后一个做logits + "first_infer": 1, + "stage": 0, + } + ################### first step ########################## + if y is not None: + y_emb = self.ar_audio_embedding(y) + y_len = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + cache["y_emb"] = y_emb + ref_free = False + else: + y_emb = None + y_len = 0 + prefix_len = 0 + y_pos = None + xy_pos = x + y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) + ref_free = True + + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + x.device + ) + + y_list = [None]*y.shape[0] + batch_idx_map = list(range(y.shape[0])) + idx_list = [None]*y.shape[0] + for idx in tqdm(range(1500)): + + xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer( + xy_dec[:, -1] + ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 + # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) + if(idx==0):###第一次跑不能EOS否则没有了 + logits = logits[:, :-1] ###刨除1024终止符号的概率 + samples = sample( + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + )[0] + # 本次生成的 semantic_ids 和之前的 y 构成新的 y + # print(samples.shape)#[1,1]#第一个1是bs + y = torch.concat([y, samples], dim=1) + + # 移除已经生成完毕的序列 + reserved_idx_of_batch_for_y = None + if (self.EOS in torch.argmax(logits, dim=-1)) or \ + (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止 + l = samples[:, 0]==self.EOS + removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist() + reserved_idx_of_batch_for_y = torch.where(l==False)[0] + # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] + for i in removed_idx_of_batch_for_y: + batch_index = batch_idx_map[i] + idx_list[batch_index] = idx - 1 + y_list[batch_index] = y[i, :-1] + + batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()] + + # 只保留未生成完毕的序列 + if reserved_idx_of_batch_for_y is not None: + # index = torch.LongTensor(batch_idx_map).to(y.device) + y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) + if cache["y_emb"] is not None: + cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y) + if cache["k"] is not None: + for i in range(self.num_layers): + # 因为kv转置了,所以batch dim是1 + cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y) + cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y) + + + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + print("use early stop num:", early_stop_num) + stop = True + + if not (None in idx_list): + # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) + stop = True + if stop: + # if prompts.shape[1] == y.shape[1]: + # y = torch.concat([y, torch.zeros_like(samples)], dim=1) + # print("bad zero prediction") + if y.shape[1]==0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + print("bad zero prediction") + print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + break + + ####################### update next step ################################### + cache["first_infer"] = 0 + if cache["y_emb"] is not None: + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + xy_pos = y_pos[:, -1:] + else: + y_emb = self.ar_audio_embedding(y[:, -1:]) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + xy_pos = y_pos + y_len = y_pos.shape[1] + + ###最右边一列(是错的) + # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) + # xy_attn_mask[:,-1]=False + ###最下面一行(是对的) + xy_attn_mask = torch.zeros( + (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device + ) + + if (None in idx_list): + for i in range(x.shape[0]): + if idx_list[i] is None: + idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 + if ref_free: return y_list, [0]*x.shape[0] return y_list, idx_list \ No newline at end of file diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index ba29a03f8..7cfaf46b3 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -17,8 +17,8 @@ from tools.i18n.i18n import I18nAuto from my_utils import load_audio from module.mel_processing import spectrogram_torch -from .text_segmentation_method import splits -from .TextPreprocessor import TextPreprocessor +from TTS_infer_pack.text_segmentation_method import splits +from TTS_infer_pack.TextPreprocessor import TextPreprocessor i18n = I18nAuto() # configs/tts_infer.yaml @@ -30,6 +30,7 @@ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth + flash_attn_enabled: true custom: device: cuda @@ -38,7 +39,7 @@ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - + flash_attn_enabled: true """ @@ -63,7 +64,8 @@ def __init__(self, configs: Union[dict, str]): "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", - "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + "flash_attn_enabled": True } self.configs:dict = configs.get("custom", self.default_configs) @@ -73,6 +75,7 @@ def __init__(self, configs: Union[dict, str]): self.vits_weights_path = self.configs.get("vits_weights_path") self.bert_base_path = self.configs.get("bert_base_path") self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path") + self.flash_attn_enabled = self.configs.get("flash_attn_enabled") self.max_sec = None @@ -103,7 +106,8 @@ def save_configs(self, configs_path:str=None)->None: "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", - "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + "flash_attn_enabled": True }, "custom": { "device": str(self.device), @@ -111,7 +115,8 @@ def save_configs(self, configs_path:str=None)->None: "t2s_weights_path": self.t2s_weights_path, "vits_weights_path": self.vits_weights_path, "bert_base_path": self.bert_base_path, - "cnhuhbert_base_path": self.cnhuhbert_base_path + "cnhuhbert_base_path": self.cnhuhbert_base_path, + "flash_attn_enabled": self.flash_attn_enabled } } if configs_path is None: @@ -128,6 +133,7 @@ def __str__(self): string += "t2s_weights_path: {}\n".format(self.t2s_weights_path) string += "vits_weights_path: {}\n".format(self.vits_weights_path) string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path) + string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled) string += "----------------------------------------\n" return string @@ -231,7 +237,8 @@ def init_t2s_weights(self, weights_path: str): dict_s1 = torch.load(weights_path, map_location=self.configs.device) config = dict_s1["config"] self.configs.max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False, + flash_attn_enabled=self.configs.flash_attn_enabled) t2s_model.load_state_dict(dict_s1["weight"]) if self.configs.is_half: t2s_model = t2s_model.half() diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 1504a534f..9fcb8e1ea 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,4 +1,7 @@ +import os, sys +now_dir = os.getcwd() +sys.path.append(now_dir) import re import torch @@ -7,7 +10,7 @@ from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer -from .text_segmentation_method import splits, get_method as get_seg_method +from TTS_infer_pack.text_segmentation_method import splits, get_method as get_seg_method # from tools.i18n.i18n import I18nAuto diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 5f56a4ecf..c772f2950 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -2,6 +2,7 @@ custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cuda + flash_attn_enabled: true is_half: true t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth @@ -9,6 +10,7 @@ default: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cpu + flash_attn_enabled: true is_half: false t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index a1932207c..3f58dbcea 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -20,7 +20,6 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) import pdb import torch -# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py infer_ttswebui = os.environ.get("infer_ttswebui", 9872) @@ -33,8 +32,9 @@ import gradio as gr from TTS_infer_pack.TTS import TTS, TTS_Config from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5 -from tools.i18n.i18n import I18nAuto from TTS_infer_pack.text_segmentation_method import get_method +from tools.i18n.i18n import I18nAuto + i18n = I18nAuto() os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。