Skip to content

Add changes to align inference-exp with refrence RFDetr implementation #1459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,13 @@ def align_instance_segmentation_results(
masks = masks[
:, mask_pad_top : mh - mask_pad_bottom, mask_pad_left : mw - mask_pad_right
]
masks = functional.resize(
masks,
[original_size.height, original_size.width],
interpolation=functional.InterpolationMode.BILINEAR,
).gt_(0.0).to(dtype=torch.bool)
masks = (
functional.resize(
masks,
[original_size.height, original_size.width],
interpolation=functional.InterpolationMode.BILINEAR,
)
.gt_(0.0)
.to(dtype=torch.bool)
)
return image_bboxes, masks
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def build_backbone(
force_no_pretrain,
gradient_checkpointing,
load_dinov2_weights,
patch_size,
num_windows,
positional_encoding_size,
):
"""
Useful args:
Expand Down Expand Up @@ -88,6 +91,9 @@ def build_backbone(
backbone_lora=backbone_lora,
gradient_checkpointing=gradient_checkpointing,
load_dinov2_weights=load_dinov2_weights,
patch_size=patch_size,
num_windows=num_windows,
positional_encoding_size=positional_encoding_size,
)

model = Joiner(backbone, position_embedding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
use_windowed_attn=True,
gradient_checkpointing=False,
load_dinov2_weights=True,
patch_size=14,
num_windows=4,
positional_encoding_size=37,
):
super().__init__()

Expand All @@ -65,6 +68,8 @@ def __init__(
)

self.shape = shape
self.patch_size = patch_size
self.num_windows = num_windows

# Create the encoder

Expand All @@ -89,18 +94,33 @@ def __init__(

dino_config["return_dict"] = False
dino_config["out_features"] = [f"stage{i}" for i in out_feature_indexes]
implied_resolution = positional_encoding_size * patch_size

if implied_resolution != dino_config["image_size"]:
print(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should think about how we want to do logging long term. Thinking both for general dev-ex and potentially structured logging or log files for analytics / debugging purposes.

f"Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
)
dino_config["image_size"] = implied_resolution
load_dinov2_weights = False

if patch_size != 14:
print(
f"Using patch size {patch_size} instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
)
dino_config["patch_size"] = patch_size
load_dinov2_weights = False

if use_registers:
windowed_dino_config = WindowedDinov2WithRegistersConfig(
**dino_config,
num_windows=4,
num_windows=num_windows,
window_block_indexes=window_block_indexes,
gradient_checkpointing=gradient_checkpointing,
)
else:
windowed_dino_config = WindowedDinov2WithRegistersConfig(
**dino_config,
num_windows=4,
num_windows=num_windows,
window_block_indexes=window_block_indexes,
num_register_tokens=0,
gradient_checkpointing=gradient_checkpointing,
Expand Down Expand Up @@ -179,9 +199,10 @@ def new_interpolate_pos_encoding(self_mod, embeddings, height, width):
)

def forward(self, x):
block_size = self.patch_size * self.num_windows
assert (
x.shape[2] % 14 == 0 and x.shape[3] % 14 == 0
), f"Dinov2 requires input shape to be divisible by 14, but got {x.shape}"
x.shape[2] % block_size == 0 and x.shape[3] % block_size == 0
), f"Backbone requires input shape to be divisible by {block_size}, but got {x.shape}"
x = self.encoder(x)
return list(x[0])

Expand Down Expand Up @@ -214,6 +235,9 @@ def __init__(
backbone_lora: bool = False,
gradient_checkpointing: bool = False,
load_dinov2_weights: bool = True,
patch_size: int = 14,
num_windows: int = 4,
positional_encoding_size: bool = False,
):
super().__init__()
# an example name here would be "dinov2_base" or "dinov2_registers_windowed_base"
Expand Down Expand Up @@ -243,6 +267,9 @@ def __init__(
use_windowed_attn=use_windowed_attn,
gradient_checkpointing=gradient_checkpointing,
load_dinov2_weights=load_dinov2_weights,
patch_size=patch_size,
num_windows=num_windows,
positional_encoding_size=positional_encoding_size,
)
# build encoder + projector as backbone module
if freeze_encoder:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import copy
import math
from dataclasses import asdict, dataclass, field
from typing import Callable, List, Literal, Optional, Union

import torch
Expand All @@ -10,16 +9,19 @@
from inference_exp.models.rfdetr.backbone_builder import build_backbone
from inference_exp.models.rfdetr.misc import NestedTensor
from inference_exp.models.rfdetr.transformer import build_transformer
from pydantic import BaseModel, ConfigDict
from torch import Tensor, nn


@dataclass
class ModelConfig:
device: Union[Literal["cpu", "cuda", "mps"], torch.device]
class ModelConfig(BaseModel):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
out_feature_indexes: List[int]
dec_layers: int
two_stage: bool = True
projector_scale: List[Literal["P3", "P4", "P5"]]
hidden_dim: int
patch_size: int
num_windows: int
sa_nheads: int
ca_nheads: int
dec_n_points: int
Expand All @@ -29,44 +31,75 @@ class ModelConfig:
amp: bool = True
num_classes: int = 90
pretrain_weights: Optional[str] = None
dec_layers: int = 3
two_stage: bool = True
resolution: int = 560
device: torch.device
resolution: int
group_detr: int = 13
gradient_checkpointing: bool = False
positional_encoding_size: int

model_config = ConfigDict(arbitrary_types_allowed=True)


@dataclass
class RFDETRBaseConfig(ModelConfig):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = field(
default="dinov2_windowed_small"
)
hidden_dim: int = field(default=256)
sa_nheads: int = field(default=8)
ca_nheads: int = field(default=16)
dec_n_points: int = field(default=2)
num_queries: int = field(default=300)
num_select: int = field(default=300)
projector_scale: List[Literal["P3", "P4", "P5"]] = field(
default_factory=lambda: ["P4"]
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
"dinov2_windowed_small"
)
out_feature_indexes: List[int] = field(default_factory=lambda: [2, 5, 8, 11])
pretrain_weights: Optional[str] = field(default="rf-detr-base.pth")
hidden_dim: int = 256
patch_size: int = 14
num_windows: int = 4
dec_layers: int = 3
sa_nheads: int = 8
ca_nheads: int = 16
dec_n_points: int = 2
num_queries: int = 300
num_select: int = 300
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
out_feature_indexes: List[int] = [2, 5, 8, 11]
pretrain_weights: Optional[str] = "rf-detr-base.pth"
resolution: int = 560
positional_encoding_size: int = 37


@dataclass
class RFDETRLargeConfig(RFDETRBaseConfig):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = field(
default="dinov2_windowed_base"
)
hidden_dim: int = field(default=384)
sa_nheads: int = field(default=12)
ca_nheads: int = field(default=24)
dec_n_points: int = field(default=4)
projector_scale: List[Literal["P3", "P4", "P5"]] = field(
default_factory=lambda: ["P3", "P5"]
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
"dinov2_windowed_base"
)
pretrain_weights: Optional[str] = field(default="rf-detr-large.pth")
hidden_dim: int = 384
sa_nheads: int = 12
ca_nheads: int = 24
dec_n_points: int = 4
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
pretrain_weights: Optional[str] = "rf-detr-large.pth"


class RFDETRNanoConfig(RFDETRBaseConfig):
out_feature_indexes: List[int] = [3, 6, 9, 12]
num_windows: int = 2
dec_layers: int = 2
patch_size: int = 16
resolution: int = 384
positional_encoding_size: int = 24
pretrain_weights: Optional[str] = "rf-detr-nano.pth"


class RFDETRSmallConfig(RFDETRBaseConfig):
out_feature_indexes: List[int] = [3, 6, 9, 12]
num_windows: int = 2
dec_layers: int = 3
patch_size: int = 16
resolution: int = 512
positional_encoding_size: int = 32
pretrain_weights: Optional[str] = "rf-detr-small.pth"


class RFDETRMediumConfig(RFDETRBaseConfig):
out_feature_indexes: List[int] = [3, 6, 9, 12]
num_windows: int = 2
dec_layers: int = 4
patch_size: int = 16
resolution: int = 576
positional_encoding_size: int = 36
pretrain_weights: Optional[str] = "rf-detr-medium.pth"


class LWDETR(nn.Module):
Expand Down Expand Up @@ -212,24 +245,27 @@ def forward(self, samples: NestedTensor, targets=None):
srcs, masks, poss, refpoint_embed_weight, query_feat_weight
)

if self.bbox_reparam:
outputs_coord_delta = self.bbox_embed(hs)
outputs_coord_cxcy = (
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
+ ref_unsigmoid[..., :2]
)
outputs_coord_wh = (
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
)
outputs_coord = torch.concat([outputs_coord_cxcy, outputs_coord_wh], dim=-1)
else:
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
if hs is not None:
if self.bbox_reparam:
outputs_coord_delta = self.bbox_embed(hs)
outputs_coord_cxcy = (
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
+ ref_unsigmoid[..., :2]
)
outputs_coord_wh = (
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
)
outputs_coord = torch.concat(
[outputs_coord_cxcy, outputs_coord_wh], dim=-1
)
else:
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()

outputs_class = self.class_embed(hs)
outputs_class = self.class_embed(hs)

out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
if self.aux_loss:
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
if self.aux_loss:
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)

if self.two_stage:
group_detr = self.group_detr if self.training else 1
Expand All @@ -241,7 +277,11 @@ def forward(self, samples: NestedTensor, targets=None):
)
cls_enc.append(cls_enc_gidx)
cls_enc = torch.cat(cls_enc, dim=1)
out["enc_outputs"] = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
if hs is not None:
out["enc_outputs"] = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
else:
out = {"pred_logits": cls_enc, "pred_boxes": ref_enc}

return out

def forward_export(self, tensors):
Expand All @@ -254,19 +294,27 @@ def forward_export(self, tensors):
srcs, None, poss, refpoint_embed_weight, query_feat_weight
)

if self.bbox_reparam:
outputs_coord_delta = self.bbox_embed(hs)
outputs_coord_cxcy = (
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
+ ref_unsigmoid[..., :2]
)
outputs_coord_wh = (
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
)
outputs_coord = torch.concat([outputs_coord_cxcy, outputs_coord_wh], dim=-1)
if hs is not None:
if self.bbox_reparam:
outputs_coord_delta = self.bbox_embed(hs)
outputs_coord_cxcy = (
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
+ ref_unsigmoid[..., :2]
)
outputs_coord_wh = (
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
)
outputs_coord = torch.concat(
[outputs_coord_cxcy, outputs_coord_wh], dim=-1
)
else:
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
outputs_class = self.class_embed(hs)
else:
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
outputs_class = self.class_embed(hs)
assert self.two_stage, "if not using decoder, two_stage must be True"
outputs_class = self.transformer.enc_out_class_embed[0](hs_enc)
outputs_coord = ref_enc

return outputs_coord, outputs_class

@torch.jit.unused
Expand Down Expand Up @@ -399,7 +447,7 @@ def build_model(config: ModelConfig) -> LWDETR:
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
args = populate_args(**asdict(config))
args = populate_args(**config.dict())
num_classes = args.num_classes + 1
backbone = build_backbone(
encoder=args.encoder,
Expand Down Expand Up @@ -429,6 +477,9 @@ def build_model(config: ModelConfig) -> LWDETR:
force_no_pretrain=args.force_no_pretrain,
gradient_checkpointing=args.gradient_checkpointing,
load_dinov2_weights=args.pretrain_weights is None,
patch_size=config.patch_size,
num_windows=config.num_windows,
positional_encoding_size=config.positional_encoding_size,
)
if args.encoder_only:
return backbone[0].encoder, None, None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
LWDETR,
RFDETRBaseConfig,
RFDETRLargeConfig,
RFDETRMediumConfig,
RFDETRNanoConfig,
RFDETRSmallConfig,
build_model,
)

Expand All @@ -33,6 +36,9 @@
pass

CONFIG_FOR_MODEL_TYPE = {
"rfdetr-nano": RFDETRNanoConfig,
"rfdetr-small": RFDETRSmallConfig,
"rfdetr-medium": RFDETRMediumConfig,
"rfdetr-base": RFDETRBaseConfig,
"rfdetr-large": RFDETRLargeConfig,
}
Expand Down
Loading