Skip to content

junyuchen245/Pretraining_Image_Registration_DNNs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Pretraining Deformable Image Registration Networks with Random Images

keywords: Image Registration, Self-supervised Learning, Pretraining

This is a PyTorch implementation of my paper:

Chen, Junyu, et al. "Pretraining Deformable Image Registration Networks with Random Images." Medical Imaging with Deep Learning. 2025.

The core idea of this work is to leverage randomly generated images to initialize (or pretrain) the encoder of an image registration network. To achieve this, we designed temporary lightweight decoders that are attached to the encoder of the registration DNN, and pretrained the resulting network using a standard image registration loss function on the task of aligning pairs of random images.
This approach is implemented in the MIR package. The source code for generating random images can be found here, and the lightweight decoder is implemented here. The repository also includes training and inference scripts to reproduce the results reported in the paper.

Pretraining and Fine-tuning Pipeline

Step 1: Pretraining the encoder on a proxy task of registering random images

Run python -u train_SSL.py to initiate the pretraining. We first extract the encoder from the registration DNN and connect it to a lightweight decoder for pretraining.

H, W, D = 224, 224, 224
config = CONFIGS_TM.get_3DTransMorphDWin3Lvl_config()
config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
config.window_size = (H // win_factor, W // win_factor, D // win_factor)
config.out_chan = 3
tm = TransMorphTVF(config, time_steps=7, SVF=True, composition='addition', swin_type='swin')
encoder = tm.transformer.cuda()
model = SSLHeadNLvl(encoder, img_size=(H//scale_factor, W//scale_factor, D//scale_factor), channels=(config.embed_dim*4, config.embed_dim*2, config.embed_dim), if_upsamp=True, encoder_output_type='single')
model.cuda()
del tm
In each iteration, a pair of random images is generated using data=rs.gen_shapes(.), in which data[0] and data[1] contains moving and fixed random images along with their binary label maps stored in data[2] and data[3].
data = rs.gen_shapes((H, W, D), res=(H // 32, W // 32, D // 32),)
x = data[0].cuda()
y = data[1].cuda()
We then simply compute the registration loss for pretraining.
loss = loss_ncc + loss_kl + loss_reg
loss_all.update(loss.item(), y.numel())

Step 2: Fine-tuning the DNN on a downstream registration task

Run python -u train_TransMorph_SSL.py to initiate the fine-tuning. We first load the pretrained encoder from the pretraining.

model = TransMorphTVF(config, SVF=True, composition='composition', swin_type='swin')
pretrained_dir = '/scratch/jchen/python_projects/Registration_SSL/experiments/TransMorphSSL_HWD_160_192_224_Scale_2_Wsize_64_PreTrain_RS_dice_1_diffusion_1_sKL_1e-07/'
pretrained = torch.load(pretrained_dir + natsorted(os.listdir(pretrained_dir))[-1])['state_dict']
sslencoder = SSLHeadNLvl(model.transformer, img_size=(H//2, W//2, D//2), channels=(config.embed_dim * 4, config.embed_dim * 2, config.embed_dim), if_upsamp=True)
sslencoder.load_state_dict(pretrained)
model.transformer.load_state_dict(sslencoder.encoder.state_dict())
print('model: pretrained.pth.tar loaded!')
del sslencoder
model.cuda()
We then train the model as usual.

Pretraining Strategy Overview

Generating Paired Random Images

Pretraining Reduces Amount of Data Needed to Achieve Competitive Performance

About

Pretraining Deformable Image Registration Networks with Random Images (PyTorch)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages