diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py new file mode 100644 index 0000000000..187922b835 --- /dev/null +++ b/monai/networks/blocks/text_embedding.py @@ -0,0 +1,90 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch import nn +from torch.utils import model_zoo + +url_map = { + "clip_encoding_univeral_model_32": ( + "https://github.com/Project-MONAI/MONAI-extra-test-data/" + "releases/download/0.8.1/clip_encoding_univeral_model.pth" + ) +} + + +class TextEncoder(nn.Module): + """ + Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding. + The text to vision encoder loads the pre-trained or random initialized weights with connection to 2D/3D vision models. + + Contrastive Language-Image Pre-training (CLIP), based on: "Radford et al., + Learning Transferable Visual Models From Natural Language Supervision " + + Connecting text and medical 3D image, based on: "Liu et al., + CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " + """ + + def __init__( + self, + out_channels: int, + spatial_dims: int = 3, + text_dim: int = 512, + hidden_size: int = 256, + encoding: str = "clip_encoding_univeral_model_32", + pretrained: bool = True, + ) -> None: + """ + Args: + out_channels: number of output channels, to control text-baesd embedding for classes. + spatial_dims: number of spatial dims. + text_dim: dimension of text embeddings. + hidden_size: dimension of hidden features, compatible to different vision feature dimensions. + encoding: the text embedding type, default to use clip text pretrained weights. + pretrained: whether to load pretrained weights from e.g., (CLIP) to initialize text embeddings, default to False. + """ + super().__init__() + self.encoding = encoding + + self.spatial_dims = spatial_dims + if spatial_dims not in (2, 3): + raise ValueError("spatial dimension should be 2 or 3.") + + if self.encoding == "rand_embedding": + self.text_embedding = nn.Embedding(out_channels, hidden_size) + else: + self.register_buffer("text_embedding", torch.randn(out_channels, text_dim)) + + if pretrained: + model_url = url_map[self.encoding] + pretrain_state_dict = model_zoo.load_url(model_url, map_location="cpu") + self.text_embedding.data = pretrain_state_dict.float() # type: ignore + else: + print(f"{self.encoding} is not implemented, and can not be downloaded, please load your own") + + self.text_to_vision = nn.Linear(text_dim, hidden_size) + + def forward(self): + if self.encoding == "rand_embedding": + # text embedding as random initialized 'rand_embedding' + text_embedding = self.text_embedding.weight + else: + print(self.text_embedding) + text_embedding = nn.functional.relu(self.text_to_vision(self.text_embedding)) + + if self.spatial_dims == 3: + text_embedding = text_embedding.unsqueeze(2).unsqueeze(2).unsqueeze(2) + elif self.spatial_dims == 2: + text_embedding = text_embedding.unsqueeze(2).unsqueeze(2) + + return text_embedding diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py new file mode 100644 index 0000000000..06c95c4111 --- /dev/null +++ b/tests/test_text_encoding.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.networks.blocks.text_embedding import TextEncoder +from tests.utils import skip_if_downloading_fails + + +class TestTextEncoder(unittest.TestCase): + def test_test_encoding_shape(self): + with skip_if_downloading_fails(): + # test 2D encoder + text_encoder = TextEncoder( + spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True + ) + text_encoding = text_encoder() + self.assertEqual(text_encoding.shape, (32, 256, 1, 1)) + + # test 3D encoder + text_encoder = TextEncoder( + spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True + ) + text_encoding = text_encoder() + self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1)) + + # test random enbedding 3D + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True) + text_encoding = text_encoder() + self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1)) + + # test random enbedding 2D + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True) + text_encoding = text_encoder() + self.assertEqual(text_encoding.shape, (32, 256, 1, 1)) + + +if __name__ == "__main__": + unittest.main()