Skip to content

Commit c746227

Browse files
committed
Update GPT-SoVITS batch inference
Update split_alpha_nonalpha chore index.html
1 parent a2fce72 commit c746227

File tree

16 files changed

+1284
-825
lines changed

16 files changed

+1284
-825
lines changed

config.py

+2
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ class GPTSoVitsConfig(AsDictMixin):
180180
top_p: float = 1.0
181181
temperature: float = 1.0
182182
use_streaming: bool = False
183+
batch_size: int = 5
184+
speed: float = 1.0
183185
presets: Dict[str, GPTSoVitsPreset] = field(default_factory=lambda: {"default": GPTSoVitsPreset(),
184186
"default2": GPTSoVitsPreset()})
185187

docker-compose.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ services:
1616
- ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
1717
- ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
1818
- ./nltk_data:/usr/local/share/nltk_data
19+
- ./phrases_dict.txt:/app/phrases_dict.txt # 挂载多音字词典

gpt_sovits/AR/models/t2s_lightning_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313

1414

1515
class Text2SemanticLightningModule(LightningModule):
16-
def __init__(self, config, output_dir, is_train=True):
16+
def __init__(self, config, output_dir, is_train=True, flash_attn_enabled: bool = False):
1717
super().__init__()
1818
self.config = config
1919
self.top_k = 3
20-
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
20+
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k, flash_attn_enabled=flash_attn_enabled)
2121
pretrained_s1 = config.get("pretrained_s1")
2222
if pretrained_s1 and is_train:
2323
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))

gpt_sovits/AR/models/t2s_model.py

+273-45
Large diffs are not rendered by default.

gpt_sovits/AR/models/utils.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn.functional as F
44
from typing import Tuple
55

6+
67
def sequence_mask(length, max_length=None):
78
if max_length is None:
89
max_length = length.max()
@@ -40,7 +41,7 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
4041

4142
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
4243
def top_k_top_p_filtering(
43-
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
44+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
4445
):
4546
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
4647
Args:
@@ -100,66 +101,67 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
100101

101102

102103
def multinomial_sample_one_no_sync(
103-
probs_sort,
104+
probs_sort,
104105
): # Does multinomial sampling without a cuda synchronization
105106
q = torch.empty_like(probs_sort).exponential_(1)
106107
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
107108

108109

109110
def logits_to_probs(
110-
logits,
111-
previous_tokens: Optional[torch.Tensor] = None,
112-
temperature: float = 1.0,
113-
top_k: Optional[int] = None,
114-
top_p: Optional[float] = None,
115-
repetition_penalty: float = 1.0,
111+
logits,
112+
previous_tokens: Optional[torch.Tensor] = None,
113+
temperature: float = 1.0,
114+
top_k: Optional[int] = None,
115+
top_p: Optional[int] = None,
116+
repetition_penalty: float = 1.0,
116117
):
117-
if previous_tokens is not None:
118-
previous_tokens = previous_tokens.squeeze()
118+
# if previous_tokens is not None:
119+
# previous_tokens = previous_tokens.squeeze()
119120
# print(logits.shape,previous_tokens.shape)
120121
# pdb.set_trace()
121122
if previous_tokens is not None and repetition_penalty != 1.0:
122123
previous_tokens = previous_tokens.long()
123-
score = torch.gather(logits, dim=0, index=previous_tokens)
124+
score = torch.gather(logits, dim=1, index=previous_tokens)
124125
score = torch.where(
125126
score < 0, score * repetition_penalty, score / repetition_penalty
126127
)
127-
logits.scatter_(dim=0, index=previous_tokens, src=score)
128+
logits.scatter_(dim=1, index=previous_tokens, src=score)
128129

129130
if top_p is not None and top_p < 1.0:
130131
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
131132
cum_probs = torch.cumsum(
132133
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
133134
)
134135
sorted_indices_to_remove = cum_probs > top_p
135-
sorted_indices_to_remove[0] = False # keep at least one option
136+
sorted_indices_to_remove[:, 0] = False # keep at least one option
136137
indices_to_remove = sorted_indices_to_remove.scatter(
137-
dim=0, index=sorted_indices, src=sorted_indices_to_remove
138+
dim=1, index=sorted_indices, src=sorted_indices_to_remove
138139
)
139140
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
140141

141142
logits = logits / max(temperature, 1e-5)
142143

143144
if top_k is not None:
144145
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
145-
pivot = v.select(-1, -1).unsqueeze(-1)
146+
pivot = v[: , -1].unsqueeze(-1)
146147
logits = torch.where(logits < pivot, -float("Inf"), logits)
147148

148149
probs = torch.nn.functional.softmax(logits, dim=-1)
149150
return probs
150151

151152

152153
def sample(
153-
logits,
154-
previous_tokens: Optional[torch.Tensor] = None,
155-
**sampling_kwargs,
154+
logits,
155+
previous_tokens: Optional[torch.Tensor] = None,
156+
**sampling_kwargs,
156157
) -> Tuple[torch.Tensor, torch.Tensor]:
157158
probs = logits_to_probs(
158159
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
159160
)
160161
idx_next = multinomial_sample_one_no_sync(probs)
161162
return idx_next, probs
162163

164+
163165
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
164166
policy_rejected_logps: torch.FloatTensor,
165167
reference_chosen_logps: torch.FloatTensor,
@@ -180,15 +182,20 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
180182

181183
return losses.mean(), chosen_rewards, rejected_rewards
182184

183-
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
184185

186+
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor,
187+
labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[
188+
torch.FloatTensor, torch.FloatTensor]:
185189
# dummy token; we'll ignore the losses on these tokens later
186190

187-
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
188-
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
191+
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2,
192+
index=labels_target.unsqueeze(2)).squeeze(2)
193+
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2,
194+
index=labels_reject.unsqueeze(2)).squeeze(2)
189195

190196
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
191197

198+
192199
def make_reject_y(y_o, y_lens):
193200
def repeat_P(y):
194201
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
@@ -197,23 +204,25 @@ def repeat_P(y):
197204
range_text = y[range_idx[0]:range_idx[1]]
198205
new_y = torch.cat([pre, range_text, range_text, shf])
199206
return new_y
207+
200208
def lost_P(y):
201209
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
202210
pre = y[:range_idx[0]]
203211
shf = y[range_idx[1]:]
204212
range_text = y[range_idx[0]:range_idx[1]]
205213
new_y = torch.cat([pre, shf])
206214
return new_y
215+
207216
bs = len(y_lens)
208217
reject_y = []
209218
reject_y_lens = []
210219
for b in range(bs):
211-
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
220+
process_item_idx = torch.randint(0, 1, size=(1,))[0]
212221
if process_item_idx == 0:
213222
new_y = repeat_P(y_o[b])
214223
reject_y.append(new_y)
215224
reject_y_lens.append(len(new_y))
216-
elif process_item_idx==1:
225+
elif process_item_idx == 1:
217226
new_y = lost_P(y_o[b])
218227
reject_y.append(new_y)
219228
reject_y_lens.append(len(new_y))
@@ -222,7 +231,7 @@ def lost_P(y):
222231
pad_length = max_length - reject_y_lens[b]
223232
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
224233

225-
reject_y = torch.stack(reject_y, dim = 0)
234+
reject_y = torch.stack(reject_y, dim=0)
226235
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
227236

228237
return reject_y, reject_y_lens

0 commit comments

Comments
 (0)