@@ -190,10 +190,10 @@ def hist_classified_stable_vs_hull_dist(
190
190
)[which_energy ]
191
191
192
192
if stability_threshold is not None :
193
- for ax in [fig ] if isinstance (fig , plt .Axes ) else fig .flat :
194
- ax .set (xlabel = xlabel , ylabel = y_label , xlim = x_lim )
193
+ for ax_i in [fig ] if isinstance (fig , plt .Axes ) else fig .flat :
194
+ ax_i .set (xlabel = xlabel , ylabel = y_label , xlim = x_lim )
195
195
label = "Stability Threshold"
196
- ax .axvline (
196
+ ax_i .axvline (
197
197
stability_threshold , color = "black" , linestyle = "--" , label = label
198
198
)
199
199
@@ -228,8 +228,8 @@ def hist_classified_stable_vs_hull_dist(
228
228
)
229
229
230
230
if backend == MATPLOTLIB :
231
- for ax in fig .flat if isinstance (fig , np .ndarray ) else [fig ]:
232
- ax_acc = ax .twinx ()
231
+ for ax_i in fig .flat if isinstance (fig , np .ndarray ) else [fig ]:
232
+ ax_acc = ax_i .twinx ()
233
233
ax_acc .set_ylabel ("Rolling Accuracy" , color = "darkblue" )
234
234
ax_acc .tick_params (labelcolor = "darkblue" )
235
235
ax_acc .set (ylim = (0 , 1.1 ))
@@ -681,14 +681,14 @@ def cumulative_metrics(
681
681
rmse_interp = cubic_interpolate (model_range , rmse_cum [:n_pred_stable ])
682
682
dfs ["RMSE" ][model_name ] = dict (zip (xs_model , rmse_interp (xs_model )))
683
683
684
- for key in dfs :
684
+ for key , df_i in dfs .items ():
685
+ # will be used as facet_col in plotly to split different metrics into subplots
686
+ df_i ["metric" ] = key
685
687
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
686
688
# predicted materials by any model
687
- dfs [key ] = dfs [key ].dropna (how = "all" )
688
- # will be used as facet_col in plotly to split different metrics into subplots
689
- dfs [key ]["metric" ] = key
689
+ dfs [key ] = df_i .dropna (how = "all" )
690
690
691
- df_cum = pd .concat (dfs .values ())
691
+ df_cumu_metrics = pd .concat (dfs .values ())
692
692
# subselect rows for speed, plot has sufficient precision with 1k rows
693
693
n_stable = sum (e_above_hull_true <= STABILITY_THRESHOLD )
694
694
@@ -752,7 +752,7 @@ def cumulative_metrics(
752
752
elif backend == PLOTLY :
753
753
n_cols = kwargs .pop ("facet_col_wrap" , 2 )
754
754
kwargs .setdefault ("facet_col_spacing" , 0.03 )
755
- fig = df_cum .plot (
755
+ fig = df_cumu_metrics .plot (
756
756
backend = backend ,
757
757
facet_col = "metric" ,
758
758
facet_col_wrap = n_cols ,
@@ -802,7 +802,7 @@ def cumulative_metrics(
802
802
else :
803
803
raise ValueError (f"Unknown { backend = } " )
804
804
805
- return fig , df_cum
805
+ return fig , df_cumu_metrics
806
806
807
807
808
808
def wandb_scatter (table : wandb .Table , fields : dict [str , str ], ** kwargs : Any ) -> None :
0 commit comments