10
10
from scipy .interpolate import interpn
11
11
from sklearn .metrics import r2_score
12
12
13
- from ml_matrics .utils import NumArray , add_identity , with_hist
13
+ from ml_matrics .utils import NumArray , with_hist
14
14
15
15
16
16
def hist_density (
@@ -19,8 +19,8 @@ def hist_density(
19
19
"""Return an approximate density of 2d points.
20
20
21
21
Args:
22
- xs (NumArray ): x-coordinates of points
23
- ys (NumArray ): y-coordinates of points
22
+ xs (array ): x-coordinates of points
23
+ ys (array ): y-coordinates of points
24
24
sort (bool, optional): Whether to sort points by density so that densest points
25
25
are plotted last. Defaults to True.
26
26
bins (int, optional): Number of bins (histogram resolution). Defaults to 100.
@@ -76,8 +76,8 @@ def density_scatter(
76
76
"""Scatter plot colored (and optionally sorted) by density.
77
77
78
78
Args:
79
- xs (NumArray ): x values.
80
- ys (NumArray ): y values.
79
+ xs (array ): x values.
80
+ ys (array ): y values.
81
81
ax (Axes, optional): plt.Axes object. Defaults to None.
82
82
color_map (str, optional): plt color map or valid string name. Defaults to "Blues".
83
83
sort (bool, optional): Whether to sort the data. Defaults to True.
@@ -102,8 +102,12 @@ def density_scatter(
102
102
norm = mpl .colors .LogNorm () if log else None
103
103
104
104
ax .scatter (xs , ys , c = cs , cmap = color_map , norm = norm , ** kwargs )
105
+
105
106
if identity :
106
- add_identity (ax , label = "ideal" )
107
+ ax .axline (
108
+ (0 , 0 ), (1 , 1 ), alpha = 0.5 , zorder = 0 , linestyle = "dashed" , color = "black"
109
+ )
110
+
107
111
if stats :
108
112
add_mae_r2_box (xs , ys , ax )
109
113
@@ -128,10 +132,10 @@ def scatter_with_err_bar(
128
132
i.e. if points farther from the parity line have larger uncertainty.
129
133
130
134
Args:
131
- xs (NumArray ): x-values
132
- ys (NumArray ): y-values
133
- xerr (NumArray , optional): Horizontal error bars. Defaults to None.
134
- yerr (NumArray , optional): Vertical error bars. Defaults to None.
135
+ xs (array ): x-values
136
+ ys (array ): y-values
137
+ xerr (array , optional): Horizontal error bars. Defaults to None.
138
+ yerr (array , optional): Vertical error bars. Defaults to None.
135
139
ax (Axes, optional): plt.Axes object. Defaults to None.
136
140
xlabel (str, optional): x-axis label. Defaults to "Actual".
137
141
ylabel (str, optional): y-axis label. Defaults to "Predicted".
@@ -145,7 +149,10 @@ def scatter_with_err_bar(
145
149
146
150
styles = dict (markersize = 6 , fmt = "o" , ecolor = "g" , capthick = 2 , elinewidth = 2 )
147
151
ax .errorbar (xs , ys , yerr = yerr , xerr = xerr , ** kwargs , ** styles )
148
- add_identity (ax )
152
+
153
+ # identity line
154
+ ax .axline ((0 , 0 ), (1 , 1 ), alpha = 0.5 , zorder = 0 , linestyle = "dashed" , color = "black" )
155
+
149
156
add_mae_r2_box (xs , ys , ax )
150
157
151
158
ax .set (xlabel = xlabel , ylabel = ylabel , title = title )
@@ -166,10 +173,10 @@ def density_hexbin(
166
173
dimension passed as weights.
167
174
168
175
Args:
169
- xs (NumArray ): x values
170
- yx (NumArray ): y values
176
+ xs (array ): x values
177
+ yx (array ): y values
171
178
ax (Axes, optional): plt.Axes object. Defaults to None.
172
- weights (NumArray , optional): If given, these values are accumulated in the bins.
179
+ weights (array , optional): If given, these values are accumulated in the bins.
173
180
Otherwise, every point has value 1. Must be of the same length as x and y.
174
181
Defaults to None.
175
182
xlabel (str, optional): x-axis label. Defaults to "Actual".
@@ -188,7 +195,9 @@ def density_hexbin(
188
195
plt .colorbar (hexbin , cax = cb_ax )
189
196
cb_ax .yaxis .set_ticks_position ("left" )
190
197
191
- add_identity (ax , label = "ideal" )
198
+ # identity line
199
+ ax .axline ((0 , 0 ), (1 , 1 ), alpha = 0.5 , zorder = 0 , linestyle = "dashed" , color = "black" )
200
+
192
201
add_mae_r2_box (xs , yx , ax , loc = "upper left" )
193
202
194
203
ax .set (xlabel = xlabel , ylabel = ylabel )
@@ -235,8 +244,8 @@ def residual_vs_actual(y_true: NumArray, y_pred: NumArray, ax: Axes = None) -> A
235
244
(y_err = y_true - y_pred) on the y-axis.
236
245
237
246
Args:
238
- y_true (NumArray ): Ground truth values
239
- y_pred (NumArray ): Model predictions
247
+ y_true (array ): Ground truth values
248
+ y_pred (array ): Model predictions
240
249
ax (Axes, optional): plt.Axes object. Defaults to None.
241
250
242
251
Returns:
@@ -248,11 +257,10 @@ def residual_vs_actual(y_true: NumArray, y_pred: NumArray, ax: Axes = None) -> A
248
257
249
258
y_err = y_true - y_pred
250
259
251
- xmin = np .min (y_true ) * 0.9
252
- xmax = np .max (y_true ) / 0.9
253
-
254
260
plt .plot (y_true , y_err , "o" , alpha = 0.5 , label = None , mew = 1.2 , ms = 5.2 )
255
- plt .plot ([xmin , xmax ], [0 , 0 ], "k--" , alpha = 0.5 , label = "ideal" )
261
+ plt .axline (
262
+ [1 , 0 ], [2 , 0 ], linestyle = "dashed" , color = "black" , alpha = 0.5 , label = "ideal"
263
+ )
256
264
257
265
plt .ylabel (r"Residual ($y_\mathrm{test} - y_\mathrm{pred}$)" )
258
266
plt .xlabel ("Actual value" )
0 commit comments