Skip to content

增加flash attention选项,防止影响训练 #730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions GPT_SoVITS/AR/models/t2s_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from AR.modules.optim import ScaledAdam

class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled)
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
Expand Down
233 changes: 201 additions & 32 deletions GPT_SoVITS/AR/models/t2s_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import List

import torch
from tqdm import tqdm

Expand Down Expand Up @@ -174,7 +176,7 @@ def decode_next_token(


class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
Expand Down Expand Up @@ -226,37 +228,42 @@ def __init__(self, config, norm_first=False, top_k=3):
multidim_average="global",
ignore_index=self.EOS,
)

blocks = []

for i in range(self.num_layers):
layer = self.h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)

block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)

blocks.append(block)

self.t2s_transformer = T2STransformer(self.num_layers, blocks)
if not flash_attn_enabled:
print("Not Using Flash Attention")
self.infer_panel = self.infer_panel_batch_only
else:
print("Using Flash Attention")
blocks = []

for i in range(self.num_layers):
layer = self.h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)

block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)

blocks.append(block)

self.t2s_transformer = T2STransformer(self.num_layers, blocks)

def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
Expand Down Expand Up @@ -640,6 +647,168 @@ def infer_panel(
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替

if ref_free:
return y_list, [0]*x.shape[0]
return y_list, idx_list

def infer_panel_batch_only(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)

# AR Decoder
y = prompts

x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要,本来只需要最后一个做logits
"first_infer": 1,
"stage": 0,
}
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True

x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)

y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):

xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)

# 移除已经生成完毕的序列
reserved_idx_of_batch_for_y = None
if (self.EOS in torch.argmax(logits, dim=-1)) or \
(self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]

batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]

# 只保留未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if cache["y_emb"] is not None:
cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
if cache["k"] is not None:
for i in range(self.num_layers):
# 因为kv转置了,所以batch dim是1
cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)


if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True

if not (None in idx_list):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
# if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break

####################### update next step ###################################
cache["first_infer"] = 0
if cache["y_emb"] is not None:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
else:
y_emb = self.ar_audio_embedding(y[:, -1:])
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos
y_len = y_pos.shape[1]

###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)

if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替

if ref_free:
return y_list, [0]*x.shape[0]
return y_list, idx_list
21 changes: 14 additions & 7 deletions GPT_SoVITS/TTS_infer_pack/TTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from tools.i18n.i18n import I18nAuto
from my_utils import load_audio
from module.mel_processing import spectrogram_torch
from .text_segmentation_method import splits
from .TextPreprocessor import TextPreprocessor
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
i18n = I18nAuto()

# configs/tts_infer.yaml
Expand All @@ -30,6 +30,7 @@
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
flash_attn_enabled: true

custom:
device: cuda
Expand All @@ -38,7 +39,7 @@
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth

flash_attn_enabled: true


"""
Expand All @@ -63,7 +64,8 @@ def __init__(self, configs: Union[dict, str]):
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
}
self.configs:dict = configs.get("custom", self.default_configs)

Expand All @@ -73,6 +75,7 @@ def __init__(self, configs: Union[dict, str]):
self.vits_weights_path = self.configs.get("vits_weights_path")
self.bert_base_path = self.configs.get("bert_base_path")
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")


self.max_sec = None
Expand Down Expand Up @@ -103,15 +106,17 @@ def save_configs(self, configs_path:str=None)->None:
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
},
"custom": {
"device": str(self.device),
"is_half": self.is_half,
"t2s_weights_path": self.t2s_weights_path,
"vits_weights_path": self.vits_weights_path,
"bert_base_path": self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path
"cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled": self.flash_attn_enabled
}
}
if configs_path is None:
Expand All @@ -128,6 +133,7 @@ def __str__(self):
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
string += "----------------------------------------\n"
return string

Expand Down Expand Up @@ -231,7 +237,8 @@ def init_t2s_weights(self, weights_path: str):
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
config = dict_s1["config"]
self.configs.max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
flash_attn_enabled=self.configs.flash_attn_enabled)
t2s_model.load_state_dict(dict_s1["weight"])
if self.configs.is_half:
t2s_model = t2s_model.half()
Expand Down
5 changes: 4 additions & 1 deletion GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)

import re
import torch
Expand All @@ -7,7 +10,7 @@
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from .text_segmentation_method import splits, get_method as get_seg_method
from TTS_infer_pack.text_segmentation_method import splits, get_method as get_seg_method

# from tools.i18n.i18n import I18nAuto

Expand Down
2 changes: 2 additions & 0 deletions GPT_SoVITS/configs/tts_infer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda
flash_attn_enabled: true
is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
flash_attn_enabled: true
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
Loading