-
Notifications
You must be signed in to change notification settings - Fork 51
Open
Description
# Evaluation Model model.eval() PSNR_mean, SSIM_mean = validation(model, val_loader)
`def validation(model, val_loader):
ssim = SSIM()
psnr = PSNR()
ssim_list = []
psnr_list = []
for i, imgs in enumerate(val_loader):
with torch.no_grad():
low_img, high_img = imgs[0].cuda(), imgs[1].cuda()
_, _, enhanced_img = model(low_img)
# print(enhanced_img.shape)
ssim_value = ssim(enhanced_img, high_img, as_loss=False).item()
#ssim_value = ssim(enhanced_img, high_img).item()
psnr_value = psnr(enhanced_img, high_img).item()
# print('The %d image SSIM value is %d:' %(i, ssim_value))
ssim_list.append(ssim_value)
psnr_list.append(psnr_value)
SSIM_mean = np.mean(ssim_list)
PSNR_mean = np.mean(psnr_list)
print('The SSIM Value is:', SSIM_mean)
print('The PSNR Value is:', PSNR_mean)
return SSIM_mean, PSNR_mean
`
Metadata
Metadata
Assignees
Labels
No labels