Skip to content

Commit b6ca41c

Browse files
Add DeDoDe (clean version) (#2835)
* Add dedode Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a14319c commit b6ca41c

27 files changed

+1991
-12
lines changed

docs/source/feature.rst

+3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ Local Features (Detector and Descriptors together)
5858
.. autoclass:: SOLD2_detector
5959
:members: forward
6060

61+
.. autoclass:: DeDoDe
62+
:members: forward, from_pretrained, describe, detect
63+
6164
.. autoclass:: DISK
6265
:members: forward, from_pretrained, heatmap_and_dense_descriptors
6366

docs/source/references.bib

+14-6
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ @article{tyszkiewicz2020disk
325325
year={2020}
326326
}
327327

328+
@inproceedings{edstedt2024dedode,
329+
title={{DeDoDe: Detect, Don't Describe --- Describe, Don't Detect for Local Feature Matching}},
330+
author = {Johan Edstedt and Georg Bökman and Mårten Wadenbäck and Michael Felsberg},
331+
booktitle={2024 International Conference on 3D Vision (3DV)},
332+
year={2024}
333+
}
334+
328335
@inproceedings{he2010guided,
329336
title = {Guided Image Filtering},
330337
booktitle = {Proceedings of the 11th European Conference on Computer Vision: Part I},
@@ -368,12 +375,6 @@ @inproceedings{barath2020magsac++
368375
year={2020}
369376
}
370377

371-
@inproceedings{wei2023generalized,
372-
author = {Wei, Tong and Patel, Yash and Shekhovtsov, Alexander and Matas, Jiri and Barath, Daniel},
373-
title = {Generalized Differentiable RANSAC},
374-
booktitle = {ICCV},
375-
year = {2023}
376-
}
377378

378379
@inproceedings{shin2017,
379380
title={JPEG-resistant Adversarial Images},
@@ -390,3 +391,10 @@ @inproceedings{reich2024
390391
booktitle={IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
391392
year={2024}
392393
}
394+
395+
@inproceedings{wei2023generalized,
396+
author = {Wei, Tong and Patel, Yash and Shekhovtsov, Alexander and Matas, Jiri and Barath, Daniel},
397+
title = {Generalized Differentiable RANSAC},
398+
booktitle = {ICCV},
399+
year = {2023}
400+
}

kornia/feature/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .affine_shape import LAFAffineShapeEstimator, LAFAffNetShapeEstimator, PatchAffineShapeEstimator
2+
from .dedode import DeDoDe
23
from .defmo import DeFMO
34
from .disk import DISK, DISKFeatures
45
from .hardnet import HardNet, HardNet8
@@ -156,6 +157,7 @@
156157
"perspective_transform_lafs",
157158
"SOLD2_detector",
158159
"SOLD2",
160+
"DeDoDe",
159161
"DISK",
160162
"DISKFeatures",
161163
"LightGlue",

kornia/feature/dedode/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .dedode import DeDoDe
2+
3+
__all__ = ["DeDoDe"]

kornia/feature/dedode/decoder.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import Any, Optional, Tuple
2+
3+
import torch
4+
from torch import nn
5+
6+
from kornia.core import Tensor
7+
8+
9+
class Decoder(nn.Module):
10+
def __init__(self, layers: Any, *args, super_resolution: bool = False, num_prototypes: int = 1, **kwargs) -> None: # type: ignore[no-untyped-def]
11+
super().__init__(*args, **kwargs)
12+
self.layers = layers
13+
self.scales = self.layers.keys()
14+
self.super_resolution = super_resolution
15+
self.num_prototypes = num_prototypes
16+
17+
def forward(
18+
self, features: Tensor, context: Optional[Tensor] = None, scale: Optional[int] = None
19+
) -> Tuple[Tensor, Optional[Tensor]]:
20+
if context is not None:
21+
features = torch.cat((features, context), dim=1)
22+
stuff = self.layers[scale](features)
23+
logits, context = stuff[:, : self.num_prototypes], stuff[:, self.num_prototypes :]
24+
return logits, context
25+
26+
27+
class ConvRefiner(nn.Module):
28+
def __init__( # type: ignore[no-untyped-def]
29+
self,
30+
in_dim=6,
31+
hidden_dim=16,
32+
out_dim=2,
33+
dw=True,
34+
kernel_size=5,
35+
hidden_blocks=5,
36+
amp=True,
37+
residual=False,
38+
amp_dtype=torch.float16,
39+
):
40+
super().__init__()
41+
self.block1 = self.create_block(
42+
in_dim,
43+
hidden_dim,
44+
dw=False,
45+
kernel_size=1,
46+
)
47+
self.hidden_blocks = nn.Sequential(
48+
*[
49+
self.create_block(
50+
hidden_dim,
51+
hidden_dim,
52+
dw=dw,
53+
kernel_size=kernel_size,
54+
)
55+
for hb in range(hidden_blocks)
56+
]
57+
)
58+
self.hidden_blocks = self.hidden_blocks
59+
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
60+
self.amp = amp
61+
self.amp_dtype = amp_dtype
62+
self.residual = residual
63+
64+
def create_block( # type: ignore[no-untyped-def]
65+
self,
66+
in_dim,
67+
out_dim,
68+
dw=True,
69+
kernel_size=5,
70+
bias=True,
71+
norm_type=nn.BatchNorm2d,
72+
):
73+
num_groups = 1 if not dw else in_dim
74+
if dw:
75+
if out_dim % in_dim != 0:
76+
raise Exception("outdim must be divisible by indim for depthwise")
77+
conv1 = nn.Conv2d(
78+
in_dim,
79+
out_dim,
80+
kernel_size=kernel_size,
81+
stride=1,
82+
padding=kernel_size // 2,
83+
groups=num_groups,
84+
bias=bias,
85+
)
86+
norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels=out_dim)
87+
relu = nn.ReLU(inplace=True)
88+
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
89+
return nn.Sequential(conv1, norm, relu, conv2)
90+
91+
def forward(self, feats: Tensor) -> Tensor:
92+
b, c, hs, ws = feats.shape
93+
with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
94+
x0 = self.block1(feats)
95+
x = self.hidden_blocks(x0)
96+
if self.residual:
97+
x = (x + x0) / 1.4
98+
x = self.out_conv(x)
99+
return x

kornia/feature/dedode/dedode.py

+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from typing import Dict, Optional, Tuple
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from kornia.core import Module, Tensor
7+
from kornia.core.check import KORNIA_CHECK_SHAPE
8+
from kornia.enhance.normalize import Normalize
9+
from kornia.geometry.conversions import denormalize_pixel_coordinates
10+
from kornia.utils.helpers import map_location_to_cpu
11+
12+
from .dedode_models import DeDoDeDescriptor, DeDoDeDetector, get_descriptor, get_detector
13+
from .utils import sample_keypoints
14+
15+
urls: Dict[str, Dict[str, str]] = {
16+
"detector": {
17+
"L-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
18+
"L-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_C4.pth",
19+
"L-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_SO2.pth",
20+
},
21+
"descriptor": {
22+
"B-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
23+
"B-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_C4_Perm_descriptor_setting_C.pth",
24+
"B-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_SO2_Spread_descriptor_setting_C.pth",
25+
"G-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth",
26+
"G-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_C4_Perm_descriptor_setting_C.pth",
27+
},
28+
}
29+
30+
31+
class DeDoDe(Module):
32+
r"""Module which detects and/or describes local features in an image using the DeDode method.
33+
34+
See :cite:`edstedt2024dedode` for details.
35+
36+
.. note:: DeDode takes ImageNet normalized images as input (not in range [0, 1]).
37+
Example:
38+
>>> dedode = DeDoDe.from_pretrained(detector_weights="L-upright", descriptor_weights="B-upright")
39+
>>> images = torch.randn(1, 3, 256, 256)
40+
>>> keypoints, scores = dedode.detect(images)
41+
>>> descriptions = dedode.describe(images, keypoints = keypoints)
42+
>>> keypoints, scores, features = dedode(images) # alternatively do both
43+
"""
44+
45+
# TODO: implement steerers and mnn matchers
46+
def __init__(
47+
self, detector_model: str = "L", descriptor_model: str = "G", amp_dtype: torch.dtype = torch.float16
48+
) -> None:
49+
super().__init__()
50+
self.detector: DeDoDeDetector = get_detector(detector_model, amp_dtype)
51+
self.descriptor: DeDoDeDescriptor = get_descriptor(descriptor_model, amp_dtype)
52+
self.normalizer = Normalize(torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
53+
54+
def forward(
55+
self,
56+
images: Tensor,
57+
n: Optional[int] = 10_000,
58+
apply_imagenet_normalization: bool = True,
59+
pad_if_not_divisible: bool = True,
60+
) -> Tuple[Tensor, Tensor, Tensor]:
61+
"""Detects and describes keypoints in the input images.
62+
63+
Args:
64+
images: A tensor of shape :math:`(B, 3, H, W)` containing the ImageNet-Normalized input images.
65+
n: The number of keypoints to detect.
66+
apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
67+
68+
Returns:
69+
keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints in the image range,
70+
unlike `.detect()` function
71+
scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
72+
descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
73+
DIM is 256 for B and 512 for G.
74+
"""
75+
if apply_imagenet_normalization:
76+
images = self.normalizer(images)
77+
B, C, H, W = images.shape
78+
if pad_if_not_divisible:
79+
h, w = images.shape[2:]
80+
pd_h = 14 - h % 14 if h % 14 > 0 else 0
81+
pd_w = 14 - w % 14 if w % 14 > 0 else 0
82+
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
83+
keypoints, scores = self.detect(images, n=n, apply_imagenet_normalization=False, crop_h=h, crop_w=w)
84+
descriptions = self.describe(images, keypoints, apply_imagenet_normalization=False)
85+
return denormalize_pixel_coordinates(keypoints, H, W), scores, descriptions
86+
87+
@torch.inference_mode()
88+
def detect(
89+
self,
90+
images: Tensor,
91+
n: Optional[int] = 10_000,
92+
apply_imagenet_normalization: bool = True,
93+
pad_if_not_divisible: bool = True,
94+
crop_h: Optional[int] = None,
95+
crop_w: Optional[int] = None,
96+
) -> Tuple[Tensor, Tensor]:
97+
"""Detects keypoints in the input images.
98+
99+
Args:
100+
images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
101+
n: The number of keypoints to detect.
102+
apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
103+
crop_h: The height of the crop to be used for detection. If None, the full image is used.
104+
crop_w: The width of the crop to be used for detection. If None, the full image is used.
105+
106+
Returns:
107+
keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints,
108+
normalized to the range :math:`[-1, 1]`.
109+
scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
110+
"""
111+
KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
112+
self.train(False)
113+
if pad_if_not_divisible:
114+
h, w = images.shape[2:]
115+
pd_h = 14 - h % 14 if h % 14 > 0 else 0
116+
pd_w = 14 - w % 14 if w % 14 > 0 else 0
117+
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
118+
if apply_imagenet_normalization:
119+
images = self.normalizer(images)
120+
B, C, H, W = images.shape
121+
logits = self.detector.forward(images)
122+
if crop_h is not None and crop_w is not None:
123+
logits = logits[..., :crop_h, :crop_w]
124+
H, W = crop_h, crop_w
125+
scoremap = logits.reshape(B, H * W).softmax(dim=-1).reshape(B, H, W)
126+
keypoints, confidence = sample_keypoints(scoremap, num_samples=n)
127+
return keypoints, confidence
128+
129+
@torch.inference_mode()
130+
def describe(
131+
self, images: Tensor, keypoints: Optional[Tensor] = None, apply_imagenet_normalization: bool = True
132+
) -> Tensor:
133+
"""Describes keypoints in the input images. If keypoints are not provided, returns the dense descriptors.
134+
135+
Args:
136+
images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
137+
keypoints: An optional tensor of shape :math:`(B, N, 2)` containing the detected keypoints.
138+
apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
139+
140+
Returns:
141+
descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
142+
If the dense descriptors are requested, the shape is :math:`(B, DIM, H, W)`.
143+
"""
144+
KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
145+
B, C, H, W = images.shape
146+
if keypoints is not None:
147+
KORNIA_CHECK_SHAPE(keypoints, ["B", "N", "2"])
148+
if apply_imagenet_normalization:
149+
images = self.normalizer(images)
150+
self.train(False)
151+
descriptions = self.descriptor.forward(images)
152+
if keypoints is not None:
153+
described_keypoints = F.grid_sample(
154+
descriptions.float(), keypoints[:, None], mode="bilinear", align_corners=False
155+
)[:, :, 0].mT
156+
return described_keypoints
157+
return descriptions
158+
159+
@classmethod
160+
def from_pretrained(
161+
cls,
162+
detector_weights: str = "L-upright",
163+
descriptor_weights: str = "G-upright",
164+
amp_dtype: torch.dtype = torch.float16,
165+
) -> Module:
166+
r"""Loads a pretrained model.
167+
168+
Depth model was trained using depth map supervision and is slightly more precise but biased to detect keypoints
169+
only where SfM depth is available. Epipolar model was trained using epipolar geometry supervision and
170+
is less precise but detects keypoints everywhere where they are matchable. The difference is especially
171+
pronounced on thin structures and on edges of objects.
172+
173+
Args:
174+
detector_weights: The weights to load for the detector. One of 'L-upright', 'L-C4', 'L-SO2'.
175+
descriptor_weights: The weights to load for the descriptor.
176+
One of 'B-upright', 'B-C4', 'B-SO2', 'G-upright', 'G-C4'.
177+
checkpoint: The checkpoint to load. One of 'depth' or 'epipolar'.
178+
amp_dtype: the dtype to use for the model. One of torch.float16 or torch.float32.
179+
Default is torch.float16, suitable for CUDA. Use torch.float32 for CPU or MPS
180+
181+
Returns:
182+
The pretrained model.
183+
"""
184+
model: DeDoDe = cls(
185+
detector_model=detector_weights[0], descriptor_model=descriptor_weights[0], amp_dtype=amp_dtype
186+
)
187+
model.detector.load_state_dict(
188+
torch.hub.load_state_dict_from_url(urls["detector"][detector_weights], map_location=map_location_to_cpu)
189+
)
190+
model.descriptor.load_state_dict(
191+
torch.hub.load_state_dict_from_url(urls["descriptor"][descriptor_weights], map_location=map_location_to_cpu)
192+
)
193+
model.eval()
194+
return model

0 commit comments

Comments
 (0)