Skip to content

Commit f372631

Browse files
committed
add doc strings to relevance.py
1 parent ea94b0d commit f372631

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

mlmatrics/relevance.py

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55

66

77
def roc_curve(targets: Array, proba_pos: Array, ax: Axes = plt.gca()) -> float:
8+
"""Plot the receiver operating characteristic curve of a binary
9+
classifier given target labels and predicted probabilities for
10+
the positive class.
11+
12+
Args:
13+
targets (Array): Ground truth targets.
14+
proba_pos (Array): predicted probabilities for the positive class.
15+
16+
Returns:
17+
float: The classifier's ROC area under the curve.
18+
"""
819

920
# get the metrics
1021
fpr, tpr, _ = skm.roc_curve(targets, proba_pos)
@@ -26,6 +37,15 @@ def roc_curve(targets: Array, proba_pos: Array, ax: Axes = plt.gca()) -> float:
2637
def precision_recall_curve(
2738
targets: Array, proba_pos: Array, ax: Axes = plt.gca()
2839
) -> float:
40+
"""Plot the precision recall curve of a binary classifier.
41+
42+
Args:
43+
targets (Array): Ground truth targets.
44+
proba_pos (Array): predicted probabilities for the positive class.
45+
46+
Returns:
47+
float: The classifier's precision score.
48+
"""
2949

3050
# get the metrics
3151
precision, recall, _ = skm.precision_recall_curve(targets, proba_pos)

0 commit comments

Comments
 (0)