Skip to content

Commit 9acd013

Browse files
committed
updated
1 parent 2691493 commit 9acd013

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `
3838
Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
3939
You can change the path by adding `--checkpoint_path` option.
4040

41+
Validation loss during training with V1 generator.<br>
42+
![validation loss](./validation_loss.png)
4143

4244
## Pretrained Model
4345
You can also use pretrained models we provide.<br/>

train.py

+25-17
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,13 @@ def train(rank, a, h):
113113
for i, batch in enumerate(train_loader):
114114
if rank == 0:
115115
start_b = time.time()
116-
x, y, _, mel_spec = batch
116+
x, y, _, y_mel = batch
117117
x = torch.autograd.Variable(x.to(device, non_blocking=True))
118118
y = torch.autograd.Variable(y.to(device, non_blocking=True))
119-
mel_spec = torch.autograd.Variable(mel_spec.to(device, non_blocking=True))
119+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
120120
y = y.unsqueeze(1)
121121

122122
y_g_hat = generator(x)
123-
y_mel = mel_spec
124123
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
125124
h.fmin, h.fmax_for_loss)
126125

@@ -188,23 +187,32 @@ def train(rank, a, h):
188187
if steps % a.validation_interval == 0: # and steps != 0:
189188
generator.eval()
190189
torch.cuda.empty_cache()
190+
val_err_tot = 0
191191
with torch.no_grad():
192-
for i, batch in enumerate(validation_loader):
193-
x, y, _, _ = batch
192+
for j, batch in enumerate(validation_loader):
193+
x, y, _, y_mel = batch
194194
y_g_hat = generator(x.to(device))
195+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
196+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
197+
h.hop_size, h.win_size,
198+
h.fmin, h.fmax_for_loss)
199+
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
200+
201+
if j <= 4:
202+
if steps == 0:
203+
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
204+
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
205+
206+
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
207+
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
208+
h.sampling_rate, h.hop_size, h.win_size,
209+
h.fmin, h.fmax)
210+
sw.add_figure('generated/y_hat_spec_{}'.format(j),
211+
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
212+
213+
val_err = val_err_tot / (j+1)
214+
sw.add_scalar("validation/mel_spec_error", val_err, steps)
195215

196-
if steps == 0:
197-
sw.add_audio('gt/y_{}'.format(i), y[0], steps, h.sampling_rate)
198-
sw.add_figure('gt/y_spec_{}'.format(i), plot_spectrogram(x[0]), steps)
199-
200-
sw.add_audio('generated/y_hat_{}'.format(i), y_g_hat[0], steps, h.sampling_rate)
201-
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
202-
h.sampling_rate, h.hop_size, h.win_size,
203-
h.fmin, h.fmax)
204-
sw.add_figure('generated/y_hat_spec_{}'.format(i),
205-
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
206-
if i == 4:
207-
break
208216
generator.train()
209217

210218
steps += 1

validation_loss.png

11.4 KB
Loading

0 commit comments

Comments
 (0)