@@ -26,6 +26,8 @@ def main():
26
26
default = 'model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth' )
27
27
parser .add_argument ('--folder_lq' , type = str , default = None , help = 'input low-quality test image folder' )
28
28
parser .add_argument ('--folder_gt' , type = str , default = None , help = 'input ground-truth test image folder' )
29
+ parser .add_argument ('--tile' , type = int , default = None , help = 'Tile size, None for no tile during testing (testing as a whole)' )
30
+ parser .add_argument ('--tile_overlap' , type = int , default = 32 , help = 'Overlapping of different tiles' )
29
31
args = parser .parse_args ()
30
32
31
33
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
@@ -68,7 +70,7 @@ def main():
68
70
w_pad = (w_old // window_size + 1 ) * window_size - w_old
69
71
img_lq = torch .cat ([img_lq , torch .flip (img_lq , [2 ])], 2 )[:, :, :h_old + h_pad , :]
70
72
img_lq = torch .cat ([img_lq , torch .flip (img_lq , [3 ])], 3 )[:, :, :, :w_old + w_pad ]
71
- output = model (img_lq )
73
+ output = test (img_lq , model , args , window_size )
72
74
output = output [..., :h_old * args .scale , :w_old * args .scale ]
73
75
74
76
# save image
@@ -145,7 +147,7 @@ def define_model(args):
145
147
else :
146
148
# larger model size; use '3conv' to save parameters and memory; use ema for GAN training
147
149
model = net (upscale = 4 , in_chans = 3 , img_size = 64 , window_size = 8 ,
148
- img_range = 1. , depths = [6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 ], embed_dim = 248 ,
150
+ img_range = 1. , depths = [6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 ], embed_dim = 240 ,
149
151
num_heads = [8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ],
150
152
mlp_ratio = 2 , upsampler = 'nearest+conv' , resi_connection = '3conv' )
151
153
param_key_g = 'params_ema'
@@ -189,6 +191,8 @@ def setup(args):
189
191
# 003 real-world image sr
190
192
elif args .task in ['real_sr' ]:
191
193
save_dir = f'results/swinir_{ args .task } _x{ args .scale } '
194
+ if args .large_model :
195
+ save_dir += '_large'
192
196
folder = args .folder_lq
193
197
border = 0
194
198
window_size = 8
@@ -249,5 +253,35 @@ def get_image_pair(args, path):
249
253
return imgname , img_lq , img_gt
250
254
251
255
256
+ def test (img_lq , model , args , window_size ):
257
+ if args .tile is None :
258
+ # test the image as a whole
259
+ output = model (img_lq )
260
+ else :
261
+ # test the image tile by tile
262
+ b , c , h , w = img_lq .size ()
263
+ tile = min (args .tile , h , w )
264
+ assert tile % window_size == 0 , "tile size should be a multiple of window_size"
265
+ tile_overlap = args .tile_overlap
266
+ sf = args .scale
267
+
268
+ stride = tile - tile_overlap
269
+ h_idx_list = list (range (0 , h - tile , stride )) + [h - tile ]
270
+ w_idx_list = list (range (0 , w - tile , stride )) + [w - tile ]
271
+ E = torch .zeros (b , c , h * sf , w * sf ).type_as (img_lq )
272
+ W = torch .zeros_like (E )
273
+
274
+ for h_idx in h_idx_list :
275
+ for w_idx in w_idx_list :
276
+ in_patch = img_lq [..., h_idx :h_idx + tile , w_idx :w_idx + tile ]
277
+ out_patch = model (in_patch )
278
+ out_patch_mask = torch .ones_like (out_patch )
279
+
280
+ E [..., h_idx * sf :(h_idx + tile )* sf , w_idx * sf :(w_idx + tile )* sf ].add_ (out_patch )
281
+ W [..., h_idx * sf :(h_idx + tile )* sf , w_idx * sf :(w_idx + tile )* sf ].add_ (out_patch_mask )
282
+ output = E .div_ (W )
283
+
284
+ return output
285
+
252
286
if __name__ == '__main__' :
253
287
main ()
0 commit comments