Skip to content

Commit 43fbc65

Browse files
committed
add --tile for large image testing
1 parent 014142d commit 43fbc65

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

docs/README_SwinIR.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ python main_test_swinir.py --task lightweight_sr --scale 3 --model_path model_zo
100100
python main_test_swinir.py --task lightweight_sr --scale 4 --model_path model_zoo/swinir/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth --folder_lq testsets/set5/LR_bicubic/X4 --folder_gt testsets/set5/HR
101101

102102

103-
# 003 Real-World Image Super-Resolution
103+
# 003 Real-World Image Super-Resolution (use --tile 400 if you run out-of-memory)
104104
# (middle size)
105105
python main_test_swinir.py --task real_sr --scale 4 --model_path model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth --folder_lq testsets/RealSRSet+5images
106106

main_test_swinir.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def main():
2626
default='model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth')
2727
parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder')
2828
parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder')
29+
parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
30+
parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
2931
args = parser.parse_args()
3032

3133
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -68,7 +70,7 @@ def main():
6870
w_pad = (w_old // window_size + 1) * window_size - w_old
6971
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
7072
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
71-
output = model(img_lq)
73+
output = test(img_lq, model, args, window_size)
7274
output = output[..., :h_old * args.scale, :w_old * args.scale]
7375

7476
# save image
@@ -145,7 +147,7 @@ def define_model(args):
145147
else:
146148
# larger model size; use '3conv' to save parameters and memory; use ema for GAN training
147149
model = net(upscale=4, in_chans=3, img_size=64, window_size=8,
148-
img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=248,
150+
img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
149151
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
150152
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
151153
param_key_g = 'params_ema'
@@ -189,6 +191,8 @@ def setup(args):
189191
# 003 real-world image sr
190192
elif args.task in ['real_sr']:
191193
save_dir = f'results/swinir_{args.task}_x{args.scale}'
194+
if args.large_model:
195+
save_dir += '_large'
192196
folder = args.folder_lq
193197
border = 0
194198
window_size = 8
@@ -249,5 +253,35 @@ def get_image_pair(args, path):
249253
return imgname, img_lq, img_gt
250254

251255

256+
def test(img_lq, model, args, window_size):
257+
if args.tile is None:
258+
# test the image as a whole
259+
output = model(img_lq)
260+
else:
261+
# test the image tile by tile
262+
b, c, h, w = img_lq.size()
263+
tile = min(args.tile, h, w)
264+
assert tile % window_size == 0, "tile size should be a multiple of window_size"
265+
tile_overlap = args.tile_overlap
266+
sf = args.scale
267+
268+
stride = tile - tile_overlap
269+
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
270+
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
271+
E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
272+
W = torch.zeros_like(E)
273+
274+
for h_idx in h_idx_list:
275+
for w_idx in w_idx_list:
276+
in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
277+
out_patch = model(in_patch)
278+
out_patch_mask = torch.ones_like(out_patch)
279+
280+
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
281+
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
282+
output = E.div_(W)
283+
284+
return output
285+
252286
if __name__ == '__main__':
253287
main()

0 commit comments

Comments
 (0)