Skip to content

Commit 14f531a

Browse files
committed
remove center rounding in bottom-up affine
1 parent 4740c5c commit 14f531a

File tree

3 files changed

+28
-23
lines changed

3 files changed

+28
-23
lines changed

configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
type=dataset_type,
148148
data_root=data_root,
149149
data_mode=data_mode,
150-
ann_file='annotations/person_keypoints_val2017_tiny_clean.json',
150+
ann_file='annotations/person_keypoints_val2017.json',
151151
data_prefix=dict(img='val2017/'),
152152
test_mode=True,
153153
pipeline=val_pipeline,
@@ -157,8 +157,7 @@
157157
# evaluators
158158
val_evaluator = dict(
159159
type='CocoMetric',
160-
ann_file=data_root +
161-
'annotations/person_keypoints_val2017_tiny_clean.json',
160+
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
162161
nms_mode='none',
163162
score_mode='keypoint',
164163
)

mmpose/codecs/associative_embedding.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from collections import namedtuple
3-
from copy import deepcopy
3+
# from copy import deepcopy
44
from itertools import product
55
from typing import Any, List, Optional, Tuple
66

77
import numpy as np
88
import torch
9-
from mmengine import dump
9+
# from mmengine import dump
1010
from munkres import Munkres
1111
from torch import Tensor
1212

@@ -77,7 +77,7 @@ def _init_group():
7777
tag_list=[])
7878
return _group
7979

80-
group_history = []
80+
# group_history = []
8181

8282
for idx, i in enumerate(keypoint_order):
8383
# Get all valid candidate of the i-th keypoints
@@ -105,7 +105,7 @@ def _init_group():
105105
group.tag_list.append(tag)
106106

107107
groups.append(group)
108-
costs_copy = None
108+
# costs_copy = None
109109
matches = None
110110

111111
else: # Match keypoints to existing groups
@@ -126,7 +126,7 @@ def _init_group():
126126
if num_kpts > num_groups:
127127
padding = np.full((num_kpts, num_kpts - num_groups), 1e10)
128128
costs = np.concatenate((costs, padding), axis=1)
129-
costs_copy = costs.copy()
129+
# costs_copy = costs.copy()
130130

131131
# Match keypoints and groups by Munkres algorithm
132132
matches = munkres.compute(costs)
@@ -148,18 +148,18 @@ def _init_group():
148148
group.scores[i] = vals_i[kpt_idx]
149149
group.tag_list.append(tags_i[kpt_idx])
150150

151-
out = {
152-
'idx': idx,
153-
'i': i,
154-
'costs': costs_copy,
155-
'matches': matches,
156-
'kpts': np.array([g.kpts for g in groups]),
157-
'scores': np.array([g.scores for g in groups]),
158-
'tag_list': [np.array(g.tag_list) for g in groups],
159-
}
160-
group_history.append(deepcopy(out))
151+
# out = {
152+
# 'idx': idx,
153+
# 'i': i,
154+
# 'costs': costs_copy,
155+
# 'matches': matches,
156+
# 'kpts': np.array([g.kpts for g in groups]),
157+
# 'scores': np.array([g.scores for g in groups]),
158+
# 'tag_list': [np.array(g.tag_list) for g in groups],
159+
# }
160+
# group_history.append(deepcopy(out))
161161

162-
dump(group_history, 'group_history.pkl')
162+
# dump(group_history, 'group_history.pkl')
163163

164164
groups = groups[:max_groups]
165165
if groups:
@@ -369,10 +369,10 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
369369
L = batch_tags.shape[1] // K
370370

371371
# Heatmap NMS
372-
dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl')
372+
# dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl')
373373
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
374374
self.decode_nms_kernel)
375-
dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl')
375+
# dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl')
376376

377377
# shape of topk_val, top_indices: (B, K, TopK)
378378
topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk(
@@ -534,7 +534,13 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
534534
blur_kernel_size=self.decode_gaussian_kernel)
535535
else:
536536
keypoints = refine_keypoints(keypoints, heatmaps)
537-
# keypoints += 0.75
537+
# The following 0.5-pixel shift is adapted from mmpose 0.x
538+
# where the heatmap center is calculated by a biased
539+
# rounding ``mu=[int(x), int(y)]``. We keep this shift
540+
# operation for now to to compatible with 0.x checkpoints
541+
# In mmpose 1.x, AE heatmap center is calculated by the
542+
# unbiased rounding ``mu=[int(x+0.5), int(y+0.5)], so the
543+
# following shift will be removed in the future.
538544
keypoints += 0.5
539545

540546
batch_keypoints.append(keypoints)

mmpose/datasets/transforms/bottomup_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def transform(self, results: Dict) -> Optional[dict]:
478478
output_size=actual_input_size)
479479
else:
480480
center = np.array([img_w / 2, img_h / 2], dtype=np.float32)
481-
center = np.round(center)
481+
# center = np.round(center)
482482
scale = np.array([
483483
img_w * padded_input_size[0] / actual_input_size[0],
484484
img_h * padded_input_size[1] / actual_input_size[1]

0 commit comments

Comments
 (0)