Skip to content

Commit 7da8f45

Browse files
committed
Update GPT-SoVITS TorchScript
1 parent df41212 commit 7da8f45

File tree

4 files changed

+189
-50
lines changed

4 files changed

+189
-50
lines changed

Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM artrajz/pytorch:2.2.0-cpu-py3.10.11-ubuntu22.04
1+
FROM artrajz/pytorch:2.2.1-cpu-py3.10.11-ubuntu22.04
22

33
RUN mkdir -p /app
44
WORKDIR /app

Dockerfile_GPU

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM artrajz/pytorch:2.2.0-cu118-py3.10.11-ubuntu22.04
1+
FROM artrajz/pytorch:2.2.1-cu118-py3.10.11-ubuntu22.04
22

33
RUN mkdir -p /app
44
WORKDIR /app

gpt_sovits/AR/models/t2s_model.py

+186-47
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
1+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2+
# reference: https://github.com/lifeiteng/vall-e
3+
from typing import List
4+
25
import torch
36
from tqdm import tqdm
47

@@ -34,6 +37,139 @@
3437
}
3538

3639

40+
@torch.jit.script
41+
class T2SMLP:
42+
def __init__(self, w1, b1, w2, b2):
43+
self.w1 = w1
44+
self.b1 = b1
45+
self.w2 = w2
46+
self.b2 = b2
47+
48+
def forward(self, x):
49+
x = F.relu(F.linear(x, self.w1, self.b1))
50+
x = F.linear(x, self.w2, self.b2)
51+
return x
52+
53+
54+
@torch.jit.script
55+
class T2SBlock:
56+
def __init__(
57+
self,
58+
num_heads,
59+
hidden_dim: int,
60+
mlp: T2SMLP,
61+
qkv_w,
62+
qkv_b,
63+
out_w,
64+
out_b,
65+
norm_w1,
66+
norm_b1,
67+
norm_eps1,
68+
norm_w2,
69+
norm_b2,
70+
norm_eps2,
71+
):
72+
self.num_heads = num_heads
73+
self.mlp = mlp
74+
self.hidden_dim: int = hidden_dim
75+
self.qkv_w = qkv_w
76+
self.qkv_b = qkv_b
77+
self.out_w = out_w
78+
self.out_b = out_b
79+
self.norm_w1 = norm_w1
80+
self.norm_b1 = norm_b1
81+
self.norm_eps1 = norm_eps1
82+
self.norm_w2 = norm_w2
83+
self.norm_b2 = norm_b2
84+
self.norm_eps2 = norm_eps2
85+
86+
def process_prompt(self, x, attn_mask: torch.Tensor):
87+
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
88+
89+
batch_size = q.shape[0]
90+
q_len = q.shape[1]
91+
kv_len = k.shape[1]
92+
93+
k_cache = k
94+
v_cache = v
95+
96+
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
97+
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
98+
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
99+
100+
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
101+
102+
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
103+
attn = F.linear(attn, self.out_w, self.out_b)
104+
105+
x = F.layer_norm(
106+
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
107+
)
108+
x = F.layer_norm(
109+
x + self.mlp.forward(x),
110+
[self.hidden_dim],
111+
self.norm_w2,
112+
self.norm_b2,
113+
self.norm_eps2,
114+
)
115+
return x, k_cache, v_cache
116+
117+
def decode_next_token(self, x, k_cache, v_cache):
118+
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
119+
120+
k_cache = torch.cat([k_cache, k], dim=1)
121+
v_cache = torch.cat([v_cache, v], dim=1)
122+
kv_len = k_cache.shape[1]
123+
124+
batch_size = q.shape[0]
125+
q_len = q.shape[1]
126+
127+
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
128+
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
129+
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
130+
131+
attn = F.scaled_dot_product_attention(q, k, v)
132+
133+
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
134+
attn = F.linear(attn, self.out_w, self.out_b)
135+
136+
x = F.layer_norm(
137+
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
138+
)
139+
x = F.layer_norm(
140+
x + self.mlp.forward(x),
141+
[self.hidden_dim],
142+
self.norm_w2,
143+
self.norm_b2,
144+
self.norm_eps2,
145+
)
146+
return x, k_cache, v_cache
147+
148+
149+
@torch.jit.script
150+
class T2STransformer:
151+
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
152+
self.num_blocks: int = num_blocks
153+
self.blocks = blocks
154+
155+
def process_prompt(
156+
self, x, attn_mask: torch.Tensor):
157+
k_cache: List[torch.Tensor] = []
158+
v_cache: List[torch.Tensor] = []
159+
for i in range(self.num_blocks):
160+
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
161+
k_cache.append(k_cache_)
162+
v_cache.append(v_cache_)
163+
return x, k_cache, v_cache
164+
165+
def decode_next_token(
166+
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
167+
):
168+
for i in range(self.num_blocks):
169+
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
170+
return x, k_cache, v_cache
171+
172+
37173
class Text2SemanticDecoder(nn.Module):
38174
def __init__(self, config, norm_first=False, top_k=3):
39175
super(Text2SemanticDecoder, self).__init__()
@@ -88,6 +224,37 @@ def __init__(self, config, norm_first=False, top_k=3):
88224
ignore_index=self.EOS,
89225
)
90226

