Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 196d1f4

Browse files
tlbylanking520
authored andcommitted
[MXNET-1399] multiclass-mcc metric enhancements (#14874)
* multiclass-mcc metric enhancements * Rename metric from "PCC" to "mMCC" because though the math is derived from Pearson CC, it's utility is as a multiclass extension of Mathews CC. * Harden mx.metric.mMCC.update to more variations of input format, similar to mx.metric.Accuracy.update. * Harden mx.metric.PCC.update to more variations of input format, similar to mx.metric.Accuracy.update. * Enhance testcases for mx.metric.PCC.
1 parent 2d86c70 commit 196d1f4

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

python/mxnet/metric.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,11 @@ def update(self, labels, preds):
15721572
# update the confusion matrix
15731573
for label, pred in zip(labels, preds):
15741574
label = label.astype('int32', copy=False).asnumpy()
1575-
pred = pred.asnumpy().argmax(axis=1)
1575+
pred = pred.asnumpy()
1576+
if pred.shape != label.shape:
1577+
pred = pred.argmax(axis=1)
1578+
else:
1579+
pred = pred.astype('int32', copy=False)
15761580
n = max(pred.max(), label.max())
15771581
if n >= self.k:
15781582
self._grow(n + 1 - self.k)

tests/python/unittest/test_metric.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,13 @@ def test_pcc():
304304
_, pear = met_pear.get()
305305
np.testing.assert_almost_equal(pcc, pear)
306306

307+
# pcc should also accept pred as scalar rather than softmax vector
308+
# like acc does
309+
met_pcc.reset()
310+
met_pcc.update(labels, [p.argmax(axis=1) for p in preds])
311+
_, chk = met_pcc.get()
312+
np.testing.assert_almost_equal(pcc, chk)
313+
307314
# check multiclass case against reference implementation
308315
CM = [
309316
[ 23, 13, 3 ],

0 commit comments

Comments
 (0)