Skip to content

Commit 4740c5c

Browse files
committed
ae inference align with master
1 parent d83c4ba commit 4740c5c

File tree

5 files changed

+147
-78
lines changed

5 files changed

+147
-78
lines changed

configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
decoder=dict(codec, heatmap_size=codec['input_size'])),
9797
test_cfg=dict(
9898
multiscale_test=False,
99-
flip_test=True,
99+
flip_test=False,
100100
shift_heatmap=True,
101101
restore_heatmap_size=True,
102102
align_corners=False))
@@ -113,9 +113,14 @@
113113
dict(
114114
type='BottomupResize',
115115
input_size=codec['input_size'],
116-
size_factor=32,
116+
size_factor=64,
117117
resize_mode='expand'),
118-
dict(type='PackPoseInputs')
118+
dict(
119+
type='PackPoseInputs',
120+
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
121+
'img_shape', 'input_size', 'input_center', 'input_scale',
122+
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
123+
'skeleton_links'))
119124
]
120125

121126
# data loaders
@@ -142,7 +147,7 @@
142147
type=dataset_type,
143148
data_root=data_root,
144149
data_mode=data_mode,
145-
ann_file='annotations/person_keypoints_val2017.json',
150+
ann_file='annotations/person_keypoints_val2017_tiny_clean.json',
146151
data_prefix=dict(img='val2017/'),
147152
test_mode=True,
148153
pipeline=val_pipeline,
@@ -152,7 +157,8 @@
152157
# evaluators
153158
val_evaluator = dict(
154159
type='CocoMetric',
155-
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
160+
ann_file=data_root +
161+
'annotations/person_keypoints_val2017_tiny_clean.json',
156162
nms_mode='none',
157163
score_mode='keypoint',
158164
)

demo/bottomup_demo.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
23
import mimetypes
34
import os
5+
import os.path as osp
46
import tempfile
57
from argparse import ArgumentParser
68

@@ -120,57 +122,64 @@ def main():
120122
visualizer = VISUALIZERS.build(model.cfg.visualizer)
121123
visualizer.set_dataset_meta(model.dataset_meta)
122124

123-
input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
124-
if input_type == 'image':
125-
pred_instances = process_one_image(
126-
args, args.input, model, visualizer, show_interval=0)
127-
pred_instances_list = split_instances(pred_instances)
128-
129-
elif input_type == 'video':
130-
tmp_folder = tempfile.TemporaryDirectory()
131-
video = mmcv.VideoReader(args.input)
132-
progressbar = mmengine.ProgressBar(len(video))
133-
video.cvt2frames(tmp_folder.name, show_progress=False)
134-
output_root = args.output_root
135-
args.output_root = tmp_folder.name
136-
pred_instances_list = []
137-
138-
for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
139-
pred_instances = process_one_image(
140-
args,
141-
f'{tmp_folder.name}/{img_fname}',
142-
model,
143-
visualizer,
144-
show_interval=1)
145-
progressbar.update()
146-
pred_instances_list.append(
147-
dict(
148-
frame_id=frame_id,
149-
instances=split_instances(pred_instances)))
150-
151-
if output_root:
152-
mmcv.frames2video(
153-
tmp_folder.name,
154-
f'{output_root}/{os.path.basename(args.input)}',
155-
fps=video.fps,
156-
fourcc='mp4v',
157-
show_progress=False)
158-
tmp_folder.cleanup()
159-
125+
if osp.isfile(args.input):
126+
inputs = [args.input]
160127
else:
161-
args.save_predictions = False
162-
raise ValueError(
163-
f'file {os.path.basename(args.input)} has invalid format.')
128+
inputs = [osp.join(args.input, fn) for fn in os.listdir(args.input)]
164129