227+
blocks = []
228+
229+
for i in range(self.num_layers):
230+
layer = self.h.layers[i]
231+
t2smlp = T2SMLP(
232+
layer.linear1.weight,
233+
layer.linear1.bias,
234+
layer.linear2.weight,
235+
layer.linear2.bias
236+
)
237+
# (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias)
238+
block = T2SBlock(
239+
self.num_head,
240+
self.model_dim,
241+
t2smlp,
242+
layer.self_attn.in_proj_weight,
243+
layer.self_attn.in_proj_bias,
244+
layer.self_attn.out_proj.weight,
245+
layer.self_attn.out_proj.bias,
246+
layer.norm1.weight,
247+
layer.norm1.bias,
248+
layer.norm1.eps,
249+
layer.norm2.weight,
250+
layer.norm2.bias,
251+
layer.norm2.eps
252+
)
253+
254+
blocks.append(block)
255+
256+
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
257+
91258
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
92259
x = self.ar_text_embedding(x)
93260
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -328,7 +495,7 @@ def infer_panel(
328495
prompts, ####参考音频token
329496
bert_feature,
330497
top_k: int = -100,
331-
top_p: float = 100,
498+
top_p: int = 100,
332499
early_stop_num: int = -1,
333500
temperature: float = 1.0,
334501
):
@@ -343,25 +510,16 @@ def infer_panel(
343510
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
344511
stop = False
345512
# print(1111111,self.num_layers)
346-
cache = {
347-
"all_stage": self.num_layers,
348-
"k": [None] * self.num_layers, ###根据配置自己手写
349-
"v": [None] * self.num_layers,
350-
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
351-
"y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
352-
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
353-
# "xy_dec":None,###不需要,本来只需要最后一个做logits
354-
"first_infer": 1,
355-
"stage": 0,
356-
}
513+
514+
k_cache = None
515+
v_cache = None
357516
################### first step ##########################
358517
if y is not None:
359518
y_emb = self.ar_audio_embedding(y)
360519
y_len = y_emb.shape[1]
361520
prefix_len = y.shape[1]
362521
y_pos = self.ar_audio_position(y_emb)
363522
xy_pos = torch.concat([x, y_pos], dim=1)
364-
cache["y_emb"] = y_emb
365523
ref_free = False
366524
else:
367525
y_emb = None
@@ -387,61 +545,42 @@ def infer_panel(
387545
)
388546

389547
for idx in tqdm(range(1500)):
548+
if xy_attn_mask is not None:
549+
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
550+
else:
551+
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
390552

391-
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
392553
logits = self.ar_predict_layer(
393554
xy_dec[:, -1]
394-
) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
395-
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
396-
if (idx == 0): ###第一次跑不能EOS否则没有了
397-
logits = logits[:, :-1] ###刨除1024终止符号的概率
555+
)
556+
557+
if idx == 0:
558+
xy_attn_mask = None
559+
logits = logits[:, :-1]
398560
samples = sample(
399561
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
400562
)[0].unsqueeze(0)
401-
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
402-
# print(samples.shape)#[1,1]#第一个1是bs
563+
403564
y = torch.concat([y, samples], dim=1)
404565

405566
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
406567
print("use early stop num:", early_stop_num)
407568
stop = True
408569

409570
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
410-
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
411571
stop = True
412572
if stop:
413-
# if prompts.shape[1] == y.shape[1]:
414-
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
415-
# print("bad zero prediction")
416573
if y.shape[1] == 0:
417574
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
418575
print("bad zero prediction")
419576
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
420577
break
421578

422579
####################### update next step ###################################
423-
cache["first_infer"] = 0
424-
if cache["y_emb"] is not None:
425-
y_emb = torch.cat(
426-
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim=1
427-
)
428-
cache["y_emb"] = y_emb
429-
y_pos = self.ar_audio_position(y_emb)
430-
xy_pos = y_pos[:, -1:]
431-
else:
432-
y_emb = self.ar_audio_embedding(y[:, -1:])
433-
cache["y_emb"] = y_emb
434-
y_pos = self.ar_audio_position(y_emb)
435-
xy_pos = y_pos
436-
y_len = y_pos.shape[1]
437-
438-
###最右边一列(是错的)
439-
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
440-
# xy_attn_mask[:,-1]=False
441-
###最下面一行(是对的)
442-
xy_attn_mask = torch.zeros(
443-
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
444-
)
580+
y_emb = self.ar_audio_embedding(y[:, -1:])
581+
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
582+
:, prompts.shape[1] + idx]
583+
445584
if ref_free:
446585
return y[:, :-1], 0
447586
return y[:, :-1], idx - 1

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ sentencepiece==0.1.99
3535
jaconv
3636

3737
# Machine Learning and Deep Learning
38-
torch>=2.2.0
38+
torch>=2.2.1
3939
onnx==1.12.0
4040
audonnx==0.7.0
4141
vector_quantize_pytorch==1.12.12

0 commit comments

Comments
 (0)