-
Notifications
You must be signed in to change notification settings - Fork 51
[Feat] Add KeypointAUC metric #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
3b8a414
[Feat] Add KeypointAUC metric
LareinaM f9ba5d5
Update docs and examples
LareinaM 6be1c5a
Update example
LareinaM ad87906
Add testcase
LareinaM 9d204b7
Update support_matrix doc
LareinaM 25a9fc2
Restore support_matrix doc
LareinaM 5aa1680
Modify typos in comments
LareinaM 11c312f
Update comment
LareinaM 1b00988
Update example according to comment
LareinaM f7ac6eb
Update comment
LareinaM fd5e149
Changes according to comments
LareinaM 9c8aff9
Rename parameters
LareinaM 6af59b7
Merge branch 'main' into auc
zhouzaida File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,3 +48,4 @@ Metrics | |
ConnectivityError | ||
DOTAMeanAP | ||
ROUGE | ||
KeypointAUC |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,3 +48,4 @@ Metrics | |
ConnectivityError | ||
DOTAMeanAP | ||
ROUGE | ||
KeypointAUC |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import logging | ||
import numpy as np | ||
from collections import OrderedDict | ||
from typing import Dict, List | ||
|
||
from mmeval.core.base_metric import BaseMetric | ||
from .pck_accuracy import keypoint_pck_accuracy | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def keypoint_auc_accuracy(pred: np.ndarray, | ||
gt: np.ndarray, | ||
LareinaM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask: np.ndarray, | ||
norm_factor: np.ndarray, | ||
num_thrs: int = 20) -> float: | ||
"""Calculate the Area under curve (AUC) of keypoint PCK accuracy. | ||
|
||
Note: | ||
- instance number: N | ||
- keypoint number: K | ||
|
||
Args: | ||
pred (np.ndarray[N, K, 2]): Predicted keypoint location. | ||
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. | ||
LareinaM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask (np.ndarray[N, K]): Visibility of the target. False for invisible | ||
joints, and True for visible. Invisible joints will be ignored for | ||
accuracy calculation. | ||
norm_factor (float): Normalization factor. | ||
num_thrs (int): number of thresholds to calculate auc. | ||
|
||
Returns: | ||
float: Area under curve (AUC) of keypoint PCK accuracy. | ||
""" | ||
nor = np.tile(np.array([[norm_factor, norm_factor]]), (pred.shape[0], 1)) | ||
thrs = [1.0 * i / num_thrs for i in range(num_thrs)] | ||
avg_accs = [] | ||
for thr in thrs: | ||
_, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor) | ||
avg_accs.append(avg_acc) | ||
|
||
auc = 0 | ||
for i in range(num_thrs): | ||
auc += 1.0 / num_thrs * avg_accs[i] | ||
return auc | ||
|
||
|
||
class KeypointAUC(BaseMetric): | ||
"""AUC evaluation metric. | ||
|
||
Calculate the Area Under Curve (AUC) of keypoint PCK accuracy. | ||
|
||
By altering the threshold percentage in the calculation of PCK accuracy, | ||
AUC can be generated to further evaluate the pose estimation algorithms. | ||
|
||
Note: | ||
- length of dataset: N | ||
- num_keypoints: K | ||
- number of keypoint dimensions: D (typically D = 2) | ||
|
||
Args: | ||
norm_factor (float): AUC normalization factor, Default: 30 (pixels). | ||
num_thrs (int): number of thresholds to calculate auc. Default: 20. | ||
LareinaM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric`. Must | ||
include ``dataset_meta`` in order to compute the metric. | ||
|
||
Examples: | ||
|
||
>>> from mmeval import KeypointAUC | ||
>>> import numpy as np | ||
>>> output = np.array([[[10., 4.], | ||
... [10., 18.], | ||
... [ 0., 0.], | ||
... [40., 40.], | ||
... [20., 10.]]]) | ||
>>> target = np.array([[[10., 0.], | ||
... [10., 10.], | ||
... [ 0., -1.], | ||
... [30., 30.], | ||
... [ 0., 10.]]]) | ||
>>> keypoints_visible = np.array([[True, True, False, True, True]]) | ||
>>> num_keypoints = 15 | ||
>>> prediction = {'coords': output} | ||
>>> groundtruth = {'coords': target, 'mask': keypoints_visible} | ||
>>> predictions = [prediction] | ||
>>> groundtruths = [groundtruth] | ||
>>> auc_metric = KeypointAUC(norm_factor=20, num_thrs=4) | ||
LareinaM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> auc_metric(predictions, groundtruths) | ||
OrderedDict([('AUC@4', 0.375)]) | ||
""" | ||
C1rN09 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, | ||
norm_factor: float = 30, | ||
num_thrs: int = 20, | ||
**kwargs) -> None: | ||
super().__init__(**kwargs) | ||
self.norm_factor = norm_factor | ||
self.num_thrs = num_thrs | ||
|
||
def add(self, predictions: List[Dict], groundtruths: List[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501 | ||
"""Process one batch of predictions and groundtruths and add the | ||
intermediate results to `self._results`. | ||
|
||
Args: | ||
predictions (Sequence[dict]): Predictions from the model. | ||
Each prediction dict has the following keys: | ||
|
||
- coords (np.ndarray, [1, K, D]): predicted keypoints | ||
coordinates | ||
|
||
groundtruths (Sequence[dict]): The ground truth labels. | ||
Each groundtruth dict has the following keys: | ||
|
||
- coords (np.ndarray, [1, K, D]): ground truth keypoints | ||
coordinates | ||
- mask (np.ndarray, [1, K]): ground truth keypoints_visible | ||
""" | ||
for prediction, groundtruth in zip(predictions, groundtruths): | ||
self._results.append((prediction, groundtruth)) | ||
|
||
def compute_metric(self, results: list) -> Dict[str, float]: | ||
"""Compute the metrics from processed results. | ||
|
||
Args: | ||
results (list): The processed results of each batch. | ||
|
||
Returns: | ||
Dict[str, float]: The computed metrics. The keys are the names of | ||
the metrics, and the values are corresponding results. | ||
""" | ||
# split gt and prediction list | ||
preds, gts = zip(*results) | ||
C1rN09 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# pred_coords: [N, K, D] | ||
pred_coords = np.concatenate([pred['coords'] for pred in preds]) | ||
# gt_coords: [N, K, D] | ||
gt_coords = np.concatenate([gt['coords'] for gt in gts]) | ||
# mask: [N, K] | ||
mask = np.concatenate([gt['mask'] for gt in gts]) | ||
|
||
logger.info(f'Evaluating {self.__class__.__name__}...') | ||
|
||
auc = keypoint_auc_accuracy(pred_coords, gt_coords, mask, | ||
self.norm_factor, self.num_thrs) | ||
|
||
metric_results: OrderedDict = OrderedDict() | ||
metric_results[f'AUC@{self.num_thrs}'] = auc | ||
|
||
return metric_results | ||
zhouzaida marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
from unittest import TestCase | ||
|
||
from mmeval.metrics import KeypointAUC | ||
|
||
|
||
class TestKeypointAUCandEPE(TestCase): | ||
zhouzaida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def setUp(self): | ||
"""Setup some variables which are used in every test method. | ||
|
||
TestCase calls functions in this order: setUp() -> testMethod() -> | ||
tearDown() -> cleanUp() | ||
""" | ||
self.output = np.zeros((1, 5, 2)) | ||
self.target = np.zeros((1, 5, 2)) | ||
# first channel | ||
self.output[0, 0] = [10, 4] | ||
self.target[0, 0] = [10, 0] | ||
# second channel | ||
self.output[0, 1] = [10, 18] | ||
self.target[0, 1] = [10, 10] | ||
# third channel | ||
self.output[0, 2] = [0, 0] | ||
self.target[0, 2] = [0, -1] | ||
# fourth channel | ||
self.output[0, 3] = [40, 40] | ||
self.target[0, 3] = [30, 30] | ||
# fifth channel | ||
self.output[0, 4] = [20, 10] | ||
self.target[0, 4] = [0, 10] | ||
|
||
self.keypoints_visible = np.array([[True, True, False, True, True]]) | ||
|
||
def test_auc_evaluate(self): | ||
"""test AUC evaluation metric.""" | ||
# case 1: norm_factor=20, num_thrs=4 | ||
auc_metric = KeypointAUC(norm_factor=20, num_thrs=4) | ||
target = {'AUC@4': 0.375} | ||
|
||
prediction = {'coords': self.output} | ||
groundtruth = {'coords': self.target, 'mask': self.keypoints_visible} | ||
predictions = [prediction] | ||
groundtruths = [groundtruth] | ||
|
||
auc_results = auc_metric(predictions, groundtruths) | ||
self.assertDictEqual(auc_results, target) | ||
|
||
# case 2: use ``add`` multiple times then ``compute`` | ||
auc_metric._results = [] | ||
preds1 = [{'coords': self.output[:3]}] | ||
preds2 = [{'coords': self.output[3:]}] | ||
gts1 = [{ | ||
'coords': self.target[:3], | ||
'mask': self.keypoints_visible[:3] | ||
}] | ||
gts2 = [{ | ||
'coords': self.target[3:], | ||
'mask': self.keypoints_visible[3:] | ||
}] | ||
|
||
auc_metric.add(preds1, gts1) | ||
auc_metric.add(preds2, gts2) | ||
|
||
auc_results = auc_metric.compute_metric(auc_metric._results) | ||
self.assertDictEqual(auc_results, target) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.