|
| 1 | +from typing import Dict, Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | +from kornia.core import Module, Tensor |
| 7 | +from kornia.core.check import KORNIA_CHECK_SHAPE |
| 8 | +from kornia.enhance.normalize import Normalize |
| 9 | +from kornia.geometry.conversions import denormalize_pixel_coordinates |
| 10 | +from kornia.utils.helpers import map_location_to_cpu |
| 11 | + |
| 12 | +from .dedode_models import DeDoDeDescriptor, DeDoDeDetector, get_descriptor, get_detector |
| 13 | +from .utils import sample_keypoints |
| 14 | + |
| 15 | +urls: Dict[str, Dict[str, str]] = { |
| 16 | + "detector": { |
| 17 | + "L-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth", |
| 18 | + "L-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_C4.pth", |
| 19 | + "L-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_SO2.pth", |
| 20 | + }, |
| 21 | + "descriptor": { |
| 22 | + "B-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth", |
| 23 | + "B-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_C4_Perm_descriptor_setting_C.pth", |
| 24 | + "B-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_SO2_Spread_descriptor_setting_C.pth", |
| 25 | + "G-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth", |
| 26 | + "G-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_C4_Perm_descriptor_setting_C.pth", |
| 27 | + }, |
| 28 | +} |
| 29 | + |
| 30 | + |
| 31 | +class DeDoDe(Module): |
| 32 | + r"""Module which detects and/or describes local features in an image using the DeDode method. |
| 33 | +
|
| 34 | + See :cite:`edstedt2024dedode` for details. |
| 35 | +
|
| 36 | + .. note:: DeDode takes ImageNet normalized images as input (not in range [0, 1]). |
| 37 | + Example: |
| 38 | + >>> dedode = DeDoDe.from_pretrained(detector_weights="L-upright", descriptor_weights="B-upright") |
| 39 | + >>> images = torch.randn(1, 3, 256, 256) |
| 40 | + >>> keypoints, scores = dedode.detect(images) |
| 41 | + >>> descriptions = dedode.describe(images, keypoints = keypoints) |
| 42 | + >>> keypoints, scores, features = dedode(images) # alternatively do both |
| 43 | + """ |
| 44 | + |
| 45 | + # TODO: implement steerers and mnn matchers |
| 46 | + def __init__( |
| 47 | + self, detector_model: str = "L", descriptor_model: str = "G", amp_dtype: torch.dtype = torch.float16 |
| 48 | + ) -> None: |
| 49 | + super().__init__() |
| 50 | + self.detector: DeDoDeDetector = get_detector(detector_model, amp_dtype) |
| 51 | + self.descriptor: DeDoDeDescriptor = get_descriptor(descriptor_model, amp_dtype) |
| 52 | + self.normalizer = Normalize(torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) |
| 53 | + |
| 54 | + def forward( |
| 55 | + self, |
| 56 | + images: Tensor, |
| 57 | + n: Optional[int] = 10_000, |
| 58 | + apply_imagenet_normalization: bool = True, |
| 59 | + pad_if_not_divisible: bool = True, |
| 60 | + ) -> Tuple[Tensor, Tensor, Tensor]: |
| 61 | + """Detects and describes keypoints in the input images. |
| 62 | +
|
| 63 | + Args: |
| 64 | + images: A tensor of shape :math:`(B, 3, H, W)` containing the ImageNet-Normalized input images. |
| 65 | + n: The number of keypoints to detect. |
| 66 | + apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints in the image range, |
| 70 | + unlike `.detect()` function |
| 71 | + scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints. |
| 72 | + descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints. |
| 73 | + DIM is 256 for B and 512 for G. |
| 74 | + """ |
| 75 | + if apply_imagenet_normalization: |
| 76 | + images = self.normalizer(images) |
| 77 | + B, C, H, W = images.shape |
| 78 | + if pad_if_not_divisible: |
| 79 | + h, w = images.shape[2:] |
| 80 | + pd_h = 14 - h % 14 if h % 14 > 0 else 0 |
| 81 | + pd_w = 14 - w % 14 if w % 14 > 0 else 0 |
| 82 | + images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0) |
| 83 | + keypoints, scores = self.detect(images, n=n, apply_imagenet_normalization=False, crop_h=h, crop_w=w) |
| 84 | + descriptions = self.describe(images, keypoints, apply_imagenet_normalization=False) |
| 85 | + return denormalize_pixel_coordinates(keypoints, H, W), scores, descriptions |
| 86 | + |
| 87 | + @torch.inference_mode() |
| 88 | + def detect( |
| 89 | + self, |
| 90 | + images: Tensor, |
| 91 | + n: Optional[int] = 10_000, |
| 92 | + apply_imagenet_normalization: bool = True, |
| 93 | + pad_if_not_divisible: bool = True, |
| 94 | + crop_h: Optional[int] = None, |
| 95 | + crop_w: Optional[int] = None, |
| 96 | + ) -> Tuple[Tensor, Tensor]: |
| 97 | + """Detects keypoints in the input images. |
| 98 | +
|
| 99 | + Args: |
| 100 | + images: A tensor of shape :math:`(B, 3, H, W)` containing the input images. |
| 101 | + n: The number of keypoints to detect. |
| 102 | + apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images. |
| 103 | + crop_h: The height of the crop to be used for detection. If None, the full image is used. |
| 104 | + crop_w: The width of the crop to be used for detection. If None, the full image is used. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints, |
| 108 | + normalized to the range :math:`[-1, 1]`. |
| 109 | + scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints. |
| 110 | + """ |
| 111 | + KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"]) |
| 112 | + self.train(False) |
| 113 | + if pad_if_not_divisible: |
| 114 | + h, w = images.shape[2:] |
| 115 | + pd_h = 14 - h % 14 if h % 14 > 0 else 0 |
| 116 | + pd_w = 14 - w % 14 if w % 14 > 0 else 0 |
| 117 | + images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0) |
| 118 | + if apply_imagenet_normalization: |
| 119 | + images = self.normalizer(images) |
| 120 | + B, C, H, W = images.shape |
| 121 | + logits = self.detector.forward(images) |
| 122 | + if crop_h is not None and crop_w is not None: |
| 123 | + logits = logits[..., :crop_h, :crop_w] |
| 124 | + H, W = crop_h, crop_w |
| 125 | + scoremap = logits.reshape(B, H * W).softmax(dim=-1).reshape(B, H, W) |
| 126 | + keypoints, confidence = sample_keypoints(scoremap, num_samples=n) |
| 127 | + return keypoints, confidence |
| 128 | + |
| 129 | + @torch.inference_mode() |
| 130 | + def describe( |
| 131 | + self, images: Tensor, keypoints: Optional[Tensor] = None, apply_imagenet_normalization: bool = True |
| 132 | + ) -> Tensor: |
| 133 | + """Describes keypoints in the input images. If keypoints are not provided, returns the dense descriptors. |
| 134 | +
|
| 135 | + Args: |
| 136 | + images: A tensor of shape :math:`(B, 3, H, W)` containing the input images. |
| 137 | + keypoints: An optional tensor of shape :math:`(B, N, 2)` containing the detected keypoints. |
| 138 | + apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints. |
| 142 | + If the dense descriptors are requested, the shape is :math:`(B, DIM, H, W)`. |
| 143 | + """ |
| 144 | + KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"]) |
| 145 | + B, C, H, W = images.shape |
| 146 | + if keypoints is not None: |
| 147 | + KORNIA_CHECK_SHAPE(keypoints, ["B", "N", "2"]) |
| 148 | + if apply_imagenet_normalization: |
| 149 | + images = self.normalizer(images) |
| 150 | + self.train(False) |
| 151 | + descriptions = self.descriptor.forward(images) |
| 152 | + if keypoints is not None: |
| 153 | + described_keypoints = F.grid_sample( |
| 154 | + descriptions.float(), keypoints[:, None], mode="bilinear", align_corners=False |
| 155 | + )[:, :, 0].mT |
| 156 | + return described_keypoints |
| 157 | + return descriptions |
| 158 | + |
| 159 | + @classmethod |
| 160 | + def from_pretrained( |
| 161 | + cls, |
| 162 | + detector_weights: str = "L-upright", |
| 163 | + descriptor_weights: str = "G-upright", |
| 164 | + amp_dtype: torch.dtype = torch.float16, |
| 165 | + ) -> Module: |
| 166 | + r"""Loads a pretrained model. |
| 167 | +
|
| 168 | + Depth model was trained using depth map supervision and is slightly more precise but biased to detect keypoints |
| 169 | + only where SfM depth is available. Epipolar model was trained using epipolar geometry supervision and |
| 170 | + is less precise but detects keypoints everywhere where they are matchable. The difference is especially |
| 171 | + pronounced on thin structures and on edges of objects. |
| 172 | +
|
| 173 | + Args: |
| 174 | + detector_weights: The weights to load for the detector. One of 'L-upright', 'L-C4', 'L-SO2'. |
| 175 | + descriptor_weights: The weights to load for the descriptor. |
| 176 | + One of 'B-upright', 'B-C4', 'B-SO2', 'G-upright', 'G-C4'. |
| 177 | + checkpoint: The checkpoint to load. One of 'depth' or 'epipolar'. |
| 178 | + amp_dtype: the dtype to use for the model. One of torch.float16 or torch.float32. |
| 179 | + Default is torch.float16, suitable for CUDA. Use torch.float32 for CPU or MPS |
| 180 | +
|
| 181 | + Returns: |
| 182 | + The pretrained model. |
| 183 | + """ |
| 184 | + model: DeDoDe = cls( |
| 185 | + detector_model=detector_weights[0], descriptor_model=descriptor_weights[0], amp_dtype=amp_dtype |
| 186 | + ) |
| 187 | + model.detector.load_state_dict( |
| 188 | + torch.hub.load_state_dict_from_url(urls["detector"][detector_weights], map_location=map_location_to_cpu) |
| 189 | + ) |
| 190 | + model.descriptor.load_state_dict( |
| 191 | + torch.hub.load_state_dict_from_url(urls["descriptor"][descriptor_weights], map_location=map_location_to_cpu) |
| 192 | + ) |
| 193 | + model.eval() |
| 194 | + return model |
0 commit comments