1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ from typing import List
4
+
2
5
import torch
3
6
from tqdm import tqdm
4
7
34
37
}
35
38
36
39
40
+ @torch .jit .script
41
+ class T2SMLP :
42
+ def __init__ (self , w1 , b1 , w2 , b2 ):
43
+ self .w1 = w1
44
+ self .b1 = b1
45
+ self .w2 = w2
46
+ self .b2 = b2
47
+
48
+ def forward (self , x ):
49
+ x = F .relu (F .linear (x , self .w1 , self .b1 ))
50
+ x = F .linear (x , self .w2 , self .b2 )
51
+ return x
52
+
53
+
54
+ @torch .jit .script
55
+ class T2SBlock :
56
+ def __init__ (
57
+ self ,
58
+ num_heads ,
59
+ hidden_dim : int ,
60
+ mlp : T2SMLP ,
61
+ qkv_w ,
62
+ qkv_b ,
63
+ out_w ,
64
+ out_b ,
65
+ norm_w1 ,
66
+ norm_b1 ,
67
+ norm_eps1 ,
68
+ norm_w2 ,
69
+ norm_b2 ,
70
+ norm_eps2 ,
71
+ ):
72
+ self .num_heads = num_heads
73
+ self .mlp = mlp
74
+ self .hidden_dim : int = hidden_dim
75
+ self .qkv_w = qkv_w
76
+ self .qkv_b = qkv_b
77
+ self .out_w = out_w
78
+ self .out_b = out_b
79
+ self .norm_w1 = norm_w1
80
+ self .norm_b1 = norm_b1
81
+ self .norm_eps1 = norm_eps1
82
+ self .norm_w2 = norm_w2
83
+ self .norm_b2 = norm_b2
84
+ self .norm_eps2 = norm_eps2
85
+
86
+ def process_prompt (self , x , attn_mask : torch .Tensor ):
87
+ q , k , v = F .linear (x , self .qkv_w , self .qkv_b ).chunk (3 , dim = - 1 )
88
+
89
+ batch_size = q .shape [0 ]
90
+ q_len = q .shape [1 ]
91
+ kv_len = k .shape [1 ]
92
+
93
+ k_cache = k
94
+ v_cache = v
95
+
96
+ q = q .view (batch_size , q_len , self .num_heads , - 1 ).transpose (1 , 2 )
97
+ k = k_cache .view (batch_size , kv_len , self .num_heads , - 1 ).transpose (1 , 2 )
98
+ v = v_cache .view (batch_size , kv_len , self .num_heads , - 1 ).transpose (1 , 2 )
99
+
100
+ attn = F .scaled_dot_product_attention (q , k , v , ~ attn_mask )
101
+
102
+ attn = attn .permute (2 , 0 , 1 , 3 ).reshape (batch_size , - 1 , self .hidden_dim )
103
+ attn = F .linear (attn , self .out_w , self .out_b )
104
+
105
+ x = F .layer_norm (
106
+ x + attn , [self .hidden_dim ], self .norm_w1 , self .norm_b1 , self .norm_eps1
107
+ )
108
+ x = F .layer_norm (
109
+ x + self .mlp .forward (x ),
110
+ [self .hidden_dim ],
111
+ self .norm_w2 ,
112
+ self .norm_b2 ,
113
+ self .norm_eps2 ,
114
+ )
115
+ return x , k_cache , v_cache
116
+
117
+ def decode_next_token (self , x , k_cache , v_cache ):
118
+ q , k , v = F .linear (x , self .qkv_w , self .qkv_b ).chunk (3 , dim = - 1 )
119
+
120
+ k_cache = torch .cat ([k_cache , k ], dim = 1 )
121
+ v_cache = torch .cat ([v_cache , v ], dim = 1 )
122
+ kv_len = k_cache .shape [1 ]
123
+
124
+ batch_size = q .shape [0 ]
125
+ q_len = q .shape [1 ]
126
+
127
+ q = q .view (batch_size , q_len , self .num_heads , - 1 ).transpose (1 , 2 )
128
+ k = k_cache .view (batch_size , kv_len , self .num_heads , - 1 ).transpose (1 , 2 )
129
+ v = v_cache .view (batch_size , kv_len , self .num_heads , - 1 ).transpose (1 , 2 )
130
+
131
+ attn = F .scaled_dot_product_attention (q , k , v )
132
+
133
+ attn = attn .permute (2 , 0 , 1 , 3 ).reshape (batch_size , - 1 , self .hidden_dim )
134
+ attn = F .linear (attn , self .out_w , self .out_b )
135
+
136
+ x = F .layer_norm (
137
+ x + attn , [self .hidden_dim ], self .norm_w1 , self .norm_b1 , self .norm_eps1
138
+ )
139
+ x = F .layer_norm (
140
+ x + self .mlp .forward (x ),
141
+ [self .hidden_dim ],
142
+ self .norm_w2 ,
143
+ self .norm_b2 ,
144
+ self .norm_eps2 ,
145
+ )
146
+ return x , k_cache , v_cache
147
+
148
+
149
+ @torch .jit .script
150
+ class T2STransformer :
151
+ def __init__ (self , num_blocks : int , blocks : List [T2SBlock ]):
152
+ self .num_blocks : int = num_blocks
153
+ self .blocks = blocks
154
+
155
+ def process_prompt (
156
+ self , x , attn_mask : torch .Tensor ):
157
+ k_cache : List [torch .Tensor ] = []
158
+ v_cache : List [torch .Tensor ] = []
159
+ for i in range (self .num_blocks ):
160
+ x , k_cache_ , v_cache_ = self .blocks [i ].process_prompt (x , attn_mask )
161
+ k_cache .append (k_cache_ )
162
+ v_cache .append (v_cache_ )
163
+ return x , k_cache , v_cache
164
+
165
+ def decode_next_token (
166
+ self , x , k_cache : List [torch .Tensor ], v_cache : List [torch .Tensor ]
167
+ ):
168
+ for i in range (self .num_blocks ):
169
+ x , k_cache [i ], v_cache [i ] = self .blocks [i ].decode_next_token (x , k_cache [i ], v_cache [i ])
170
+ return x , k_cache , v_cache
171
+
172
+
37
173
class Text2SemanticDecoder (nn .Module ):
38
174
def __init__ (self , config , norm_first = False , top_k = 3 ):
39
175
super (Text2SemanticDecoder , self ).__init__ ()
@@ -88,6 +224,37 @@ def __init__(self, config, norm_first=False, top_k=3):
88
224
ignore_index = self .EOS ,
89
225
)
90
226
227
+ blocks = []
228
+
229
+ for i in range (self .num_layers ):
230
+ layer = self .h .layers [i ]
231
+ t2smlp = T2SMLP (
232
+ layer .linear1 .weight ,
233
+ layer .linear1 .bias ,
234
+ layer .linear2 .weight ,
235
+ layer .linear2 .bias
236
+ )
237
+ # (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias)
238
+ block = T2SBlock (
239
+ self .num_head ,
240
+ self .model_dim ,
241
+ t2smlp ,
242
+ layer .self_attn .in_proj_weight ,
243
+ layer .self_attn .in_proj_bias ,
244
+ layer .self_attn .out_proj .weight ,
245
+ layer .self_attn .out_proj .bias ,
246
+ layer .norm1 .weight ,
247
+ layer .norm1 .bias ,
248
+ layer .norm1 .eps ,
249
+ layer .norm2 .weight ,
250
+ layer .norm2 .bias ,
251
+ layer .norm2 .eps
252
+ )
253
+
254
+ blocks .append (block )
255
+
256
+ self .t2s_transformer = T2STransformer (self .num_layers , blocks )
257
+
91
258
def make_input_data (self , x , x_lens , y , y_lens , bert_feature ):
92
259
x = self .ar_text_embedding (x )
93
260
x = x + self .bert_proj (bert_feature .transpose (1 , 2 ))
@@ -328,7 +495,7 @@ def infer_panel(
328
495
prompts , ####参考音频token
329
496
bert_feature ,
330
497
top_k : int = - 100 ,
331
- top_p : float = 100 ,
498
+ top_p : int = 100 ,
332
499
early_stop_num : int = - 1 ,
333
500
temperature : float = 1.0 ,
334
501
):
@@ -343,25 +510,16 @@ def infer_panel(
343
510
x_attn_mask = torch .zeros ((x_len , x_len ), dtype = torch .bool )
344
511
stop = False
345
512
# print(1111111,self.num_layers)
346
- cache = {
347
- "all_stage" : self .num_layers ,
348
- "k" : [None ] * self .num_layers , ###根据配置自己手写
349
- "v" : [None ] * self .num_layers ,
350
- # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
351
- "y_emb" : None , ##只需要对最新的samples求emb,再拼历史的就行
352
- # "logits":None,###原版就已经只对结尾求再拼接了,不用管
353
- # "xy_dec":None,###不需要,本来只需要最后一个做logits
354
- "first_infer" : 1 ,
355
- "stage" : 0 ,
356
- }
513
+
514
+ k_cache = None
515
+ v_cache = None
357
516
################### first step ##########################
358
517
if y is not None :
359
518
y_emb = self .ar_audio_embedding (y )
360
519
y_len = y_emb .shape [1 ]
361
520
prefix_len = y .shape [1 ]
362
521
y_pos = self .ar_audio_position (y_emb )
363
522
xy_pos = torch .concat ([x , y_pos ], dim = 1 )
364
- cache ["y_emb" ] = y_emb
365
523
ref_free = False
366
524
else :
367
525
y_emb = None
@@ -387,61 +545,42 @@ def infer_panel(
387
545
)
388
546
389
547
for idx in tqdm (range (1500 )):
548
+ if xy_attn_mask is not None :
549
+ xy_dec , k_cache , v_cache = self .t2s_transformer .process_prompt (xy_pos , xy_attn_mask )
550
+ else :
551
+ xy_dec , k_cache , v_cache = self .t2s_transformer .decode_next_token (xy_pos , k_cache , v_cache )
390
552
391
- xy_dec , _ = self .h ((xy_pos , None ), mask = xy_attn_mask , cache = cache )
392
553
logits = self .ar_predict_layer (
393
554
xy_dec [:, - 1 ]
394
- ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
395
- # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
396
- if (idx == 0 ): ###第一次跑不能EOS否则没有了
397
- logits = logits [:, :- 1 ] ###刨除1024终止符号的概率
555
+ )
556
+
557
+ if idx == 0 :
558
+ xy_attn_mask = None
559
+ logits = logits [:, :- 1 ]
398
560
samples = sample (
399
561
logits [0 ], y , top_k = top_k , top_p = top_p , repetition_penalty = 1.35 , temperature = temperature
400
562
)[0 ].unsqueeze (0 )
401
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
402
- # print(samples.shape)#[1,1]#第一个1是bs
563
+
403
564
y = torch .concat ([y , samples ], dim = 1 )
404
565
405
566
if early_stop_num != - 1 and (y .shape [1 ] - prefix_len ) > early_stop_num :
406
567
print ("use early stop num:" , early_stop_num )
407
568
stop = True
408
569
409
570
if torch .argmax (logits , dim = - 1 )[0 ] == self .EOS or samples [0 , 0 ] == self .EOS :
410
- # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
411
571
stop = True
412
572
if stop :
413
- # if prompts.shape[1] == y.shape[1]:
414
- # y = torch.concat([y, torch.zeros_like(samples)], dim=1)
415
- # print("bad zero prediction")
416
573
if y .shape [1 ] == 0 :
417
574
y = torch .concat ([y , torch .zeros_like (samples )], dim = 1 )
418
575
print ("bad zero prediction" )
419
576
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
420
577
break
421
578
422
579
####################### update next step ###################################
423
- cache ["first_infer" ] = 0
424
- if cache ["y_emb" ] is not None :
425
- y_emb = torch .cat (
426
- [cache ["y_emb" ], self .ar_audio_embedding (y [:, - 1 :])], dim = 1
427
- )
428
- cache ["y_emb" ] = y_emb
429
- y_pos = self .ar_audio_position (y_emb )
430
- xy_pos = y_pos [:, - 1 :]
431
- else :
432
- y_emb = self .ar_audio_embedding (y [:, - 1 :])
433
- cache ["y_emb" ] = y_emb
434
- y_pos = self .ar_audio_position (y_emb )
435
- xy_pos = y_pos
436
- y_len = y_pos .shape [1 ]
437
-
438
- ###最右边一列(是错的)
439
- # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
440
- # xy_attn_mask[:,-1]=False
441
- ###最下面一行(是对的)
442
- xy_attn_mask = torch .zeros (
443
- (1 , x_len + y_len ), dtype = torch .bool , device = xy_pos .device
444
- )
580
+ y_emb = self .ar_audio_embedding (y [:, - 1 :])
581
+ xy_pos = y_emb * self .ar_audio_position .x_scale + self .ar_audio_position .alpha * self .ar_audio_position .pe [
582
+ :, prompts .shape [1 ] + idx ]
583
+
445
584
if ref_free :
446
585
return y [:, :- 1 ], 0
447
586
return y [:, :- 1 ], idx - 1
0 commit comments