Skip to content

Commit 4c5671e

Browse files
authored
Add DnCNN3 for JPEG image deblocking
1 parent 9107cb7 commit 4c5671e

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

main_test_dncnn3_deblocking.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os.path
2+
import logging
3+
4+
import numpy as np
5+
from datetime import datetime
6+
from collections import OrderedDict
7+
8+
import torch
9+
10+
from utils import utils_logger
11+
from utils import utils_model
12+
from utils import utils_image as util
13+
#import os
14+
#os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
15+
16+
17+
'''
18+
Spyder (Python 3.6)
19+
PyTorch 1.1.0
20+
Windows 10 or Linux
21+
22+
Kai Zhang ([email protected])
23+
github: https://github.com/cszn/KAIR
24+
https://github.com/cszn/DnCNN
25+
26+
@article{zhang2017beyond,
27+
title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising},
28+
author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
29+
journal={IEEE Transactions on Image Processing},
30+
volume={26},
31+
number={7},
32+
pages={3142--3155},
33+
year={2017},
34+
publisher={IEEE}
35+
}
36+
37+
% If you have any question, please feel free to contact with me.
38+
% Kai Zhang (e-mail: [email protected]; github: https://github.com/cszn)
39+
40+
by Kai Zhang (12/Dec./2019)
41+
'''
42+
43+
"""
44+
# --------------------------------------------
45+
|--model_zoo # model_zoo
46+
|--dncnn3 # model_name
47+
|--testset # testsets
48+
|--set12 # testset_name
49+
|--bsd68
50+
|--results # results
51+
|--set12_dncnn3 # result_name = testset_name + '_' + model_name
52+
# --------------------------------------------
53+
"""
54+
55+
56+
def main():
57+
58+
# ----------------------------------------
59+
# Preparation
60+
# ----------------------------------------
61+
62+
model_name = 'dncnn3' # 'dncnn3'- can be used for blind Gaussian denoising, JPEG deblocking (quality factor 5-100) and super-resolution (x234)
63+
64+
# important!
65+
testset_name = 'bsd68' # test set, low-quality grayscale/color JPEG images
66+
n_channels = 1 # set 1 for grayscale image, set 3 for color image
67+
68+
69+
x8 = False # default: False, x8 to boost performance
70+
testsets = 'testsets' # fixed
71+
results = 'results' # fixed
72+
result_name = testset_name + '_' + model_name # fixed
73+
L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality grayscale/Y-channel JPEG images
74+
E_path = os.path.join(results, result_name) # E_path, for Estimated images
75+
util.mkdir(E_path)
76+
77+
model_pool = 'model_zoo' # fixed
78+
model_path = os.path.join(model_pool, model_name+'.pth')
79+
logger_name = result_name
80+
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
81+
logger = logging.getLogger(logger_name)
82+
83+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84+
85+
# ----------------------------------------
86+
# load model
87+
# ----------------------------------------
88+
89+
from models.network_dncnn import DnCNN as net
90+
model = net(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='R')
91+
model.load_state_dict(torch.load(model_path), strict=True)
92+
model.eval()
93+
for k, v in model.named_parameters():
94+
v.requires_grad = False
95+
model = model.to(device)
96+
logger.info('Model path: {:s}'.format(model_path))
97+
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
98+
logger.info('Params number: {}'.format(number_parameters))
99+
100+
logger.info(L_path)
101+
L_paths = util.get_image_paths(L_path)
102+
103+
for idx, img in enumerate(L_paths):
104+
105+
# ------------------------------------
106+
# (1) img_L
107+
# ------------------------------------
108+
img_name, ext = os.path.splitext(os.path.basename(img))
109+
logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
110+
img_L = util.imread_uint(img, n_channels=n_channels)
111+
img_L = util.uint2single(img_L)
112+
if n_channels == 3:
113+
ycbcr = util.rgb2ycbcr(img_L, False)
114+
img_L = ycbcr[..., 0:1]
115+
img_L = util.single2tensor4(img_L)
116+
img_L = img_L.to(device)
117+
118+
# ------------------------------------
119+
# (2) img_E
120+
# ------------------------------------
121+
if not x8:
122+
img_E = model(img_L)
123+
else:
124+
img_E = utils_model.test_mode(model, img_L, mode=3)
125+
126+
img_E = util.tensor2single(img_E)
127+
if n_channels == 3:
128+
ycbcr[..., 0] = img_E
129+
img_E = util.ycbcr2rgb(ycbcr)
130+
img_E = util.single2uint(img_E)
131+
132+
# ------------------------------------
133+
# save results
134+
# ------------------------------------
135+
util.imsave(img_E, os.path.join(E_path, img_name+'.png'))
136+
137+
138+
if __name__ == '__main__':
139+
140+
main()

0 commit comments

Comments
 (0)