@@ -113,14 +113,13 @@ def train(rank, a, h):
113
113
for i , batch in enumerate (train_loader ):
114
114
if rank == 0 :
115
115
start_b = time .time ()
116
- x , y , _ , mel_spec = batch
116
+ x , y , _ , y_mel = batch
117
117
x = torch .autograd .Variable (x .to (device , non_blocking = True ))
118
118
y = torch .autograd .Variable (y .to (device , non_blocking = True ))
119
- mel_spec = torch .autograd .Variable (mel_spec .to (device , non_blocking = True ))
119
+ y_mel = torch .autograd .Variable (y_mel .to (device , non_blocking = True ))
120
120
y = y .unsqueeze (1 )
121
121
122
122
y_g_hat = generator (x )
123
- y_mel = mel_spec
124
123
y_g_hat_mel = mel_spectrogram (y_g_hat .squeeze (1 ), h .n_fft , h .num_mels , h .sampling_rate , h .hop_size , h .win_size ,
125
124
h .fmin , h .fmax_for_loss )
126
125
@@ -188,23 +187,32 @@ def train(rank, a, h):
188
187
if steps % a .validation_interval == 0 : # and steps != 0:
189
188
generator .eval ()
190
189
torch .cuda .empty_cache ()
190
+ val_err_tot = 0
191
191
with torch .no_grad ():
192
- for i , batch in enumerate (validation_loader ):
193
- x , y , _ , _ = batch
192
+ for j , batch in enumerate (validation_loader ):
193
+ x , y , _ , y_mel = batch
194
194
y_g_hat = generator (x .to (device ))
195
+ y_mel = torch .autograd .Variable (y_mel .to (device , non_blocking = True ))
196
+ y_g_hat_mel = mel_spectrogram (y_g_hat .squeeze (1 ), h .n_fft , h .num_mels , h .sampling_rate ,
197
+ h .hop_size , h .win_size ,
198
+ h .fmin , h .fmax_for_loss )
199
+ val_err_tot += F .l1_loss (y_mel , y_g_hat_mel ).item ()
200
+
201
+ if j <= 4 :
202
+ if steps == 0 :
203
+ sw .add_audio ('gt/y_{}' .format (j ), y [0 ], steps , h .sampling_rate )
204
+ sw .add_figure ('gt/y_spec_{}' .format (j ), plot_spectrogram (x [0 ]), steps )
205
+
206
+ sw .add_audio ('generated/y_hat_{}' .format (j ), y_g_hat [0 ], steps , h .sampling_rate )
207
+ y_hat_spec = mel_spectrogram (y_g_hat .squeeze (1 ), h .n_fft , h .num_mels ,
208
+ h .sampling_rate , h .hop_size , h .win_size ,
209
+ h .fmin , h .fmax )
210
+ sw .add_figure ('generated/y_hat_spec_{}' .format (j ),
211
+ plot_spectrogram (y_hat_spec .squeeze (0 ).cpu ().numpy ()), steps )
212
+
213
+ val_err = val_err_tot / (j + 1 )
214
+ sw .add_scalar ("validation/mel_spec_error" , val_err , steps )
195
215
196
- if steps == 0 :
197
- sw .add_audio ('gt/y_{}' .format (i ), y [0 ], steps , h .sampling_rate )
198
- sw .add_figure ('gt/y_spec_{}' .format (i ), plot_spectrogram (x [0 ]), steps )
199
-
200
- sw .add_audio ('generated/y_hat_{}' .format (i ), y_g_hat [0 ], steps , h .sampling_rate )
201
- y_hat_spec = mel_spectrogram (y_g_hat .squeeze (1 ), h .n_fft , h .num_mels ,
202
- h .sampling_rate , h .hop_size , h .win_size ,
203
- h .fmin , h .fmax )
204
- sw .add_figure ('generated/y_hat_spec_{}' .format (i ),
205
- plot_spectrogram (y_hat_spec .squeeze (0 ).cpu ().numpy ()), steps )
206
- if i == 4 :
207
- break
208
216
generator .train ()
209
217
210
218
steps += 1
0 commit comments