keywords: Image Registration, Self-supervised Learning, Pretraining
This is a PyTorch implementation of my paper:
The core idea of this work is to leverage randomly generated images to initialize (or pretrain) the encoder of an image registration network. To achieve this, we designed temporary lightweight decoders that are attached to the encoder of the registration DNN, and pretrained the resulting network using a standard image registration loss function on the task of aligning pairs of random images.
This approach is implemented in the MIR package. The source code for generating random images can be found here, and the lightweight decoder is implemented here. The repository also includes training and inference scripts to reproduce the results reported in the paper.
Run python -u train_SSL.py
to initiate the pretraining. We first extract the encoder from the registration DNN and connect it to a lightweight decoder for pretraining.
Pretraining_Image_Registration_DNNs/scripts/train_SSL.py
Lines 36 to 45 in 88a330b
data=rs.gen_shapes(.)
, in which data[0]
and data[1]
contains moving and fixed random images along with their binary label maps stored in data[2]
and data[3]
.
Pretraining_Image_Registration_DNNs/scripts/train_SSL.py
Lines 84 to 86 in 88a330b
Pretraining_Image_Registration_DNNs/scripts/train_SSL.py
Lines 105 to 106 in 88a330b
Run python -u train_TransMorph_SSL.py
to initiate the fine-tuning. We first load the pretrained encoder from the pretraining.
Pretraining_Image_Registration_DNNs/scripts/train_TransMorph_SSL.py
Lines 53 to 62 in adae0b4