Skip to content

Commit 281a187

Browse files
authored
Add files via upload
1 parent 407ac79 commit 281a187

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed

main_train_usrnet.py

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import os.path
2+
import math
3+
import argparse
4+
import time
5+
import random
6+
import numpy as np
7+
from collections import OrderedDict
8+
import logging
9+
from torch.utils.data import DataLoader
10+
import torch
11+
12+
from utils import utils_logger
13+
from utils import utils_image as util
14+
from utils import utils_option as option
15+
from utils import utils_sisr as sisr
16+
17+
from data.select_dataset import define_Dataset
18+
from models.select_model import define_Model
19+
20+
21+
'''
22+
# --------------------------------------------
23+
# training code for USRNet
24+
# --------------------------------------------
25+
# Kai Zhang ([email protected])
26+
# github: https://github.com/cszn/KAIR
27+
# https://github.com/cszn/USRNet
28+
#
29+
# Reference:
30+
@inproceedings{zhang2020deep,
31+
title={Deep unfolding network for image super-resolution},
32+
author={Zhang, Kai and Van Gool, Luc and Timofte, Radu},
33+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
34+
pages={3217--3226},
35+
year={2020}
36+
}
37+
# --------------------------------------------
38+
'''
39+
40+
41+
def main(json_path='options/train_usrnet.json'):
42+
43+
'''
44+
# ----------------------------------------
45+
# Step--1 (prepare opt)
46+
# ----------------------------------------
47+
'''
48+
49+
parser = argparse.ArgumentParser()
50+
parser.add_argument('-opt', type=str, default=json_path, help='Path to option JSON file.')
51+
52+
opt = option.parse(parser.parse_args().opt, is_train=True)
53+
util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key))
54+
55+
# ----------------------------------------
56+
# update opt
57+
# ----------------------------------------
58+
# -->-->-->-->-->-->-->-->-->-->-->-->-->-
59+
init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G')
60+
opt['path']['pretrained_netG'] = init_path_G
61+
current_step = init_iter
62+
63+
border = opt['scale']
64+
# --<--<--<--<--<--<--<--<--<--<--<--<--<-
65+
66+
# ----------------------------------------
67+
# save opt to a '../option.json' file
68+
# ----------------------------------------
69+
option.save(opt)
70+
71+
# ----------------------------------------
72+
# return None for missing key
73+
# ----------------------------------------
74+
opt = option.dict_to_nonedict(opt)
75+
76+
# ----------------------------------------
77+
# configure logger
78+
# ----------------------------------------
79+
logger_name = 'train'
80+
utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log'))
81+
logger = logging.getLogger(logger_name)
82+
logger.info(option.dict2str(opt))
83+
84+
85+
# ----------------------------------------
86+
# seed
87+
# ----------------------------------------
88+
seed = opt['train']['manual_seed']
89+
if seed is None:
90+
seed = random.randint(1, 10000)
91+
logger.info('Random seed: {}'.format(seed))
92+
random.seed(seed)
93+
np.random.seed(seed)
94+
torch.manual_seed(seed)
95+
torch.cuda.manual_seed_all(seed)
96+
97+
'''
98+
# ----------------------------------------
99+
# Step--2 (creat dataloader)
100+
# ----------------------------------------
101+
'''
102+
103+
# ----------------------------------------
104+
# 1) create_dataset
105+
# 2) creat_dataloader for train and test
106+
# ----------------------------------------
107+
for phase, dataset_opt in opt['datasets'].items():
108+
if phase == 'train':
109+
train_set = define_Dataset(dataset_opt)
110+
train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size']))
111+
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
112+
train_loader = DataLoader(train_set,
113+
batch_size=dataset_opt['dataloader_batch_size'],
114+
shuffle=dataset_opt['dataloader_shuffle'],
115+
num_workers=dataset_opt['dataloader_num_workers'],
116+
drop_last=True,
117+
pin_memory=True)
118+
elif phase == 'test':
119+
test_set = define_Dataset(dataset_opt)
120+
test_loader = DataLoader(test_set, batch_size=1,
121+
shuffle=False, num_workers=1,
122+
drop_last=False, pin_memory=True)
123+
else:
124+
raise NotImplementedError("Phase [%s] is not recognized." % phase)
125+
126+
'''
127+
# ----------------------------------------
128+
# Step--3 (initialize model)
129+
# ----------------------------------------
130+
'''
131+
132+
model = define_Model(opt)
133+
134+
logger.info(model.info_network())
135+
model.init_train()
136+
logger.info(model.info_params())
137+
138+
'''
139+
# ----------------------------------------
140+
# Step--4 (main training)
141+
# ----------------------------------------
142+
'''
143+
144+
for epoch in range(1000000): # keep running
145+
for i, train_data in enumerate(train_loader):
146+
147+
current_step += 1
148+
149+
# -------------------------------
150+
# 1) update learning rate
151+
# -------------------------------
152+
model.update_learning_rate(current_step)
153+
154+
# -------------------------------
155+
# 2) feed patch pairs
156+
# -------------------------------
157+
model.feed_data(train_data)
158+
159+
# -------------------------------
160+
# 3) optimize parameters
161+
# -------------------------------
162+
model.optimize_parameters(current_step)
163+
164+
# -------------------------------
165+
# 4) training information
166+
# -------------------------------
167+
if current_step % opt['train']['checkpoint_print'] == 0:
168+
logs = model.current_log() # such as loss
169+
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(epoch, current_step, model.current_learning_rate())
170+
for k, v in logs.items(): # merge log information into message
171+
message += '{:s}: {:.3e} '.format(k, v)
172+
logger.info(message)
173+
174+
# -------------------------------
175+
# 5) save model
176+
# -------------------------------
177+
if current_step % opt['train']['checkpoint_save'] == 0:
178+
logger.info('Saving the model.')
179+
model.save(current_step)
180+
181+
# -------------------------------
182+
# 6) testing
183+
# -------------------------------
184+
if current_step % opt['train']['checkpoint_test'] == 0:
185+
186+
avg_psnr = 0.0
187+
idx = 0
188+
189+
for test_data in test_loader:
190+
idx += 1
191+
image_name_ext = os.path.basename(test_data['L_path'][0])
192+
img_name, ext = os.path.splitext(image_name_ext)
193+
194+
img_dir = os.path.join(opt['path']['images'], img_name)
195+
util.mkdir(img_dir)
196+
197+
model.feed_data(test_data)
198+
model.test()
199+
200+
visuals = model.current_visuals()
201+
E_img = util.tensor2uint(visuals['E'])
202+
H_img = util.tensor2uint(visuals['H'])
203+
204+
# -----------------------
205+
# save estimated image E
206+
# -----------------------
207+
save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
208+
util.imsave(E_img, save_img_path)
209+
210+
# -----------------------
211+
# calculate PSNR
212+
# -----------------------
213+
current_psnr = util.calculate_psnr(E_img, H_img, border=border)
214+
215+
logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr))
216+
217+
avg_psnr += current_psnr
218+
219+
avg_psnr = avg_psnr / idx
220+
221+
# testing log
222+
logger.info('<epoch:{:3d}, iter:{:8,d}, Average PSNR : {:<.2f}dB\n'.format(epoch, current_step, avg_psnr))
223+
224+
logger.info('Saving the final model.')
225+
model.save('latest')
226+
logger.info('End of training.')
227+
228+
229+
if __name__ == '__main__':
230+
main()

0 commit comments

Comments
 (0)