-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
Hi ,
I liked your work on MultiImagesteganography and I need your help to resolve a error which encountered while I am executing the code
I am getting a Type error while executing the below code of yours
I am also attaching a image please check and help me out.
#1
def train(model,epochs,decoder_criterion,full_model_optimizer,full_model_criterion,learning_rate,training_iterator,valid_iterator,print_every=50):
training_full_model_loss_list = []
decoder_loss_list = []
valid_loss_list = []
for epoch in range(epochs):
for index,training_dict in enumerate(training_iterator):
cover_image = training_dict['cover_image']
cover_image = cover_image.to(device)
secret_image_1 = training_dict['secret_image_1']
secret_image_1 = secret_image_1.to(device)
secret_image_2 = training_dict['secret_image_2']
secret_image_2 = secret_image_2.to(device)
secret_image_3 = training_dict['secret_image_3']
secret_image_3 = secret_image_3.to(device)
full_model_optimizer.zero_grad()
encoder_output = model(cover_image,secret_image_1,secret_image_2,secret_image_3,secret_image_3,'encoder')
hidden_image,reveal_image_1,reveal_image_2,reveal_image_3 = model(cover_image,
secret_image_1,
secret_image_2,
secret_image_3,secret_image_3,'full')
full_model_loss = full_model_criterion(hidden_image,cover_image,
reveal_image_1,secret_image_1,
reveal_image_2,secret_image_2,
reveal_image_3,secret_image_3,
)
full_model_loss.backward()
full_model_optimizer.step()
full_model_optimizer.zero_grad()
reveal_output1, reveal_output2,reveal_output3 = model(cover_image,
secret_image_1,
secret_image_2,
secret_image_3,encoder_output,'decoder')
decoder_loss = decoder_criterion(reveal_output1, reveal_output2,reveal_output3,secret_image_1,
secret_image_2,secret_image_3)
decoder_loss.backward()
full_model_optimizer.step()
training_full_model_loss_list.append(full_model_loss)
decoder_loss_list.append(decoder_loss)
if epoch % print_every == 0:
print("Training full model loss at {} epochs is: {}".format(epoch, full_model_loss))
print("Training decoder loss at {} epochs is: {}".format(epoch, decoder_loss))
return model, training_full_model_loss_list,decoder_loss_list
#2
model, training_full_model_loss_list,decoder_loss_list = train(model,EPOCHS,decoder_criterion,full_model_optimizer,full_model_criterion,LEARNING_RATE,train_data_loader,valid_data_loader,50)
Metadata
Metadata
Assignees
Labels
No labels