Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 1914c02

Browse files
committed
multiclass-mcc metric enhancements
* Rename metric from "PCC" to "mMCC" because though the math is derived from Pearson CC, it's utility is as a multiclass extension of Mathews CC. * Harden mx.metric.mMCC.update to more variations of input format, similar to mx.metric.Accuracy.update.
1 parent d09f68a commit 1914c02

File tree

2 files changed

+33
-29
lines changed

2 files changed

+33
-29
lines changed

python/mxnet/metric.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ class MCC(EvalMetric):
860860
861861
.. note::
862862
863-
This version of MCC only supports binary classification. See PCC.
863+
This version of MCC only supports binary classification. See MMCC.
864864
865865
Parameters
866866
----------
@@ -1477,18 +1477,18 @@ def update(self, labels, preds):
14771477

14781478

14791479
@register
1480-
class PCC(EvalMetric):
1481-
"""PCC is a multiclass equivalent for the Matthews correlation coefficient derived
1480+
class MMCC(EvalMetric):
1481+
"""MMCC is a multiclass equivalent for the Matthews correlation coefficient derived
14821482
from a discrete solution to the Pearson correlation coefficient.
14831483
14841484
.. math::
1485-
\\text{PCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}}
1485+
\\text{MMCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}}
14861486
{{\\sqrt {\\sum _{k}(\\sum _{l}C_{kl})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{k'l'})}}
14871487
{\\sqrt {\\sum _{k}(\\sum _{l}C_{lk})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{l'k'})}}}
14881488
14891489
defined in terms of a K x K confusion matrix C.
14901490
1491-
When there are more than two labels the PCC will no longer range between -1 and +1.
1491+
When there are more than two labels the MMCC will no longer range between -1 and +1.
14921492
Instead the minimum value will be between -1 and 0 depending on the true distribution.
14931493
The maximum value is always +1.
14941494
@@ -1522,18 +1522,18 @@ class PCC(EvalMetric):
15221522
)]
15231523
>>> f1 = mx.metric.F1()
15241524
>>> f1.update(preds = predicts, labels = labels)
1525-
>>> pcc = mx.metric.PCC()
1526-
>>> pcc.update(preds = predicts, labels = labels)
1525+
>>> mmcc = mx.metric.MMCC()
1526+
>>> mmcc.update(preds = predicts, labels = labels)
15271527
>>> print f1.get()
15281528
('f1', 0.95233560306652054)
1529-
>>> print pcc.get()
1530-
('pcc', 0.01917751877733392)
1529+
>>> print mmcc.get()
1530+
('mmcc', 0.01917751877733392)
15311531
"""
1532-
def __init__(self, name='pcc',
1532+
def __init__(self, name='mmcc',
15331533
output_names=None, label_names=None,
15341534
has_global_stats=True):
15351535
self.k = 2
1536-
super(PCC, self).__init__(
1536+
super(MMCC, self).__init__(
15371537
name=name, output_names=output_names, label_names=label_names,
15381538
has_global_stats=has_global_stats)
15391539

@@ -1572,7 +1572,11 @@ def update(self, labels, preds):
15721572
# update the confusion matrix
15731573
for label, pred in zip(labels, preds):
15741574
label = label.astype('int32', copy=False).asnumpy()
1575-
pred = pred.asnumpy().argmax(axis=1)
1575+
pred = pred.asnumpy()
1576+
if pred.shape != label.shape:
1577+
pred = pred.argmax(axis=1)
1578+
else:
1579+
pred = pred.astype('int32', copy=False)
15761580
n = max(pred.max(), label.max())
15771581
if n >= self.k:
15781582
self._grow(n + 1 - self.k)

tests/python/unittest/test_metric.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_metrics():
3434
check_metric('mcc')
3535
check_metric('perplexity', -1)
3636
check_metric('pearsonr')
37-
check_metric('pcc')
37+
check_metric('mmcc')
3838
check_metric('nll_loss')
3939
check_metric('loss')
4040
composite = mx.metric.create(['acc', 'f1'])
@@ -90,7 +90,7 @@ def test_global_metric():
9090
_check_global_metric('mcc', shape=(10,2), average='micro')
9191
_check_global_metric('perplexity', -1)
9292
_check_global_metric('pearsonr', use_same_shape=True)
93-
_check_global_metric('pcc', shape=(10,2))
93+
_check_global_metric('mmcc', shape=(10,2))
9494
_check_global_metric('nll_loss')
9595
_check_global_metric('loss')
9696
_check_global_metric('ce')
@@ -267,26 +267,26 @@ def cm_batch(cm):
267267
preds += [ ident[j] ] * cm[i][j]
268268
return ([ mx.nd.array(labels, dtype='int32') ], [ mx.nd.array(preds) ])
269269

270-
def test_pcc():
270+
def test_mmcc():
271271
labels, preds = cm_batch([
272272
[ 7, 3 ],
273273
[ 2, 5 ],
274274
])
275-
met_pcc = mx.metric.create('pcc')
276-
met_pcc.update(labels, preds)
277-
_, pcc = met_pcc.get()
275+
met_mmcc = mx.metric.create('mmcc')
276+
met_mmcc.update(labels, preds)
277+
_, mmcc = met_mmcc.get()
278278

279-
# pcc should agree with mcc for binary classification
279+
# mmcc should agree with mcc for binary classification
280280
met_mcc = mx.metric.create('mcc')
281281
met_mcc.update(labels, preds)
282282
_, mcc = met_mcc.get()
283-
np.testing.assert_almost_equal(pcc, mcc)
283+
np.testing.assert_almost_equal(mmcc, mcc)
284284

285-
# pcc should agree with Pearson for binary classification
285+
# mmcc should agree with Pearson for binary classification
286286
met_pear = mx.metric.create('pearsonr')
287287
met_pear.update(labels, [p.argmax(axis=1) for p in preds])
288288
_, pear = met_pear.get()
289-
np.testing.assert_almost_equal(pcc, pear)
289+
np.testing.assert_almost_equal(mmcc, pear)
290290

291291
# check multiclass case against reference implementation
292292
CM = [
@@ -316,10 +316,10 @@ def test_pcc():
316316
for k in range(K)
317317
)) ** 0.5
318318
labels, preds = cm_batch(CM)
319-
met_pcc.reset()
320-
met_pcc.update(labels, preds)
321-
_, pcc = met_pcc.get()
322-
np.testing.assert_almost_equal(pcc, ref)
319+
met_mmcc.reset()
320+
met_mmcc.update(labels, preds)
321+
_, mmcc = met_mmcc.get()
322+
np.testing.assert_almost_equal(mmcc, ref)
323323

324324
# things that should not change metric score:
325325
# * order
@@ -330,10 +330,10 @@ def test_pcc():
330330
preds = [ [ i.reshape((1, -1)) ] for i in preds[0] ]
331331
preds.reverse()
332332

333-
met_pcc.reset()
333+
met_mmcc.reset()
334334
for l, p in zip(labels, preds):
335-
met_pcc.update(l, p)
336-
assert pcc == met_pcc.get()[1]
335+
met_mmcc.update(l, p)
336+
assert mmcc == met_mmcc.get()[1]
337337

338338
def test_single_array_input():
339339
pred = mx.nd.array([[1,2,3,4]])

0 commit comments

Comments
 (0)