165-
if args.save_predictions:
166-
with open(args.pred_save_path, 'w') as f:
167-
json.dump(
168-
dict(
169-
meta_info=model.dataset_meta,
170-
instance_info=pred_instances_list),
171-
f,
172-
indent='\t')
173-
print(f'predictions have been saved at {args.pred_save_path}')
130+
for fn in inputs:
131+
132+
input_type = mimetypes.guess_type(fn)[0].split('/')[0]
133+
if input_type == 'image':
134+
pred_instances = process_one_image(
135+
args, fn, model, visualizer, show_interval=0)
136+
pred_instances_list = split_instances(pred_instances)
137+
138+
elif input_type == 'video':
139+
tmp_folder = tempfile.TemporaryDirectory()
140+
video = mmcv.VideoReader(fn)
141+
progressbar = mmengine.ProgressBar(len(video))
142+
video.cvt2frames(tmp_folder.name, show_progress=False)
143+
output_root = args.output_root
144+
args.output_root = tmp_folder.name
145+
pred_instances_list = []
146+
147+
for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
148+
pred_instances = process_one_image(
149+
args,
150+
f'{tmp_folder.name}/{img_fname}',
151+
model,
152+
visualizer,
153+
show_interval=1)
154+
progressbar.update()
155+
pred_instances_list.append(
156+
dict(
157+
frame_id=frame_id,
158+
instances=split_instances(pred_instances)))
159+
160+
if output_root:
161+
mmcv.frames2video(
162+
tmp_folder.name,
163+
f'{output_root}/{os.path.basename(fn)}',
164+
fps=video.fps,
165+
fourcc='mp4v',
166+
show_progress=False)
167+
tmp_folder.cleanup()
168+
169+
else:
170+
args.save_predictions = False
171+
raise ValueError(
172+
f'file {os.path.basename(fn)} has invalid format.')
173+
174+
if args.save_predictions:
175+
with open(args.pred_save_path, 'w') as f:
176+
json.dump(
177+
dict(
178+
meta_info=model.dataset_meta,
179+
instance_info=pred_instances_list),
180+
f,
181+
indent='\t')
182+
print(f'predictions have been saved at {args.pred_save_path}')
174183

175184

176185
if __name__ == '__main__':

mmpose/codecs/associative_embedding.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from collections import namedtuple
3+
from copy import deepcopy
34
from itertools import product
45
from typing import Any, List, Optional, Tuple
56

67
import numpy as np
78
import torch
9+
from mmengine import dump
810
from munkres import Munkres
911
from torch import Tensor
1012

@@ -75,7 +77,9 @@ def _init_group():
7577
tag_list=[])
7678
return _group
7779

78-
for i in keypoint_order:
80+
group_history = []
81+
82+
for idx, i in enumerate(keypoint_order):
7983
# Get all valid candidate of the i-th keypoints
8084
valid = vals[i] > val_thr
8185
if not valid.any():
@@ -87,12 +91,22 @@ def _init_group():
8791

8892
if len(groups) == 0: # Initialize the group pool
8993
for tag, val, loc in zip(tags_i, vals_i, locs_i):
94+
95+
# Check if the keypoint belongs to existing groups
96+
if len(groups):
97+
prev_tags = np.stack([g.tag_list[0] for g in groups])
98+
dists = np.linalg.norm(prev_tags - tag, ord=2, axis=1)
99+
if dists.min() < 1:
100+
continue
101+
90102
group = _init_group()
91103
group.kpts[i] = loc
92104
group.scores[i] = val
93105
group.tag_list.append(tag)
94106

95107
groups.append(group)
108+
costs_copy = None
109+
matches = None
96110

97111
else: # Match keypoints to existing groups
98112
groups = groups[:max_groups]
@@ -101,17 +115,18 @@ def _init_group():
101115
# Calculate distance matrix between group tags and tag candidates
102116
# of the i-th keypoint
103117
# Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
104-
diff = tags_i[:, None] - np.array(group_tags)[None]
118+
diff = (tags_i[:, None] -
119+
np.array(group_tags)[None]).astype(np.float64)
105120
dists = np.linalg.norm(diff, ord=2, axis=2)
106121
num_kpts, num_groups = dists.shape[:2]
107122

108-
# Experimental cost function for keypoint-group matching
123+
# Experimental cost function for keypoint-group matching2
109124
costs = np.round(dists) * 100 - vals_i[..., None]
125+
110126
if num_kpts > num_groups:
111-
padding = np.full((num_kpts, num_kpts - num_groups),
112-
1e10,
113-
dtype=np.float32)
127+
padding = np.full((num_kpts, num_kpts - num_groups), 1e10)
114128
costs = np.concatenate((costs, padding), axis=1)
129+
costs_copy = costs.copy()
115130

