2
2
3
3
from __future__ import annotations
4
4
5
+ import functools
5
6
import math
6
7
import os
7
8
import subprocess
21
22
import wandb
22
23
from mpl_toolkits .axes_grid1 .anchored_artists import AnchoredSizeBar
23
24
from pandas .io .formats .style import Styler
25
+ from plotly .validators .scatter .line import DashValidator
26
+ from plotly .validators .scatter .marker import SymbolValidator
24
27
from tqdm import tqdm
25
28
26
29
from matbench_discovery import STABILITY_THRESHOLD
31
34
32
35
Backend = Literal ["matplotlib" , "plotly" ]
33
36
37
+ plotly_markers = SymbolValidator ().values [2 ::3 ] # noqa: PD011
38
+ plotly_line_styles = DashValidator ().values [:- 1 ] # noqa: PD011
39
+ # repeat line styles as many as times as needed to match number of markers
40
+ plotly_line_styles *= len (plotly_markers ) // len (plotly_line_styles )
34
41
35
- def unit (text : str ) -> str :
42
+
43
+ def plotly_unit (text : str ) -> str :
36
44
"""Wrap text in a span with decreased font size and weight to display units in
37
45
plotly labels.
38
46
"""
39
47
return f"<span style='font-size: 0.8em; font-weight: lighter;'>({ text } )</span>"
40
48
41
49
42
- ev_per_atom = unit ("eV/atom" )
50
+ ev_per_atom = plotly_unit ("eV/atom" )
43
51
44
52
# --- start global plot settings
45
53
quantity_labels = dict (
@@ -51,11 +59,11 @@ def unit(text: str) -> str:
51
59
n_sites = "Lattice site count" ,
52
60
energy_per_atom = f"Energy { ev_per_atom } " ,
53
61
e_form = f"DFT E<sub>form</sub> { ev_per_atom } " ,
54
- e_above_hull = f"E<sub>above hull</sub> { ev_per_atom } " ,
55
- e_above_hull_mp2020_corrected_ppd_mp = f"DFT E<sub>above hull</sub> { ev_per_atom } " ,
56
- e_above_hull_pred = f"Predicted E<sub>above hull</sub> { ev_per_atom } " ,
62
+ e_above_hull = f"E<sub>hull dist </sub> { ev_per_atom } " ,
63
+ e_above_hull_mp2020_corrected_ppd_mp = f"DFT E<sub>hull dist </sub> { ev_per_atom } " ,
64
+ e_above_hull_pred = f"Predicted E<sub>hull dist </sub> { ev_per_atom } " ,
57
65
e_above_hull_mp = f"E<sub>above MP hull</sub> { ev_per_atom } " ,
58
- e_above_hull_error = f"Error in E<sub>above hull</sub> { ev_per_atom } " ,
66
+ e_above_hull_error = f"Error in E<sub>hull dist </sub> { ev_per_atom } " ,
59
67
vol_diff = "Volume difference (A^3)" ,
60
68
e_form_per_atom_mp2020_corrected = f"DFT E<sub>form</sub> { ev_per_atom } " ,
61
69
e_form_per_atom_pred = f"Predicted E<sub>form</sub> { ev_per_atom } " ,
@@ -547,7 +555,7 @@ def rolling_mae_vs_hull_dist(
547
555
scatter_kwds = dict (
548
556
fill = "toself" , opacity = 0.2 , hoverinfo = "skip" , showlegend = False
549
557
)
550
- triangle_anno = "MAE > |E<sub>above hull</sub>|"
558
+ triangle_anno = "MAE > |E<sub>hull dist </sub>|"
551
559
fig .add_scatter (
552
560
x = (- 1 , - dft_acc , dft_acc , 1 ) if show_dft_acc else (- 1 , 0 , 1 ),
553
561
y = (1 , dft_acc , dft_acc , 1 ) if show_dft_acc else (1 , 0 , 1 ),
@@ -632,6 +640,7 @@ def cumulative_metrics(
632
640
optimal_recall : str | None = "Optimal Recall" ,
633
641
show_n_stable : bool = True ,
634
642
backend : Backend = "plotly" ,
643
+ n_points : int = 50 ,
635
644
** kwargs : Any ,
636
645
) -> tuple [plt .Figure | go .Figure , pd .DataFrame ]:
637
646
"""Create 2 subplots side-by-side with cumulative precision and recall curves for
@@ -661,18 +670,24 @@ def cumulative_metrics(
661
670
number of stable materials. Defaults to True.
662
671
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
663
672
Changes the return type. Defaults to 'plotly'.
673
+ n_points (int, optional): Number of points to use for interpolation of the
674
+ metric curves. Defaults to 80.
664
675
**kwargs: Keyword arguments passed to df.plot().
665
676
666
677
Returns:
667
678
tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and
668
679
dataframe of cumulative metrics for each model.
669
680
"""
670
- factory = lambda : pd .DataFrame (index = range (len (e_above_hull_true )))
671
- dfs : dict [str , pd .DataFrame ] = defaultdict (factory )
672
- metrics_no_case = [* map (str .casefold , metrics )]
681
+ dfs : dict [str , pd .DataFrame ] = defaultdict (pd .DataFrame )
682
+
683
+ # largest number of materials predicted stable by any model, determines x-axis range
684
+ n_max_pred_stable = (df_preds < stability_threshold ).sum ().max ()
685
+ longest_xs = np .linspace (0 , n_max_pred_stable - 1 , n_points )
686
+ for metric in metrics :
687
+ dfs [metric ].index = longest_xs
673
688
674
- valid_metrics = {"precision " , "recall " , "f1 " , "mae " , "rmse " }
675
- if invalid_metrics := set (metrics_no_case ) - valid_metrics :
689
+ valid_metrics = {"Precision " , "Recall " , "F1 " , "MAE " , "RMSE " }
690
+ if invalid_metrics := set (metrics ) - valid_metrics :
676
691
raise ValueError (
677
692
f"{ invalid_metrics = } , should be case-insensitive subset of { valid_metrics = } "
678
693
)
@@ -691,35 +706,36 @@ def cumulative_metrics(
691
706
precision_cum = true_pos_cum / (true_pos_cum + false_pos_cum )
692
707
recall_cum = true_pos_cum / n_total_pos # aka true_pos_rate aka sensitivity
693
708
694
- end = int (np .argmax (recall_cum ))
695
- xs = np .arange (end )
709
+ n_pred_stable = sum (each_pred <= stability_threshold )
710
+ model_range = np .arange (n_pred_stable ) # xs for interpolation
711
+ xs_model = longest_xs [longest_xs < n_pred_stable - 1 ] # xs for plotting
696
712
697
- if "precision" in metrics_no_case :
698
- prec_interp = scipy . interpolate . interp1d (
699
- xs , precision_cum [: end ], kind = "cubic"
700
- )
701
- dfs ["Precision" ][model_name ] = pd . Series ( prec_interp (xs ))
702
- if "recall" in metrics_no_case :
703
- recall_interp = scipy . interpolate . interp1d (
704
- xs , recall_cum [:end ], kind = "cubic"
705
- )
706
- dfs [ "Recall" ][ model_name ] = pd . Series ( recall_interp ( xs ))
707
- if "f1 " in metrics_no_case :
713
+ cubic_interpolate = functools . partial ( scipy . interpolate . interp1d , kind = "cubic" )
714
+
715
+ if "Precision" in metrics :
716
+ prec_interp = cubic_interpolate ( model_range , precision_cum [: n_pred_stable ] )
717
+ dfs ["Precision" ][model_name ] = dict ( zip ( xs_model , prec_interp (xs_model ) ))
718
+
719
+ if "Recall" in metrics :
720
+ recall_interp = cubic_interpolate ( model_range , recall_cum [:n_pred_stable ])
721
+ dfs [ "Recall" ][ model_name ] = dict ( zip ( xs_model , recall_interp ( xs_model )) )
722
+
723
+ if "F1 " in metrics :
708
724
f1_cum = 2 * (precision_cum * recall_cum ) / (precision_cum + recall_cum )
709
- f1_interp = scipy . interpolate . interp1d ( xs , f1_cum [:end ], kind = "cubic" )
710
- dfs ["F1" ][model_name ] = pd . Series ( f1_interp (xs ))
725
+ f1_interp = cubic_interpolate ( model_range , f1_cum [:n_pred_stable ] )
726
+ dfs ["F1" ][model_name ] = dict ( zip ( xs_model , f1_interp (xs_model ) ))
711
727
712
- if "mae " in metrics_no_case :
728
+ if "MAE " in metrics :
713
729
cum_errors = (each_true - each_pred ).abs ().cumsum ()
714
730
cum_counts = np .arange (1 , len (each_true ) + 1 )
715
731
mae_cum = cum_errors / cum_counts
716
- mae_interp = scipy . interpolate . interp1d ( xs , mae_cum [:end ], kind = "cubic" )
717
- dfs ["MAE" ][model_name ] = pd . Series ( mae_interp (xs ))
732
+ mae_interp = cubic_interpolate ( model_range , mae_cum [:n_pred_stable ] )
733
+ dfs ["MAE" ][model_name ] = dict ( zip ( xs_model , mae_interp (xs_model ) ))
718
734
719
- if "rmse " in metrics_no_case :
735
+ if "RMSE " in metrics :
720
736
rmse_cum = (((each_true - each_pred ) ** 2 ).cumsum () / cum_counts ) ** 0.5
721
- rmse_interp = scipy . interpolate . interp1d ( xs , rmse_cum [:end ], kind = "cubic" )
722
- dfs ["RMSE" ][model_name ] = pd . Series ( rmse_interp (xs ))
737
+ rmse_interp = cubic_interpolate ( model_range , rmse_cum [:n_pred_stable ] )
738
+ dfs ["RMSE" ][model_name ] = dict ( zip ( xs_model , rmse_interp (xs_model ) ))
723
739
724
740
for key in dfs :
725
741
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
@@ -730,7 +746,6 @@ def cumulative_metrics(
730
746
731
747
df_cum = pd .concat (dfs .values ())
732
748
# subselect rows for speed, plot has sufficient precision with 1k rows
733
- df_cum = df_cum .iloc [:: len (df_cum ) // 1000 or 1 ]
734
749
n_stable = sum (e_above_hull_true <= STABILITY_THRESHOLD )
735
750
736
751
if backend == "matplotlib" :
@@ -751,7 +766,7 @@ def cumulative_metrics(
751
766
# plotting speed and reduced file size
752
767
# falls back on every row if df has less than 1000 rows
753
768
df = dfs [metric ]
754
- df .iloc [:: len ( df ) // 1000 or 1 ]. plot (
769
+ df .plot (
755
770
ax = ax ,
756
771
legend = False ,
757
772
backend = backend ,
0 commit comments