|
13 | 13 | from pymatviz.utils import save_fig
|
14 | 14 |
|
15 | 15 | from matbench_discovery import PDF_FIGS, SITE_FIGS
|
16 |
| -from matbench_discovery.plots import ( |
17 |
| - cumulative_metrics, |
18 |
| - plotly_line_styles, |
19 |
| - plotly_markers, |
20 |
| -) |
| 16 | +from matbench_discovery.plots import cumulative_metrics |
21 | 17 | from matbench_discovery.preds import (
|
22 | 18 | df_each_pred,
|
23 | 19 | df_preds,
|
24 | 20 | each_true_col,
|
| 21 | + model_styles, |
25 | 22 | models,
|
26 | 23 | )
|
27 | 24 |
|
|
30 | 27 |
|
31 | 28 |
|
32 | 29 | # %%
|
33 |
| -metrics = ("Precision", "Recall") |
34 |
| -# metrics = ("MAE", "RMSE") |
| 30 | +# metrics = ("Precision", "Recall") |
| 31 | +metrics = ("MAE", "RMSE") |
35 | 32 | range_y = {
|
36 |
| - ("MAE", "RMSE"): (0, 0.5), |
| 33 | + ("MAE", "RMSE"): (0, 0.7), |
37 | 34 | ("Precision", "Recall"): (0, 1),
|
38 | 35 | }[metrics]
|
39 | 36 | fig, df_metric = cumulative_metrics(
|
40 | 37 | e_above_hull_true=df_preds[each_true_col],
|
41 | 38 | df_preds=df_each_pred[models],
|
42 | 39 | project_end_point="xy",
|
43 | 40 | backend=(backend := "plotly"),
|
44 |
| - range_y=range_y, |
45 | 41 | metrics=metrics,
|
46 | 42 | # facet_col_wrap=2,
|
47 | 43 | # increase facet col gap
|
|
54 | 50 | # fig.suptitle(title)
|
55 | 51 | fig.text(0.5, -0.08, x_label, ha="center", fontdict={"size": 16})
|
56 | 52 | if backend == "plotly":
|
| 53 | + for key in filter(lambda key: key.startswith("yaxis"), fig.layout): |
| 54 | + fig.layout[key].range = range_y |
| 55 | + |
57 | 56 | fig.layout.margin.update(l=0, r=0, t=30, b=50)
|
58 | 57 | fig.add_annotation(
|
59 | 58 | x=0.5,
|
|
71 | 70 | # )
|
72 | 71 | # if "MAE" in metrics:
|
73 | 72 | # fig.layout.legend.update(traceorder="reversed")
|
| 73 | + assert len(metrics) * len(models) == len( |
| 74 | + fig.data |
| 75 | + ), f"expected one trace per model per metric, got {len(fig.data)}" |
74 | 76 |
|
75 |
| - for trace, ls, marker in zip(fig.data, plotly_line_styles, plotly_markers): |
76 |
| - trace.line.dash = ls |
77 |
| - trace.marker.symbol = marker |
| 77 | + for trace in fig.data: |
| 78 | + if line_style := model_styles.get(trace.name): |
| 79 | + ls, _marker, color = line_style |
| 80 | + trace.line = dict(color=color, dash=ls, width=2) |
78 | 81 |
|
79 | 82 | # show only the N best models by default
|
80 | 83 | # if trace.name in df_metrics.T.sort_values("F1").index[:-6]:
|
|
0 commit comments