Skip to content

Commit e447a4c

Browse files
i think i fixed it
1 parent f989545 commit e447a4c

File tree

1 file changed

+19
-13
lines changed
  • crates/sketchy_pix2pix/src/pix2pix

1 file changed

+19
-13
lines changed

crates/sketchy_pix2pix/src/pix2pix/gan.rs

+19-13
Original file line numberDiff line numberDiff line change
@@ -111,36 +111,42 @@ impl<B: Backend> Pix2PixModel<B> {
111111

112112
let real_result = self
113113
.discriminator
114-
.forward(item.photos.clone(), item.sketches.clone());
115-
let fake_result = self
114+
.forward(item.sketches.clone(), item.photos.clone());
115+
let fake_result_for_discriminator = self
116116
.discriminator
117+
// IMPORTANT: detatch generated sketch from generator.
118+
.forward(generated_sketches.clone().detach(), item.photos.clone());
119+
let fake_result_for_generator = self
120+
.discriminator
121+
// IMPORTANT the discriminator should not be included in autograd path.
122+
.clone().no_grad()
117123
.forward(generated_sketches.clone(), item.sketches.clone());
118-
// Erstelle Ziel-Tensoren: echte Bilder sollen als 1 klassifiziert werden, gefälschte als 0
119-
// let true_labels: Tensor<B, 4, Int> = Tensor::ones(real_result.shape(), &real_result.device());
120-
// let loss_d_real = self.bce_loss.forward(
121-
// real_result.clone(),
122-
// true_labels.clone(),
123-
// );
124124
let true_labels = Tensor::ones_like(&real_result);
125+
let fake_labels: Tensor<B, 4> = Tensor::zeros_like(&real_result);
125126
let loss_d_real = self.mse_loss.forward(
126127
real_result.clone(),
127128
true_labels.clone(),
128129
nn::loss::Reduction::Mean,
129130
);
131+
let loss_d_fake = self.mse_loss.forward(
132+
fake_result_for_discriminator,
133+
fake_labels,
134+
nn::loss::Reduction::Mean,
135+
);
130136

131137
let fake_labels: Tensor<B, 4> = Tensor::zeros_like(&real_result);
132138
let loss_g_fake = self.mse_loss.forward(
133-
fake_result.clone(),
139+
fake_result_for_generator.clone(),
134140
fake_labels,
135141
nn::loss::Reduction::Mean,
136142
);
137-
let dis_loss = loss_d_real + loss_g_fake.clone().detach();
138-
let gen_loss = Tensor::ones_like(&loss_g_fake) - loss_g_fake;
143+
let gen_loss = Tensor::ones_like(&loss_g_fake) - loss_g_fake.clone();
144+
let dis_loss = loss_d_real + loss_d_fake;
139145
GanOutput {
140146
train_sketches: item.sketches,
141147
fake_sketches: generated_sketches,
142148
real_sketch_output: real_result,
143-
fake_sketch_output: fake_result,
149+
fake_sketch_output: fake_result_for_generator,
144150
loss_discriminator: dis_loss/2,
145151
loss_generator: gen_loss,
146152
}
@@ -259,7 +265,7 @@ pub fn train_gan<B: AutodiffBackend>(
259265
epoch_bar.set_message("Epochs");
260266

261267
// Iterate over our training and validation loop for X epochs.
262-
let log_image_interval = 1;
268+
let log_image_interval = 10;
263269
for _epoch in 1..config.num_epochs + 1 {
264270
epoch_bar.inc(1);
265271

0 commit comments

Comments
 (0)