@@ -111,36 +111,42 @@ impl<B: Backend> Pix2PixModel<B> {
111
111
112
112
let real_result = self
113
113
. discriminator
114
- . forward ( item. photos . clone ( ) , item. sketches . clone ( ) ) ;
115
- let fake_result = self
114
+ . forward ( item. sketches . clone ( ) , item. photos . clone ( ) ) ;
115
+ let fake_result_for_discriminator = self
116
116
. discriminator
117
+ // IMPORTANT: detatch generated sketch from generator.
118
+ . forward ( generated_sketches. clone ( ) . detach ( ) , item. photos . clone ( ) ) ;
119
+ let fake_result_for_generator = self
120
+ . discriminator
121
+ // IMPORTANT the discriminator should not be included in autograd path.
122
+ . clone ( ) . no_grad ( )
117
123
. forward ( generated_sketches. clone ( ) , item. sketches . clone ( ) ) ;
118
- // Erstelle Ziel-Tensoren: echte Bilder sollen als 1 klassifiziert werden, gefälschte als 0
119
- // let true_labels: Tensor<B, 4, Int> = Tensor::ones(real_result.shape(), &real_result.device());
120
- // let loss_d_real = self.bce_loss.forward(
121
- // real_result.clone(),
122
- // true_labels.clone(),
123
- // );
124
124
let true_labels = Tensor :: ones_like ( & real_result) ;
125
+ let fake_labels: Tensor < B , 4 > = Tensor :: zeros_like ( & real_result) ;
125
126
let loss_d_real = self . mse_loss . forward (
126
127
real_result. clone ( ) ,
127
128
true_labels. clone ( ) ,
128
129
nn:: loss:: Reduction :: Mean ,
129
130
) ;
131
+ let loss_d_fake = self . mse_loss . forward (
132
+ fake_result_for_discriminator,
133
+ fake_labels,
134
+ nn:: loss:: Reduction :: Mean ,
135
+ ) ;
130
136
131
137
let fake_labels: Tensor < B , 4 > = Tensor :: zeros_like ( & real_result) ;
132
138
let loss_g_fake = self . mse_loss . forward (
133
- fake_result . clone ( ) ,
139
+ fake_result_for_generator . clone ( ) ,
134
140
fake_labels,
135
141
nn:: loss:: Reduction :: Mean ,
136
142
) ;
137
- let dis_loss = loss_d_real + loss_g_fake. clone ( ) . detach ( ) ;
138
- let gen_loss = Tensor :: ones_like ( & loss_g_fake ) - loss_g_fake ;
143
+ let gen_loss = Tensor :: ones_like ( & loss_g_fake ) - loss_g_fake. clone ( ) ;
144
+ let dis_loss = loss_d_real + loss_d_fake ;
139
145
GanOutput {
140
146
train_sketches : item. sketches ,
141
147
fake_sketches : generated_sketches,
142
148
real_sketch_output : real_result,
143
- fake_sketch_output : fake_result ,
149
+ fake_sketch_output : fake_result_for_generator ,
144
150
loss_discriminator : dis_loss/2 ,
145
151
loss_generator : gen_loss,
146
152
}
@@ -259,7 +265,7 @@ pub fn train_gan<B: AutodiffBackend>(
259
265
epoch_bar. set_message ( "Epochs" ) ;
260
266
261
267
// Iterate over our training and validation loop for X epochs.
262
- let log_image_interval = 1 ;
268
+ let log_image_interval = 10 ;
263
269
for _epoch in 1 ..config. num_epochs + 1 {
264
270
epoch_bar. inc ( 1 ) ;
265
271
0 commit comments