Skip to content

Commit 59d6cc4

Browse files
Merge pull request #1459 from roboflow/feature/rfdetr-upgrade-in-inference-exp
Add changes to align `inference-exp` with refrence RFDetr implementation
2 parents cc2b3d2 + 47281b0 commit 59d6cc4

File tree

7 files changed

+211
-106
lines changed

7 files changed

+211
-106
lines changed

inference_experimental/inference_exp/models/common/roboflow/post_processing.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,13 @@ def align_instance_segmentation_results(
306306
masks = masks[
307307
:, mask_pad_top : mh - mask_pad_bottom, mask_pad_left : mw - mask_pad_right
308308
]
309-
masks = functional.resize(
310-
masks,
311-
[original_size.height, original_size.width],
312-
interpolation=functional.InterpolationMode.BILINEAR,
313-
).gt_(0.0).to(dtype=torch.bool)
309+
masks = (
310+
functional.resize(
311+
masks,
312+
[original_size.height, original_size.width],
313+
interpolation=functional.InterpolationMode.BILINEAR,
314+
)
315+
.gt_(0.0)
316+
.to(dtype=torch.bool)
317+
)
314318
return image_bboxes, masks

inference_experimental/inference_exp/models/rfdetr/backbone_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def build_backbone(
6161
force_no_pretrain,
6262
gradient_checkpointing,
6363
load_dinov2_weights,
64+
patch_size,
65+
num_windows,
66+
positional_encoding_size,
6467
):
6568
"""
6669
Useful args:
@@ -88,6 +91,9 @@ def build_backbone(
8891
backbone_lora=backbone_lora,
8992
gradient_checkpointing=gradient_checkpointing,
9093
load_dinov2_weights=load_dinov2_weights,
94+
patch_size=patch_size,
95+
num_windows=num_windows,
96+
positional_encoding_size=positional_encoding_size,
9197
)
9298

9399
model = Joiner(backbone, position_embedding)

inference_experimental/inference_exp/models/rfdetr/rfdetr_backbone_pytorch.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def __init__(
5555
use_windowed_attn=True,
5656
gradient_checkpointing=False,
5757
load_dinov2_weights=True,
58+
patch_size=14,
59+
num_windows=4,
60+
positional_encoding_size=37,
5861
):
5962
super().__init__()
6063

@@ -65,6 +68,8 @@ def __init__(
6568
)
6669

6770
self.shape = shape
71+
self.patch_size = patch_size
72+
self.num_windows = num_windows
6873

6974
# Create the encoder
7075

@@ -89,18 +94,33 @@ def __init__(
8994

9095
dino_config["return_dict"] = False
9196
dino_config["out_features"] = [f"stage{i}" for i in out_feature_indexes]
97+
implied_resolution = positional_encoding_size * patch_size
98+
99+
if implied_resolution != dino_config["image_size"]:
100+
print(
101+
f"Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
102+
)
103+
dino_config["image_size"] = implied_resolution
104+
load_dinov2_weights = False
105+
106+
if patch_size != 14:
107+
print(
108+
f"Using patch size {patch_size} instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
109+
)
110+
dino_config["patch_size"] = patch_size
111+
load_dinov2_weights = False
92112

93113
if use_registers:
94114
windowed_dino_config = WindowedDinov2WithRegistersConfig(
95115
**dino_config,
96-
num_windows=4,
116+
num_windows=num_windows,
97117
window_block_indexes=window_block_indexes,
98118
gradient_checkpointing=gradient_checkpointing,
99119
)
100120
else:
101121
windowed_dino_config = WindowedDinov2WithRegistersConfig(
102122
**dino_config,
103-
num_windows=4,
123+
num_windows=num_windows,
104124
window_block_indexes=window_block_indexes,
105125
num_register_tokens=0,
106126
gradient_checkpointing=gradient_checkpointing,
@@ -179,9 +199,10 @@ def new_interpolate_pos_encoding(self_mod, embeddings, height, width):
179199
)
180200

181201
def forward(self, x):
202+
block_size = self.patch_size * self.num_windows
182203
assert (
183-
x.shape[2] % 14 == 0 and x.shape[3] % 14 == 0
184-
), f"Dinov2 requires input shape to be divisible by 14, but got {x.shape}"
204+
x.shape[2] % block_size == 0 and x.shape[3] % block_size == 0
205+
), f"Backbone requires input shape to be divisible by {block_size}, but got {x.shape}"
185206
x = self.encoder(x)
186207
return list(x[0])
187208

@@ -214,6 +235,9 @@ def __init__(
214235
backbone_lora: bool = False,
215236
gradient_checkpointing: bool = False,
216237
load_dinov2_weights: bool = True,
238+
patch_size: int = 14,
239+
num_windows: int = 4,
240+
positional_encoding_size: bool = False,
217241
):
218242
super().__init__()
219243
# an example name here would be "dinov2_base" or "dinov2_registers_windowed_base"
@@ -243,6 +267,9 @@ def __init__(
243267
use_windowed_attn=use_windowed_attn,
244268
gradient_checkpointing=gradient_checkpointing,
245269
load_dinov2_weights=load_dinov2_weights,
270+
patch_size=patch_size,
271+
num_windows=num_windows,
272+
positional_encoding_size=positional_encoding_size,
246273
)
247274
# build encoder + projector as backbone module
248275
if freeze_encoder:

inference_experimental/inference_exp/models/rfdetr/rfdetr_base_pytorch.py

Lines changed: 113 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import copy
33
import math
4-
from dataclasses import asdict, dataclass, field
54
from typing import Callable, List, Literal, Optional, Union
65

76
import torch
@@ -10,16 +9,19 @@
109
from inference_exp.models.rfdetr.backbone_builder import build_backbone
1110
from inference_exp.models.rfdetr.misc import NestedTensor
1211
from inference_exp.models.rfdetr.transformer import build_transformer
12+
from pydantic import BaseModel, ConfigDict
1313
from torch import Tensor, nn
1414

1515

16-
@dataclass
17-
class ModelConfig:
18-
device: Union[Literal["cpu", "cuda", "mps"], torch.device]
16+
class ModelConfig(BaseModel):
1917
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
2018
out_feature_indexes: List[int]
19+
dec_layers: int
20+
two_stage: bool = True
2121
projector_scale: List[Literal["P3", "P4", "P5"]]
2222
hidden_dim: int
23+
patch_size: int
24+
num_windows: int
2325
sa_nheads: int
2426
ca_nheads: int
2527
dec_n_points: int
@@ -29,44 +31,75 @@ class ModelConfig:
2931
amp: bool = True
3032
num_classes: int = 90
3133
pretrain_weights: Optional[str] = None
32-
dec_layers: int = 3
33-
two_stage: bool = True
34-
resolution: int = 560
34+
device: torch.device
35+
resolution: int
3536
group_detr: int = 13
3637
gradient_checkpointing: bool = False
38+
positional_encoding_size: int
39+
40+
model_config = ConfigDict(arbitrary_types_allowed=True)
3741

3842

39-
@dataclass
4043
class RFDETRBaseConfig(ModelConfig):
41-
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = field(
42-
default="dinov2_windowed_small"
43-
)
44-
hidden_dim: int = field(default=256)
45-
sa_nheads: int = field(default=8)
46-
ca_nheads: int = field(default=16)
47-
dec_n_points: int = field(default=2)
48-
num_queries: int = field(default=300)
49-
num_select: int = field(default=300)
50-
projector_scale: List[Literal["P3", "P4", "P5"]] = field(
51-
default_factory=lambda: ["P4"]
44+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
45+
"dinov2_windowed_small"
5246
)
53-
out_feature_indexes: List[int] = field(default_factory=lambda: [2, 5, 8, 11])
54-
pretrain_weights: Optional[str] = field(default="rf-detr-base.pth")
47+
hidden_dim: int = 256
48+
patch_size: int = 14
49+
num_windows: int = 4
50+
dec_layers: int = 3
51+
sa_nheads: int = 8
52+
ca_nheads: int = 16
53+
dec_n_points: int = 2
54+
num_queries: int = 300
55+
num_select: int = 300
56+
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
57+
out_feature_indexes: List[int] = [2, 5, 8, 11]
58+
pretrain_weights: Optional[str] = "rf-detr-base.pth"
59+
resolution: int = 560
60+
positional_encoding_size: int = 37
5561

5662

57-
@dataclass
5863
class RFDETRLargeConfig(RFDETRBaseConfig):
59-
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = field(
60-
default="dinov2_windowed_base"
61-
)
62-
hidden_dim: int = field(default=384)
63-
sa_nheads: int = field(default=12)
64-
ca_nheads: int = field(default=24)
65-
dec_n_points: int = field(default=4)
66-
projector_scale: List[Literal["P3", "P4", "P5"]] = field(
67-
default_factory=lambda: ["P3", "P5"]
64+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
65+
"dinov2_windowed_base"
6866
)
69-
pretrain_weights: Optional[str] = field(default="rf-detr-large.pth")
67+
hidden_dim: int = 384
68+
sa_nheads: int = 12
69+
ca_nheads: int = 24
70+
dec_n_points: int = 4
71+
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
72+
pretrain_weights: Optional[str] = "rf-detr-large.pth"
73+
74+
75+
class RFDETRNanoConfig(RFDETRBaseConfig):
76+
out_feature_indexes: List[int] = [3, 6, 9, 12]
77+
num_windows: int = 2
78+
dec_layers: int = 2
79+
patch_size: int = 16
80+
resolution: int = 384
81+
positional_encoding_size: int = 24
82+
pretrain_weights: Optional[str] = "rf-detr-nano.pth"
83+
84+
85+
class RFDETRSmallConfig(RFDETRBaseConfig):
86+
out_feature_indexes: List[int] = [3, 6, 9, 12]
87+
num_windows: int = 2
88+
dec_layers: int = 3
89+
patch_size: int = 16
90+
resolution: int = 512
91+
positional_encoding_size: int = 32
92+
pretrain_weights: Optional[str] = "rf-detr-small.pth"
93+
94+
95+
class RFDETRMediumConfig(RFDETRBaseConfig):
96+
out_feature_indexes: List[int] = [3, 6, 9, 12]
97+
num_windows: int = 2
98+
dec_layers: int = 4
99+
patch_size: int = 16
100+
resolution: int = 576
101+
positional_encoding_size: int = 36
102+
pretrain_weights: Optional[str] = "rf-detr-medium.pth"
70103

71104

72105
class LWDETR(nn.Module):
@@ -212,24 +245,27 @@ def forward(self, samples: NestedTensor, targets=None):
212245
srcs, masks, poss, refpoint_embed_weight, query_feat_weight
213246
)
214247

215-
if self.bbox_reparam:
216-
outputs_coord_delta = self.bbox_embed(hs)
217-
outputs_coord_cxcy = (
218-
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
219-
+ ref_unsigmoid[..., :2]
220-
)
221-
outputs_coord_wh = (
222-
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
223-
)
224-
outputs_coord = torch.concat([outputs_coord_cxcy, outputs_coord_wh], dim=-1)
225-
else:
226-
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
248+
if hs is not None:
249+
if self.bbox_reparam:
250+
outputs_coord_delta = self.bbox_embed(hs)
251+
outputs_coord_cxcy = (
252+
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
253+
+ ref_unsigmoid[..., :2]
254+
)
255+
outputs_coord_wh = (
256+
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
257+
)
258+
outputs_coord = torch.concat(
259+
[outputs_coord_cxcy, outputs_coord_wh], dim=-1
260+
)
261+
else:
262+
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
227263

228-
outputs_class = self.class_embed(hs)
264+
outputs_class = self.class_embed(hs)
229265

230-
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
231-
if self.aux_loss:
232-
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
266+
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
267+
if self.aux_loss:
268+
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
233269

234270
if self.two_stage:
235271
group_detr = self.group_detr if self.training else 1
@@ -241,7 +277,11 @@ def forward(self, samples: NestedTensor, targets=None):
241277
)
242278
cls_enc.append(cls_enc_gidx)
243279
cls_enc = torch.cat(cls_enc, dim=1)
244-
out["enc_outputs"] = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
280+
if hs is not None:
281+
out["enc_outputs"] = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
282+
else:
283+
out = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
284+
245285
return out
246286

247287
def forward_export(self, tensors):
@@ -254,19 +294,27 @@ def forward_export(self, tensors):
254294
srcs, None, poss, refpoint_embed_weight, query_feat_weight
255295
)
256296

257-
if self.bbox_reparam:
258-
outputs_coord_delta = self.bbox_embed(hs)
259-
outputs_coord_cxcy = (
260-
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
261-
+ ref_unsigmoid[..., :2]
262-
)
263-
outputs_coord_wh = (
264-
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
265-
)
266-
outputs_coord = torch.concat([outputs_coord_cxcy, outputs_coord_wh], dim=-1)
297+
if hs is not None:
298+
if self.bbox_reparam:
299+
outputs_coord_delta = self.bbox_embed(hs)
300+
outputs_coord_cxcy = (
301+
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
302+
+ ref_unsigmoid[..., :2]
303+
)
304+
outputs_coord_wh = (
305+
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
306+
)
307+
outputs_coord = torch.concat(
308+
[outputs_coord_cxcy, outputs_coord_wh], dim=-1
309+
)
310+
else:
311+
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
312+
outputs_class = self.class_embed(hs)
267313
else:
268-
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
269-
outputs_class = self.class_embed(hs)
314+
assert self.two_stage, "if not using decoder, two_stage must be True"
315+
outputs_class = self.transformer.enc_out_class_embed[0](hs_enc)
316+
outputs_coord = ref_enc
317+
270318
return outputs_coord, outputs_class
271319

272320
@torch.jit.unused
@@ -399,7 +447,7 @@ def build_model(config: ModelConfig) -> LWDETR:
399447
# you should pass `num_classes` to be 2 (max_obj_id + 1).
400448
# For more details on this, check the following discussion
401449
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
402-
args = populate_args(**asdict(config))
450+
args = populate_args(**config.dict())
403451
num_classes = args.num_classes + 1
404452
backbone = build_backbone(
405453
encoder=args.encoder,
@@ -429,6 +477,9 @@ def build_model(config: ModelConfig) -> LWDETR:
429477
force_no_pretrain=args.force_no_pretrain,
430478
gradient_checkpointing=args.gradient_checkpointing,
431479
load_dinov2_weights=args.pretrain_weights is None,
480+
patch_size=config.patch_size,
481+
num_windows=config.num_windows,
482+
positional_encoding_size=config.positional_encoding_size,
432483
)
433484
if args.encoder_only:
434485
return backbone[0].encoder, None, None

inference_experimental/inference_exp/models/rfdetr/rfdetr_object_detection_pytorch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
LWDETR,
2525
RFDETRBaseConfig,
2626
RFDETRLargeConfig,
27+
RFDETRMediumConfig,
28+
RFDETRNanoConfig,
29+
RFDETRSmallConfig,
2730
build_model,
2831
)
2932

@@ -33,6 +36,9 @@
3336
pass
3437

3538
CONFIG_FOR_MODEL_TYPE = {
39+
"rfdetr-nano": RFDETRNanoConfig,
40+
"rfdetr-small": RFDETRSmallConfig,
41+
"rfdetr-medium": RFDETRMediumConfig,
3642
"rfdetr-base": RFDETRBaseConfig,
3743
"rfdetr-large": RFDETRLargeConfig,
3844
}

0 commit comments

Comments
 (0)