1
1
import argparse
2
2
import copy
3
3
import math
4
- from dataclasses import asdict , dataclass , field
5
4
from typing import Callable , List , Literal , Optional , Union
6
5
7
6
import torch
10
9
from inference_exp .models .rfdetr .backbone_builder import build_backbone
11
10
from inference_exp .models .rfdetr .misc import NestedTensor
12
11
from inference_exp .models .rfdetr .transformer import build_transformer
12
+ from pydantic import BaseModel , ConfigDict
13
13
from torch import Tensor , nn
14
14
15
15
16
- @dataclass
17
- class ModelConfig :
18
- device : Union [Literal ["cpu" , "cuda" , "mps" ], torch .device ]
16
+ class ModelConfig (BaseModel ):
19
17
encoder : Literal ["dinov2_windowed_small" , "dinov2_windowed_base" ]
20
18
out_feature_indexes : List [int ]
19
+ dec_layers : int
20
+ two_stage : bool = True
21
21
projector_scale : List [Literal ["P3" , "P4" , "P5" ]]
22
22
hidden_dim : int
23
+ patch_size : int
24
+ num_windows : int
23
25
sa_nheads : int
24
26
ca_nheads : int
25
27
dec_n_points : int
@@ -29,44 +31,75 @@ class ModelConfig:
29
31
amp : bool = True
30
32
num_classes : int = 90
31
33
pretrain_weights : Optional [str ] = None
32
- dec_layers : int = 3
33
- two_stage : bool = True
34
- resolution : int = 560
34
+ device : torch .device
35
+ resolution : int
35
36
group_detr : int = 13
36
37
gradient_checkpointing : bool = False
38
+ positional_encoding_size : int
39
+
40
+ model_config = ConfigDict (arbitrary_types_allowed = True )
37
41
38
42
39
- @dataclass
40
43
class RFDETRBaseConfig (ModelConfig ):
41
- encoder : Literal ["dinov2_windowed_small" , "dinov2_windowed_base" ] = field (
42
- default = "dinov2_windowed_small"
43
- )
44
- hidden_dim : int = field (default = 256 )
45
- sa_nheads : int = field (default = 8 )
46
- ca_nheads : int = field (default = 16 )
47
- dec_n_points : int = field (default = 2 )
48
- num_queries : int = field (default = 300 )
49
- num_select : int = field (default = 300 )
50
- projector_scale : List [Literal ["P3" , "P4" , "P5" ]] = field (
51
- default_factory = lambda : ["P4" ]
44
+ encoder : Literal ["dinov2_windowed_small" , "dinov2_windowed_base" ] = (
45
+ "dinov2_windowed_small"
52
46
)
53
- out_feature_indexes : List [int ] = field (default_factory = lambda : [2 , 5 , 8 , 11 ])
54
- pretrain_weights : Optional [str ] = field (default = "rf-detr-base.pth" )
47
+ hidden_dim : int = 256
48
+ patch_size : int = 14
49
+ num_windows : int = 4
50
+ dec_layers : int = 3
51
+ sa_nheads : int = 8
52
+ ca_nheads : int = 16
53
+ dec_n_points : int = 2
54
+ num_queries : int = 300
55
+ num_select : int = 300
56
+ projector_scale : List [Literal ["P3" , "P4" , "P5" ]] = ["P4" ]
57
+ out_feature_indexes : List [int ] = [2 , 5 , 8 , 11 ]
58
+ pretrain_weights : Optional [str ] = "rf-detr-base.pth"
59
+ resolution : int = 560
60
+ positional_encoding_size : int = 37
55
61
56
62
57
- @dataclass
58
63
class RFDETRLargeConfig (RFDETRBaseConfig ):
59
- encoder : Literal ["dinov2_windowed_small" , "dinov2_windowed_base" ] = field (
60
- default = "dinov2_windowed_base"
61
- )
62
- hidden_dim : int = field (default = 384 )
63
- sa_nheads : int = field (default = 12 )
64
- ca_nheads : int = field (default = 24 )
65
- dec_n_points : int = field (default = 4 )
66
- projector_scale : List [Literal ["P3" , "P4" , "P5" ]] = field (
67
- default_factory = lambda : ["P3" , "P5" ]
64
+ encoder : Literal ["dinov2_windowed_small" , "dinov2_windowed_base" ] = (
65
+ "dinov2_windowed_base"
68
66
)
69
- pretrain_weights : Optional [str ] = field (default = "rf-detr-large.pth" )
67
+ hidden_dim : int = 384
68
+ sa_nheads : int = 12
69
+ ca_nheads : int = 24
70
+ dec_n_points : int = 4
71
+ projector_scale : List [Literal ["P3" , "P4" , "P5" ]] = ["P3" , "P5" ]
72
+ pretrain_weights : Optional [str ] = "rf-detr-large.pth"
73
+
74
+
75
+ class RFDETRNanoConfig (RFDETRBaseConfig ):
76
+ out_feature_indexes : List [int ] = [3 , 6 , 9 , 12 ]
77
+ num_windows : int = 2
78
+ dec_layers : int = 2
79
+ patch_size : int = 16
80
+ resolution : int = 384
81
+ positional_encoding_size : int = 24
82
+ pretrain_weights : Optional [str ] = "rf-detr-nano.pth"
83
+
84
+
85
+ class RFDETRSmallConfig (RFDETRBaseConfig ):
86
+ out_feature_indexes : List [int ] = [3 , 6 , 9 , 12 ]
87
+ num_windows : int = 2
88
+ dec_layers : int = 3
89
+ patch_size : int = 16
90
+ resolution : int = 512
91
+ positional_encoding_size : int = 32
92
+ pretrain_weights : Optional [str ] = "rf-detr-small.pth"
93
+
94
+
95
+ class RFDETRMediumConfig (RFDETRBaseConfig ):
96
+ out_feature_indexes : List [int ] = [3 , 6 , 9 , 12 ]
97
+ num_windows : int = 2
98
+ dec_layers : int = 4
99
+ patch_size : int = 16
100
+ resolution : int = 576
101
+ positional_encoding_size : int = 36
102
+ pretrain_weights : Optional [str ] = "rf-detr-medium.pth"
70
103
71
104
72
105
class LWDETR (nn .Module ):
@@ -212,24 +245,27 @@ def forward(self, samples: NestedTensor, targets=None):
212
245
srcs , masks , poss , refpoint_embed_weight , query_feat_weight
213
246
)
214
247
215
- if self .bbox_reparam :
216
- outputs_coord_delta = self .bbox_embed (hs )
217
- outputs_coord_cxcy = (
218
- outputs_coord_delta [..., :2 ] * ref_unsigmoid [..., 2 :]
219
- + ref_unsigmoid [..., :2 ]
220
- )
221
- outputs_coord_wh = (
222
- outputs_coord_delta [..., 2 :].exp () * ref_unsigmoid [..., 2 :]
223
- )
224
- outputs_coord = torch .concat ([outputs_coord_cxcy , outputs_coord_wh ], dim = - 1 )
225
- else :
226
- outputs_coord = (self .bbox_embed (hs ) + ref_unsigmoid ).sigmoid ()
248
+ if hs is not None :
249
+ if self .bbox_reparam :
250
+ outputs_coord_delta = self .bbox_embed (hs )
251
+ outputs_coord_cxcy = (
252
+ outputs_coord_delta [..., :2 ] * ref_unsigmoid [..., 2 :]
253
+ + ref_unsigmoid [..., :2 ]
254
+ )
255
+ outputs_coord_wh = (
256
+ outputs_coord_delta [..., 2 :].exp () * ref_unsigmoid [..., 2 :]
257
+ )
258
+ outputs_coord = torch .concat (
259
+ [outputs_coord_cxcy , outputs_coord_wh ], dim = - 1
260
+ )
261
+ else :
262
+ outputs_coord = (self .bbox_embed (hs ) + ref_unsigmoid ).sigmoid ()
227
263
228
- outputs_class = self .class_embed (hs )
264
+ outputs_class = self .class_embed (hs )
229
265
230
- out = {"pred_logits" : outputs_class [- 1 ], "pred_boxes" : outputs_coord [- 1 ]}
231
- if self .aux_loss :
232
- out ["aux_outputs" ] = self ._set_aux_loss (outputs_class , outputs_coord )
266
+ out = {"pred_logits" : outputs_class [- 1 ], "pred_boxes" : outputs_coord [- 1 ]}
267
+ if self .aux_loss :
268
+ out ["aux_outputs" ] = self ._set_aux_loss (outputs_class , outputs_coord )
233
269
234
270
if self .two_stage :
235
271
group_detr = self .group_detr if self .training else 1
@@ -241,7 +277,11 @@ def forward(self, samples: NestedTensor, targets=None):
241
277
)
242
278
cls_enc .append (cls_enc_gidx )
243
279
cls_enc = torch .cat (cls_enc , dim = 1 )
244
- out ["enc_outputs" ] = {"pred_logits" : cls_enc , "pred_boxes" : ref_enc }
280
+ if hs is not None :
281
+ out ["enc_outputs" ] = {"pred_logits" : cls_enc , "pred_boxes" : ref_enc }
282
+ else :
283
+ out = {"pred_logits" : cls_enc , "pred_boxes" : ref_enc }
284
+
245
285
return out
246
286
247
287
def forward_export (self , tensors ):
@@ -254,19 +294,27 @@ def forward_export(self, tensors):
254
294
srcs , None , poss , refpoint_embed_weight , query_feat_weight
255
295
)
256
296
257
- if self .bbox_reparam :
258
- outputs_coord_delta = self .bbox_embed (hs )
259
- outputs_coord_cxcy = (
260
- outputs_coord_delta [..., :2 ] * ref_unsigmoid [..., 2 :]
261
- + ref_unsigmoid [..., :2 ]
262
- )
263
- outputs_coord_wh = (
264
- outputs_coord_delta [..., 2 :].exp () * ref_unsigmoid [..., 2 :]
265
- )
266
- outputs_coord = torch .concat ([outputs_coord_cxcy , outputs_coord_wh ], dim = - 1 )
297
+ if hs is not None :
298
+ if self .bbox_reparam :
299
+ outputs_coord_delta = self .bbox_embed (hs )
300
+ outputs_coord_cxcy = (
301
+ outputs_coord_delta [..., :2 ] * ref_unsigmoid [..., 2 :]
302
+ + ref_unsigmoid [..., :2 ]
303
+ )
304
+ outputs_coord_wh = (
305
+ outputs_coord_delta [..., 2 :].exp () * ref_unsigmoid [..., 2 :]
306
+ )
307
+ outputs_coord = torch .concat (
308
+ [outputs_coord_cxcy , outputs_coord_wh ], dim = - 1
309
+ )
310
+ else :
311
+ outputs_coord = (self .bbox_embed (hs ) + ref_unsigmoid ).sigmoid ()
312
+ outputs_class = self .class_embed (hs )
267
313
else :
268
- outputs_coord = (self .bbox_embed (hs ) + ref_unsigmoid ).sigmoid ()
269
- outputs_class = self .class_embed (hs )
314
+ assert self .two_stage , "if not using decoder, two_stage must be True"
315
+ outputs_class = self .transformer .enc_out_class_embed [0 ](hs_enc )
316
+ outputs_coord = ref_enc
317
+
270
318
return outputs_coord , outputs_class
271
319
272
320
@torch .jit .unused
@@ -399,7 +447,7 @@ def build_model(config: ModelConfig) -> LWDETR:
399
447
# you should pass `num_classes` to be 2 (max_obj_id + 1).
400
448
# For more details on this, check the following discussion
401
449
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
402
- args = populate_args (** asdict ( config ))
450
+ args = populate_args (** config . dict ( ))
403
451
num_classes = args .num_classes + 1
404
452
backbone = build_backbone (
405
453
encoder = args .encoder ,
@@ -429,6 +477,9 @@ def build_model(config: ModelConfig) -> LWDETR:
429
477
force_no_pretrain = args .force_no_pretrain ,
430
478
gradient_checkpointing = args .gradient_checkpointing ,
431
479
load_dinov2_weights = args .pretrain_weights is None ,
480
+ patch_size = config .patch_size ,
481
+ num_windows = config .num_windows ,
482
+ positional_encoding_size = config .positional_encoding_size ,
432
483
)
433
484
if args .encoder_only :
434
485
return backbone [0 ].encoder , None , None
0 commit comments