Skip to content

Commit 2b04381

Browse files
committed
relocate scikit learn import
1 parent 1161bf0 commit 2b04381

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pymatviz/relevance.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import TYPE_CHECKING
66

77
import matplotlib.pyplot as plt
8-
import sklearn.metrics as skm
98

109
from pymatviz.utils import df_to_arrays
1110

@@ -37,6 +36,8 @@ def roc_curve(
3736
ax = ax or plt.gca()
3837

3938
# get the metrics
39+
import sklearn.metrics as skm
40+
4041
false_pos_rate, true_pos_rate, _ = skm.roc_curve(targets, proba_pos)
4142
roc_auc = skm.roc_auc_score(targets, proba_pos)
4243

@@ -71,6 +72,8 @@ def precision_recall_curve(
7172
ax = ax or plt.gca()
7273

7374
# get the metrics
75+
import sklearn.metrics as skm
76+
7477
precision, recall, _ = skm.precision_recall_curve(targets, proba_pos)
7578

7679
# proba_pos.round() converts class probabilities to integer class labels

0 commit comments

Comments
 (0)