116131
# Match keypoints and groups by Munkres algorithm
117132
matches = munkres.compute(costs)
@@ -121,13 +136,30 @@ def _init_group():
121136
# Add the keypoint to the matched group
122137
group = groups[group_idx]
123138
else:
124-
# Initialize a new group with unmatched keypoint
125-
group = _init_group()
126-
groups.append(group)
127-
128-
group.kpts[i] = locs_i[kpt_idx]
129-
group.scores[i] = vals_i[kpt_idx]
130-
group.tag_list.append(tags_i[kpt_idx])
139+
# if dists[kpt_idx].min() < 0.2:
140+
if False:
141+
group = None
142+
else:
143+
# Initialize a new group with unmatched keypoint
144+
group = _init_group()
145+
groups.append(group)
146+
if group is not None:
147+
group.kpts[i] = locs_i[kpt_idx]
148+
group.scores[i] = vals_i[kpt_idx]
149+
group.tag_list.append(tags_i[kpt_idx])
150+
151+
out = {
152+
'idx': idx,
153+
'i': i,
154+
'costs': costs_copy,
155+
'matches': matches,
156+
'kpts': np.array([g.kpts for g in groups]),
157+
'scores': np.array([g.scores for g in groups]),
158+
'tag_list': [np.array(g.tag_list) for g in groups],
159+
}
160+
group_history.append(deepcopy(out))
161+
162+
dump(group_history, 'group_history.pkl')
131163

132164
groups = groups[:max_groups]
133165
if groups:
@@ -210,7 +242,7 @@ def __init__(
210242
decode_gaussian_kernel: int = 3,
211243
decode_keypoint_thr: float = 0.1,
212244
decode_tag_thr: float = 1.0,
213-
decode_topk: int = 20,
245+
decode_topk: int = 30,
214246
decode_max_instances: Optional[int] = None,
215247
) -> None:
216248
super().__init__()
@@ -336,6 +368,12 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
336368
B, K, H, W = batch_heatmaps.shape
337369
L = batch_tags.shape[1] // K
338370

371+
# Heatmap NMS
372+
dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl')
373+
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
374+
self.decode_nms_kernel)
375+
dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl')
376+
339377
# shape of topk_val, top_indices: (B, K, TopK)
340378
topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk(
341379
k, dim=-1)
@@ -433,9 +471,8 @@ def _fill_missing_keypoints(self, keypoints: np.ndarray,
433471
cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W
434472
y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W))
435473
keypoints[n, k] = [x, y]
436-
keypoint_scores[n, k] = heatmaps[k, y, x]
437474

438-
return keypoints, keypoint_scores
475+
return keypoints
439476

440477
def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
441478
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
@@ -457,15 +494,12 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
457494
batch, each is in shape (N, K). It usually represents the
458495
confidience of the keypoint prediction
459496
"""
497+
460498
B, _, H, W = batch_heatmaps.shape
461499
assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), (
462500
f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and '
463501
f'tagging map ({batch_tags.shape})')
464502

465-
# Heatmap NMS
466-
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
467-
self.decode_nms_kernel)
468-
469503
# Get top-k in each heatmap and and convert to numpy
470504
batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
471505
self._get_batch_topk(
@@ -489,7 +523,7 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
489523

490524
if keypoints.size > 0:
491525
# identify missing keypoints
492-
keypoints, scores = self._fill_missing_keypoints(
526+
keypoints = self._fill_missing_keypoints(
493527
keypoints, scores, heatmaps, tags)
494528

495529
# refine keypoint coordinates according to heatmap distribution
@@ -500,6 +534,8 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
500534
blur_kernel_size=self.decode_gaussian_kernel)
501535
else:
502536
keypoints = refine_keypoints(keypoints, heatmaps)
537+
# keypoints += 0.75
538+
keypoints += 0.5
503539

504540
batch_keypoints.append(keypoints)
505541
batch_keypoint_scores.append(scores)

mmpose/datasets/transforms/bottomup_transforms.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def transform(self, results: Dict) -> Optional[dict]:
478478
output_size=actual_input_size)
479479
else:
480480
center = np.array([img_w / 2, img_h / 2], dtype=np.float32)
481+
center = np.round(center)
481482
scale = np.array([
482483
img_w * padded_input_size[0] / actual_input_size[0],
483484
img_h * padded_input_size[1] / actual_input_size[1]
@@ -489,11 +490,18 @@ def transform(self, results: Dict) -> Optional[dict]:
489490
rot=0,
490491
output_size=padded_input_size)
491492

492-
_img = cv2.warpAffine(
493-
img, warp_mat, padded_input_size, flags=cv2.INTER_LINEAR)
493+
_img = cv2.warpAffine(img, warp_mat, padded_input_size)
494494

495495
imgs.append(_img)
496496

497+
# print('#' * 20)
498+
# print('w,h: ', img_w, img_h, 'center: ', center, 'scale: ',
499+
# scale,
500+
# 'actual_input_size: ', actual_input_size,
501+
# 'padded_input_size: ', padded_input_size)
502+
# print(warp_mat)
503+
# print('#' * 20)
504+
497505
# Store the transform information w.r.t. the main input size
498506
if i == 0:
499507
results['img_shape'] = padded_input_size[::-1]

0 commit comments

Comments
 (0)