Skip to content

Commit 598ac16

Browse files
authored
Merge pull request #457 from WatchTower-Liu/DPO_optim
添加DPO协同训练,提升输出内容的稳定性,增加部分生成参数的webui控制
2 parents e3f3ad2 + 070ac9b commit 598ac16

File tree

4 files changed

+176
-5
lines changed

4 files changed

+176
-5
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ runtime
77
output
88
logs
99
reference
10-
SoVITS_weights
10+
SoVITS_weights
11+
GPT_weights

GPT_SoVITS/AR/models/t2s_model.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
sample,
99
logits_to_probs,
1010
multinomial_sample_one_no_sync,
11+
dpo_loss,
12+
make_reject_y,
13+
get_batch_logps
1114
)
1215
from AR.modules.embedding import SinePositionalEmbedding
1316
from AR.modules.embedding import TokenEmbedding
@@ -85,11 +88,104 @@ def __init__(self, config, norm_first=False, top_k=3):
8588
ignore_index=self.EOS,
8689
)
8790

91+
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
92+
x = self.ar_text_embedding(x)
93+
x = x + self.bert_proj(bert_feature.transpose(1, 2))
94+
x = self.ar_text_position(x)
95+
x_mask = make_pad_mask(x_lens)
96+
97+
y_mask = make_pad_mask(y_lens)
98+
y_mask_int = y_mask.type(torch.int64)
99+
codes = y.type(torch.int64) * (1 - y_mask_int)
100+
101+
# Training
102+
# AR Decoder
103+
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
104+
x_len = x_lens.max()
105+
y_len = y_lens.max()
106+
y_emb = self.ar_audio_embedding(y)
107+
y_pos = self.ar_audio_position(y_emb)
108+
109+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
110+
111+
ar_xy_padding_mask = xy_padding_mask
112+
113+
x_attn_mask = F.pad(
114+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
115+
(0, y_len),
116+
value=True,
117+
)
118+
119+
y_attn_mask = F.pad(
120+
torch.triu(
121+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
122+
diagonal=1,
123+
),
124+
(x_len, 0),
125+
value=False,
126+
)
127+
128+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
129+
bsz, src_len = x.shape[0], x_len + y_len
130+
_xy_padding_mask = (
131+
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
132+
.expand(-1, self.num_head, -1, -1)
133+
.reshape(bsz * self.num_head, 1, src_len)
134+
)
135+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
136+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
137+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
138+
xy_attn_mask = new_attn_mask
139+
# x 和完整的 y 一次性输入模型
140+
xy_pos = torch.concat([x, y_pos], dim=1)
141+
142+
return xy_pos, xy_attn_mask, targets
143+
88144
def forward(self, x, x_lens, y, y_lens, bert_feature):
89145
"""
90146
x: phoneme_ids
91147
y: semantic_ids
92148
"""
149+
150+
reject_y, reject_y_lens = make_reject_y(y, y_lens)
151+
152+
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
153+
154+
xy_dec, _ = self.h(
155+
(xy_pos, None),
156+
mask=xy_attn_mask,
157+
)
158+
x_len = x_lens.max()
159+
logits = self.ar_predict_layer(xy_dec[:, x_len:])
160+
161+
###### DPO #############
162+
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
163+
164+
reject_xy_dec, _ = self.h(
165+
(reject_xy_pos, None),
166+
mask=reject_xy_attn_mask,
167+
)
168+
x_len = x_lens.max()
169+
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
170+
171+
# loss
172+
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
173+
174+
loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
175+
acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
176+
177+
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
178+
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
179+
180+
loss = loss_1 + loss_2
181+
182+
return loss, acc
183+
184+
def forward_old(self, x, x_lens, y, y_lens, bert_feature):
185+
"""
186+
x: phoneme_ids
187+
y: semantic_ids
188+
"""
93189
x = self.ar_text_embedding(x)
94190
x = x + self.bert_proj(bert_feature.transpose(1, 2))
95191
x = self.ar_text_position(x)
@@ -231,6 +327,7 @@ def infer_panel(
231327
prompts, ####参考音频token
232328
bert_feature,
233329
top_k: int = -100,
330+
top_p: int = 100,
234331
early_stop_num: int = -1,
235332
temperature: float = 1.0,
236333
):
@@ -305,7 +402,7 @@ def infer_panel(
305402
if(idx==0):###第一次跑不能EOS否则没有了
306403
logits = logits[:, :-1] ###刨除1024终止符号的概率
307404
samples = sample(
308-
logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
405+
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.05, temperature=temperature
309406
)[0].unsqueeze(0)
310407
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
311408
print("use early stop num:", early_stop_num)

GPT_SoVITS/AR/models/utils.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
22
import torch
33
import torch.nn.functional as F
4-
4+
from typing import Tuple
55

66
def sequence_mask(length, max_length=None):
77
if max_length is None:
@@ -158,3 +158,70 @@ def sample(
158158
)
159159
idx_next = multinomial_sample_one_no_sync(probs)
160160
return idx_next, probs
161+
162+
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
163+
policy_rejected_logps: torch.FloatTensor,
164+
reference_chosen_logps: torch.FloatTensor,
165+
reference_rejected_logps: torch.FloatTensor,
166+
beta: float,
167+
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
168+
pi_logratios = policy_chosen_logps - policy_rejected_logps
169+
ref_logratios = reference_chosen_logps - reference_rejected_logps
170+
171+
if reference_free:
172+
ref_logratios = 0
173+
174+
logits = pi_logratios - ref_logratios
175+
176+
losses = -F.logsigmoid(beta * logits)
177+
chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
178+
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
179+
180+
return losses.mean(), chosen_rewards, rejected_rewards
181+
182+
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
183+
184+
# dummy token; we'll ignore the losses on these tokens later
185+
186+
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
187+
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
188+
189+
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
190+
191+
def make_reject_y(y_o, y_lens):
192+
def repeat_P(y):
193+
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
194+
pre = y[:range_idx[0]]
195+
shf = y[range_idx[1]:]
196+
range_text = y[range_idx[0]:range_idx[1]]
197+
new_y = torch.cat([pre, range_text, range_text, shf])
198+
return new_y
199+
def lost_P(y):
200+
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
201+
pre = y[:range_idx[0]]
202+
shf = y[range_idx[1]:]
203+
range_text = y[range_idx[0]:range_idx[1]]
204+
new_y = torch.cat([pre, shf])
205+
return new_y
206+
bs = len(y_lens)
207+
reject_y = []
208+
reject_y_lens = []
209+
for b in range(bs):
210+
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
211+
if process_item_idx == 0:
212+
new_y = repeat_P(y_o[b])
213+
reject_y.append(new_y)
214+
reject_y_lens.append(len(new_y))
215+
elif process_item_idx==1:
216+
new_y = lost_P(y_o[b])
217+
reject_y.append(new_y)
218+
reject_y_lens.append(len(new_y))
219+
max_length = max(reject_y_lens)
220+
for b in range(bs):
221+
pad_length = max_length - reject_y_lens[b]
222+
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
223+
224+
reject_y = torch.stack(reject_y, dim = 0)
225+
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
226+
227+
return reject_y, reject_y_lens

GPT_SoVITS/inference_webui.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def merge_short_text_in_array(texts, threshold):
365365
result[len(result) - 1] += text
366366
return result
367367

368-
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切")):
368+
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
369369
t0 = ttime()
370370
prompt_language = dict_language[prompt_language]
371371
text_language = dict_language[text_language]
@@ -444,7 +444,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
444444
prompt,
445445
bert,
446446
# prompt_phone_len=ph_offset,
447-
top_k=config["inference"]["top_k"],
447+
top_k=top_k,
448+
top_p=top_p,
449+
temperature=temperature,
448450
early_stop_num=hz * max_sec,
449451
)
450452
t3 = ttime()
@@ -621,6 +623,10 @@ def get_weights_names():
621623
value=i18n("凑四句一切"),
622624
interactive=True,
623625
)
626+
with gr.Row():
627+
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=20,interactive=True)
628+
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=0.6,interactive=True)
629+
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=0.6,interactive=True)
624630
inference_button = gr.Button(i18n("合成语音"), variant="primary")
625631
output = gr.Audio(label=i18n("输出的语音"))
626632

0 commit comments

Comments
 (0)