@@ -489,8 +489,8 @@ def cumulative_precision_recall(
489
489
tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and
490
490
dataframe of cumulative metrics for each model.
491
491
"""
492
- fact = lambda : pd .DataFrame (index = range (len (e_above_hull_true )))
493
- dfs = dict (Precision = fact (), Recall = fact ())
492
+ factory = lambda : pd .DataFrame (index = range (len (e_above_hull_true )))
493
+ dfs = dict (Precision = factory (), Recall = factory (), F1 = factory ())
494
494
495
495
for model_name in df_preds :
496
496
model_preds = df_preds [model_name ].sort_values ()
@@ -502,36 +502,43 @@ def cumulative_precision_recall(
502
502
503
503
true_pos_cumsum = true_pos .cumsum ()
504
504
# precision aka positive predictive value (PPV)
505
- precision = true_pos_cumsum / (true_pos_cumsum + false_pos .cumsum ())
505
+ precision_cum = true_pos_cumsum / (true_pos_cumsum + false_pos .cumsum ())
506
506
n_total_pos = sum (true_pos ) + sum (false_neg )
507
- recall = true_pos_cumsum / n_total_pos # aka true_pos_rate aka sensitivity
507
+ recall_cum = true_pos_cumsum / n_total_pos # aka true_pos_rate aka sensitivity
508
508
509
- end = int (np .argmax (recall ))
509
+ end = int (np .argmax (recall_cum ))
510
510
xs = np .arange (end )
511
511
512
- prec_interp = scipy .interpolate .interp1d (xs , precision [:end ], kind = "cubic" )
513
- recall_interp = scipy .interpolate .interp1d (xs , recall [:end ], kind = "cubic" )
512
+ # cumulative F1 score
513
+ f1_cum = 2 * (precision_cum * recall_cum ) / (precision_cum + recall_cum )
514
+
515
+ prec_interp = scipy .interpolate .interp1d (xs , precision_cum [:end ], kind = "cubic" )
516
+ recall_interp = scipy .interpolate .interp1d (xs , recall_cum [:end ], kind = "cubic" )
517
+ f1_interp = scipy .interpolate .interp1d (xs , f1_cum [:end ], kind = "cubic" )
514
518
dfs ["Precision" ][model_name ] = pd .Series (prec_interp (xs ))
515
519
dfs ["Recall" ][model_name ] = pd .Series (recall_interp (xs ))
520
+ dfs ["F1" ][model_name ] = pd .Series (f1_interp (xs ))
516
521
517
522
for key , df in dfs .items ():
518
523
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
519
524
# predicted materials by any model
520
525
df .dropna (how = "all" , inplace = True )
526
+ # will be used as facet_col in plotly to split different metrics into subplots
521
527
df ["metric" ] = key
522
528
523
529
df_cum = pd .concat (dfs .values ())
530
+ # subselect rows for speed, plot has sufficient precision with 1k rows
531
+ df_cum = df_cum .iloc [:: len (df_cum ) // 1000 or 1 ]
524
532
525
533
if backend == "matplotlib" :
526
- fig , axs = plt .subplots (1 , 2 , figsize = (15 , 7 ), sharey = True )
534
+ fig , axs = plt .subplots (1 , len ( dfs ) , figsize = (15 , 7 ), sharey = True )
527
535
line_kwargs = dict (
528
536
linewidth = 3 , markevery = [- 1 ], marker = "x" , markersize = 14 , markeredgewidth = 2.5
529
537
)
530
538
for (key , df ), ax in zip (dfs .items (), axs ):
531
539
# select every n-th row of df so that 1000 rows are left for increased
532
540
# plotting speed and reduced file size
533
541
# falls back on every row if df has less than 1000 rows
534
-
535
542
df .iloc [:: len (df ) // 1000 or 1 ].plot (
536
543
ax = ax , legend = False , backend = backend , ** line_kwargs | kwargs , ylabel = key
537
544
)
@@ -541,9 +548,12 @@ def cumulative_precision_recall(
541
548
bbox = dict (facecolor = "white" , alpha = 0.5 , edgecolor = "none" )
542
549
assert len (axs ) == len (dfs ), f"{ len (axs )} != { len (dfs )} "
543
550
544
- for ax , df in zip (axs , dfs .values ()):
551
+ for ax , ( key , df ) in zip (axs . flat , dfs .items ()):
545
552
ax .set (ylim = (0 , 1 ), xlim = (0 , None ), ylabel = key )
546
553
for model in df_preds :
554
+ # TODO is this if really necessary?
555
+ if len (df [model ].dropna ()) == 0 :
556
+ continue
547
557
x_end = df [model ].dropna ().index [- 1 ]
548
558
y_end = df [model ].dropna ().iloc [- 1 ]
549
559
# place model name at the end of every line
@@ -556,11 +566,12 @@ def cumulative_precision_recall(
556
566
# optimal recall line finds all stable materials without any false positives
557
567
# can be included to confirm all models start out of with near optimal recall
558
568
# and to see how much each model overshoots total n_stable
559
- n_below_hull = sum (e_above_hull_true < 0 )
560
569
if show_optimal :
570
+ ax = next (filter (lambda ax : ax .get_ylabel () == "Recall" , axs .flat ))
571
+ n_below_hull = sum (e_above_hull_true < 0 )
561
572
opt_label = "Optimal Recall"
562
- axs [ 1 ] .plot ([0 , n_below_hull ], [0 , 1 ], color = "green" , linestyle = "--" )
563
- axs [ 1 ] .text (
573
+ ax .plot ([0 , n_below_hull ], [0 , 1 ], color = "green" , linestyle = "--" )
574
+ ax .text (
564
575
* [n_below_hull , 0.81 ],
565
576
opt_label ,
566
577
color = "green" ,
@@ -571,16 +582,29 @@ def cumulative_precision_recall(
571
582
)
572
583
573
584
elif backend == "plotly" :
574
- fig = df_cum .iloc [:: len (df_cum ) // 1000 or 1 ].plot (
575
- backend = backend , facet_col = "metric" , ** kwargs
585
+ fig = df_cum .plot (
586
+ backend = backend ,
587
+ facet_col = "metric" ,
588
+ facet_col_wrap = 3 ,
589
+ facet_col_spacing = 0.03 ,
590
+ # pivot df in case we want to show all 3 metrics in each plot's hover
591
+ # requires fixing index mismatch due to df subsampling above
592
+ # customdata=dict(
593
+ # df_cum.reset_index()
594
+ # .pivot(index="index", columns="metric")["Voronoi RF above hull pred"]
595
+ # .items()
596
+ # ),
597
+ ** kwargs ,
576
598
)
577
599
fig .update_traces (line = dict (width = 4 ))
578
- for idx in range (1 , 3 ):
579
- fig .update_xaxes (
580
- title_text = "Number of materials predicted stable" , row = 1 , col = idx
600
+ for idx , metric in enumerate (df_cum .metric .unique (), 1 ):
601
+ x_axis_label = "Number of materials predicted stable" if idx == 2 else ""
602
+ fig .update_xaxes (title = x_axis_label , col = idx )
603
+ fig .update_yaxes (title = dict (text = metric , standoff = 0 ), col = idx )
604
+ fig .update_traces (
605
+ hovertemplate = f"Index = %{{x:d}}<br>{ metric } = %{{y:.2f}}" ,
606
+ col = idx , # model = %{customdata[0]}<br>
581
607
)
582
- fig .update_yaxes (title = "Precision" , col = 1 )
583
- fig .update_yaxes (title = "Recall" , col = 2 )
584
608
fig .for_each_annotation (lambda a : a .update (text = "" ))
585
609
fig .update_layout (legend = dict (title = "" ))
586
610
fig .update_layout (showlegend = False )
0 commit comments