|
8 | 8 | sample,
|
9 | 9 | logits_to_probs,
|
10 | 10 | multinomial_sample_one_no_sync,
|
| 11 | + dpo_loss, |
| 12 | + make_reject_y, |
| 13 | + get_batch_logps |
11 | 14 | )
|
12 | 15 | from AR.modules.embedding import SinePositionalEmbedding
|
13 | 16 | from AR.modules.embedding import TokenEmbedding
|
@@ -85,11 +88,104 @@ def __init__(self, config, norm_first=False, top_k=3):
|
85 | 88 | ignore_index=self.EOS,
|
86 | 89 | )
|
87 | 90 |
|
| 91 | + def make_input_data(self, x, x_lens, y, y_lens, bert_feature): |
| 92 | + x = self.ar_text_embedding(x) |
| 93 | + x = x + self.bert_proj(bert_feature.transpose(1, 2)) |
| 94 | + x = self.ar_text_position(x) |
| 95 | + x_mask = make_pad_mask(x_lens) |
| 96 | + |
| 97 | + y_mask = make_pad_mask(y_lens) |
| 98 | + y_mask_int = y_mask.type(torch.int64) |
| 99 | + codes = y.type(torch.int64) * (1 - y_mask_int) |
| 100 | + |
| 101 | + # Training |
| 102 | + # AR Decoder |
| 103 | + y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS) |
| 104 | + x_len = x_lens.max() |
| 105 | + y_len = y_lens.max() |
| 106 | + y_emb = self.ar_audio_embedding(y) |
| 107 | + y_pos = self.ar_audio_position(y_emb) |
| 108 | + |
| 109 | + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) |
| 110 | + |
| 111 | + ar_xy_padding_mask = xy_padding_mask |
| 112 | + |
| 113 | + x_attn_mask = F.pad( |
| 114 | + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), |
| 115 | + (0, y_len), |
| 116 | + value=True, |
| 117 | + ) |
| 118 | + |
| 119 | + y_attn_mask = F.pad( |
| 120 | + torch.triu( |
| 121 | + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), |
| 122 | + diagonal=1, |
| 123 | + ), |
| 124 | + (x_len, 0), |
| 125 | + value=False, |
| 126 | + ) |
| 127 | + |
| 128 | + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) |
| 129 | + bsz, src_len = x.shape[0], x_len + y_len |
| 130 | + _xy_padding_mask = ( |
| 131 | + ar_xy_padding_mask.view(bsz, 1, 1, src_len) |
| 132 | + .expand(-1, self.num_head, -1, -1) |
| 133 | + .reshape(bsz * self.num_head, 1, src_len) |
| 134 | + ) |
| 135 | + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) |
| 136 | + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) |
| 137 | + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) |
| 138 | + xy_attn_mask = new_attn_mask |
| 139 | + # x 和完整的 y 一次性输入模型 |
| 140 | + xy_pos = torch.concat([x, y_pos], dim=1) |
| 141 | + |
| 142 | + return xy_pos, xy_attn_mask, targets |
| 143 | + |
88 | 144 | def forward(self, x, x_lens, y, y_lens, bert_feature):
|
89 | 145 | """
|
90 | 146 | x: phoneme_ids
|
91 | 147 | y: semantic_ids
|
92 | 148 | """
|
| 149 | + |
| 150 | + reject_y, reject_y_lens = make_reject_y(y, y_lens) |
| 151 | + |
| 152 | + xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature) |
| 153 | + |
| 154 | + xy_dec, _ = self.h( |
| 155 | + (xy_pos, None), |
| 156 | + mask=xy_attn_mask, |
| 157 | + ) |
| 158 | + x_len = x_lens.max() |
| 159 | + logits = self.ar_predict_layer(xy_dec[:, x_len:]) |
| 160 | + |
| 161 | + ###### DPO ############# |
| 162 | + reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature) |
| 163 | + |
| 164 | + reject_xy_dec, _ = self.h( |
| 165 | + (reject_xy_pos, None), |
| 166 | + mask=reject_xy_attn_mask, |
| 167 | + ) |
| 168 | + x_len = x_lens.max() |
| 169 | + reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:]) |
| 170 | + |
| 171 | + # loss |
| 172 | + # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum |
| 173 | + |
| 174 | + loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum") |
| 175 | + acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item() |
| 176 | + |
| 177 | + A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) |
| 178 | + loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) |
| 179 | + |
| 180 | + loss = loss_1 + loss_2 |
| 181 | + |
| 182 | + return loss, acc |
| 183 | + |
| 184 | + def forward_old(self, x, x_lens, y, y_lens, bert_feature): |
| 185 | + """ |
| 186 | + x: phoneme_ids |
| 187 | + y: semantic_ids |
| 188 | + """ |
93 | 189 | x = self.ar_text_embedding(x)
|
94 | 190 | x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
95 | 191 | x = self.ar_text_position(x)
|
@@ -231,6 +327,7 @@ def infer_panel(
|
231 | 327 | prompts, ####参考音频token
|
232 | 328 | bert_feature,
|
233 | 329 | top_k: int = -100,
|
| 330 | + top_p: int = 100, |
234 | 331 | early_stop_num: int = -1,
|
235 | 332 | temperature: float = 1.0,
|
236 | 333 | ):
|
@@ -305,7 +402,7 @@ def infer_panel(
|
305 | 402 | if(idx==0):###第一次跑不能EOS否则没有了
|
306 | 403 | logits = logits[:, :-1] ###刨除1024终止符号的概率
|
307 | 404 | samples = sample(
|
308 |
| - logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35 |
| 405 | + logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.05, temperature=temperature |
309 | 406 | )[0].unsqueeze(0)
|
310 | 407 | if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
311 | 408 | print("use early stop num:", early_stop_num)
|
|
0 commit comments