1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
2
from collections import namedtuple
3
+ from copy import deepcopy
3
4
from itertools import product
4
5
from typing import Any , List , Optional , Tuple
5
6
6
7
import numpy as np
7
8
import torch
9
+ from mmengine import dump
8
10
from munkres import Munkres
9
11
from torch import Tensor
10
12
@@ -75,7 +77,9 @@ def _init_group():
75
77
tag_list = [])
76
78
return _group
77
79
78
- for i in keypoint_order :
80
+ group_history = []
81
+
82
+ for idx , i in enumerate (keypoint_order ):
79
83
# Get all valid candidate of the i-th keypoints
80
84
valid = vals [i ] > val_thr
81
85
if not valid .any ():
@@ -87,12 +91,22 @@ def _init_group():
87
91
88
92
if len (groups ) == 0 : # Initialize the group pool
89
93
for tag , val , loc in zip (tags_i , vals_i , locs_i ):
94
+
95
+ # Check if the keypoint belongs to existing groups
96
+ if len (groups ):
97
+ prev_tags = np .stack ([g .tag_list [0 ] for g in groups ])
98
+ dists = np .linalg .norm (prev_tags - tag , ord = 2 , axis = 1 )
99
+ if dists .min () < 1 :
100
+ continue
101
+
90
102
group = _init_group ()
91
103
group .kpts [i ] = loc
92
104
group .scores [i ] = val
93
105
group .tag_list .append (tag )
94
106
95
107
groups .append (group )
108
+ costs_copy = None
109
+ matches = None
96
110
97
111
else : # Match keypoints to existing groups
98
112
groups = groups [:max_groups ]
@@ -101,17 +115,18 @@ def _init_group():
101
115
# Calculate distance matrix between group tags and tag candidates
102
116
# of the i-th keypoint
103
117
# Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
104
- diff = tags_i [:, None ] - np .array (group_tags )[None ]
118
+ diff = (tags_i [:, None ] -
119
+ np .array (group_tags )[None ]).astype (np .float64 )
105
120
dists = np .linalg .norm (diff , ord = 2 , axis = 2 )
106
121
num_kpts , num_groups = dists .shape [:2 ]
107
122
108
- # Experimental cost function for keypoint-group matching
123
+ # Experimental cost function for keypoint-group matching2
109
124
costs = np .round (dists ) * 100 - vals_i [..., None ]
125
+
110
126
if num_kpts > num_groups :
111
- padding = np .full ((num_kpts , num_kpts - num_groups ),
112
- 1e10 ,
113
- dtype = np .float32 )
127
+ padding = np .full ((num_kpts , num_kpts - num_groups ), 1e10 )
114
128
costs = np .concatenate ((costs , padding ), axis = 1 )
129
+ costs_copy = costs .copy ()
115
130
116
131
# Match keypoints and groups by Munkres algorithm
117
132
matches = munkres .compute (costs )
@@ -121,13 +136,30 @@ def _init_group():
121
136
# Add the keypoint to the matched group
122
137
group = groups [group_idx ]
123
138
else :
124
- # Initialize a new group with unmatched keypoint
125
- group = _init_group ()
126
- groups .append (group )
127
-
128
- group .kpts [i ] = locs_i [kpt_idx ]
129
- group .scores [i ] = vals_i [kpt_idx ]
130
- group .tag_list .append (tags_i [kpt_idx ])
139
+ # if dists[kpt_idx].min() < 0.2:
140
+ if False :
141
+ group = None
142
+ else :
143
+ # Initialize a new group with unmatched keypoint
144
+ group = _init_group ()
145
+ groups .append (group )
146
+ if group is not None :
147
+ group .kpts [i ] = locs_i [kpt_idx ]
148
+ group .scores [i ] = vals_i [kpt_idx ]
149
+ group .tag_list .append (tags_i [kpt_idx ])
150
+
151
+ out = {
152
+ 'idx' : idx ,
153
+ 'i' : i ,
154
+ 'costs' : costs_copy ,
155
+ 'matches' : matches ,
156
+ 'kpts' : np .array ([g .kpts for g in groups ]),
157
+ 'scores' : np .array ([g .scores for g in groups ]),
158
+ 'tag_list' : [np .array (g .tag_list ) for g in groups ],
159
+ }
160
+ group_history .append (deepcopy (out ))
161
+
162
+ dump (group_history , 'group_history.pkl' )
131
163
132
164
groups = groups [:max_groups ]
133
165
if groups :
@@ -210,7 +242,7 @@ def __init__(
210
242
decode_gaussian_kernel : int = 3 ,
211
243
decode_keypoint_thr : float = 0.1 ,
212
244
decode_tag_thr : float = 1.0 ,
213
- decode_topk : int = 20 ,
245
+ decode_topk : int = 30 ,
214
246
decode_max_instances : Optional [int ] = None ,
215
247
) -> None :
216
248
super ().__init__ ()
@@ -336,6 +368,12 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
336
368
B , K , H , W = batch_heatmaps .shape
337
369
L = batch_tags .shape [1 ] // K
338
370
371
+ # Heatmap NMS
372
+ dump (batch_heatmaps .cpu ().numpy (), 'heatmaps.pkl' )
373
+ batch_heatmaps = batch_heatmap_nms (batch_heatmaps ,
374
+ self .decode_nms_kernel )
375
+ dump (batch_heatmaps .cpu ().numpy (), 'heatmaps_nms.pkl' )
376
+
339
377
# shape of topk_val, top_indices: (B, K, TopK)
340
378
topk_vals , topk_indices = batch_heatmaps .flatten (- 2 , - 1 ).topk (
341
379
k , dim = - 1 )
@@ -433,9 +471,8 @@ def _fill_missing_keypoints(self, keypoints: np.ndarray,
433
471
cost_map = np .round (dist_map ) * 100 - heatmaps [k ] # H, W
434
472
y , x = np .unravel_index (np .argmin (cost_map ), shape = (H , W ))
435
473
keypoints [n , k ] = [x , y ]
436
- keypoint_scores [n , k ] = heatmaps [k , y , x ]
437
474
438
- return keypoints , keypoint_scores
475
+ return keypoints
439
476
440
477
def batch_decode (self , batch_heatmaps : Tensor , batch_tags : Tensor
441
478
) -> Tuple [List [np .ndarray ], List [np .ndarray ]]:
@@ -457,15 +494,12 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
457
494
batch, each is in shape (N, K). It usually represents the
458
495
confidience of the keypoint prediction
459
496
"""
497
+
460
498
B , _ , H , W = batch_heatmaps .shape
461
499
assert batch_tags .shape [0 ] == B and batch_tags .shape [2 :4 ] == (H , W ), (
462
500
f'Mismatched shapes of heatmap ({ batch_heatmaps .shape } ) and '
463
501
f'tagging map ({ batch_tags .shape } )' )
464
502
465
- # Heatmap NMS
466
- batch_heatmaps = batch_heatmap_nms (batch_heatmaps ,
467
- self .decode_nms_kernel )
468
-
469
503
# Get top-k in each heatmap and and convert to numpy
470
504
batch_topk_vals , batch_topk_tags , batch_topk_locs = to_numpy (
471
505
self ._get_batch_topk (
@@ -489,7 +523,7 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
489
523
490
524
if keypoints .size > 0 :
491
525
# identify missing keypoints
492
- keypoints , scores = self ._fill_missing_keypoints (
526
+ keypoints = self ._fill_missing_keypoints (
493
527
keypoints , scores , heatmaps , tags )
494
528
495
529
# refine keypoint coordinates according to heatmap distribution
@@ -500,6 +534,8 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
500
534
blur_kernel_size = self .decode_gaussian_kernel )
501
535
else :
502
536
keypoints = refine_keypoints (keypoints , heatmaps )
537
+ # keypoints += 0.75
538
+ keypoints += 0.5
503
539
504
540
batch_keypoints .append (keypoints )
505
541
batch_keypoint_scores .append (scores )
0 commit comments