Skip to content

Commit 071174a

Browse files
committed
remove df_to_pdf and normalize_and_crop_pdf from matbench_discovery/plots.py
now imported from pymatviz
1 parent 6156c27 commit 071174a

20 files changed

+34
-208
lines changed

data/wbm/eda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ptable_heatmap_plotly,
1212
spacegroup_sunburst,
1313
)
14-
from pymatviz.utils import save_fig
14+
from pymatviz.io import save_fig
1515

1616
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD
1717
from matbench_discovery import plots as plots

data/wbm/fetch_process_wbm_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from pymatgen.entries.computed_entries import ComputedStructureEntry
1717
from pymatviz import density_scatter
18-
from pymatviz.utils import save_fig
18+
from pymatviz.io import save_fig
1919
from tqdm import tqdm
2020

2121
from matbench_discovery import SITE_FIGS, today

matbench_discovery/plots.py

-83
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import functools
66
import math
7-
import os
8-
import subprocess
97
from collections import defaultdict
108
from collections.abc import Sequence
119
from pathlib import Path
@@ -923,84 +921,3 @@ def df_to_svelte_table(
923921
styled_table = html_table.replace("</style>", f"{styles}</style>")
924922
with open(file_path, "w") as file:
925923
file.write(styled_table)
926-
927-
928-
def df_to_pdf(
929-
styler: Styler, file_path: str | Path, crop: bool = True, **kwargs: Any
930-
) -> None:
931-
"""Export a pandas Styler to PDF with WeasyPrint.
932-
933-
Args:
934-
styler (Styler): Styler object to export.
935-
file_path (str): Path to save the PDF to. Requires WeasyPrint.
936-
crop (bool): Whether to crop the PDF margins. Requires pdfCropMargins.
937-
Defaults to True.
938-
**kwargs: Keyword arguments passed to Styler.to_html().
939-
"""
940-
try:
941-
from weasyprint import HTML
942-
except ImportError as exc:
943-
msg = "weasyprint not installed\nrun pip install weasyprint"
944-
raise ImportError(msg) from exc
945-
946-
html_str = styler.to_html(**kwargs)
947-
948-
# CSS to adjust layout and margins
949-
html_str = f"""
950-
<style>
951-
@page {{ size: landscape; margin: 1cm; }}
952-
body {{ margin: 0; padding: 1em; }}
953-
</style>
954-
{html_str}
955-
"""
956-
957-
html = HTML(string=html_str)
958-
959-
html.write_pdf(file_path)
960-
961-
if crop:
962-
normalize_and_crop_pdf(file_path)
963-
964-
965-
def normalize_and_crop_pdf(file_path: str | Path) -> None:
966-
"""Normalize a PDF using Ghostscript and then crop it.
967-
Without gs normalization, pdfCropMargins sometimes corrupts the PDF.
968-
969-
Args:
970-
file_path (str | Path): Path to the PDF file.
971-
"""
972-
try:
973-
normalized_file_path = f"{file_path}_normalized.pdf"
974-
from pdfCropMargins import crop
975-
976-
# Normalize the PDF with Ghostscript
977-
subprocess.run(
978-
[
979-
*"gs -sDEVICE=pdfwrite -dCompatibilityLevel=1.4".split(),
980-
*"-dPDFSETTINGS=/default -dNOPAUSE -dQUIET -dBATCH".split(),
981-
f"-sOutputFile={normalized_file_path}",
982-
str(file_path),
983-
],
984-
check=True,
985-
)
986-
987-
# Crop the normalized PDF
988-
cropped_file_path, exit_code, stdout, stderr = crop(
989-
["--percentRetain", "0", normalized_file_path]
990-
)
991-
992-
if stderr:
993-
print(f"pdfCropMargins {stderr=}")
994-
# something went wrong, remove the cropped PDF
995-
os.remove(cropped_file_path)
996-
else:
997-
# replace the original PDF with the cropped one
998-
os.replace(cropped_file_path, str(file_path))
999-
1000-
os.remove(normalized_file_path)
1001-
1002-
except ImportError as exc:
1003-
msg = "pdfCropMargins not installed\nrun pip install pdfCropMargins"
1004-
raise ImportError(msg) from exc
1005-
except Exception as exc:
1006-
raise RuntimeError("Error cropping PDF margins") from exc

models/chgnet/analyze_chgnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
from pymatgen.core import Structure
1010
from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly
11-
from pymatviz.utils import save_fig
11+
from pymatviz.io import save_fig
1212

1313
from matbench_discovery import PDF_FIGS
1414
from matbench_discovery import plots as plots

models/mace/analyze_mace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pandas as pd
88
from pymatviz import density_scatter, ptable_heatmap_plotly, spacegroup_sunburst
9-
from pymatviz.utils import save_fig
9+
from pymatviz.io import save_fig
1010

1111
from matbench_discovery import plots as plots
1212
from matbench_discovery.data import df_wbm

models/wrenformer/analyze_wrenformer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import pandas as pd
77
from aviary.wren.utils import get_isopointal_proto_from_aflow
88
from pymatviz import spacegroup_hist, spacegroup_sunburst
9+
from pymatviz.io import df_to_pdf
910
from pymatviz.ptable import ptable_heatmap_plotly
1011
from pymatviz.utils import add_identity_line, bin_df_cols, save_fig
1112

1213
from matbench_discovery import PDF_FIGS, SITE_FIGS
1314
from matbench_discovery.data import DATA_FILES, df_wbm
14-
from matbench_discovery.plots import df_to_pdf, df_to_svelte_table
15+
from matbench_discovery.plots import df_to_svelte_table
1516
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col
1617

1718
__author__ = "Janosh Riebesell"

pyproject.toml

+3-7
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,10 @@ requires-python = ">=3.9"
3434
dependencies = [
3535
"matplotlib",
3636
"numpy",
37-
# tmp: shouldn't be needed, used to be included in output_formatting
38-
"jinja2",
39-
# output_formatting needed for pandas Stylers
40-
# see https://github.com/pandas-dev/pandas/blob/-/pyproject.toml
41-
"pandas[output_formatting]>=2.0.0",
37+
"pandas>=2.0.0",
4238
"plotly",
4339
"pymatgen",
44-
"pymatviz[export-figs]",
40+
"pymatviz[export-figs,df-pdf-export]",
4541
"scikit-learn",
4642
"scipy",
4743
"tqdm",
@@ -69,8 +65,8 @@ running-models = [
6965
"megnet",
7066
]
7167
3d-structures = ["crystaltoolkit"]
68+
df-to-pdf = ["jinja2"]
7269
fetch-data = ["gdown"]
73-
df-pdf-export = ["pdfCropMargins", "weasyprint"]
7470

7571
[tool.setuptools.packages.find]
7672
include = ["matbench_discovery*"]

scripts/analyze_model_failure_cases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import plotly.graph_objs as go
1414
from pymatgen.core import Composition, Structure
1515
from pymatviz import count_elements, plot_structure_2d, ptable_heatmap_plotly
16-
from pymatviz.utils import save_fig
16+
from pymatviz.io import save_fig
1717
from tqdm import tqdm
1818

1919
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS

scripts/hist_classified_stable_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# %%
1010
from typing import Final
1111

12-
from pymatviz.utils import save_fig
12+
from pymatviz.io import save_fig
1313

1414
from matbench_discovery import PDF_FIGS
1515
from matbench_discovery.data import df_wbm

scripts/hist_classified_stable_vs_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Final
1111

1212
import pandas as pd
13-
from pymatviz.utils import save_fig
13+
from pymatviz.io import save_fig
1414

1515
from matbench_discovery import PDF_FIGS
1616
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist

scripts/model_figs/cumulative_metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# %%
1212
import pandas as pd
13-
from pymatviz.utils import save_fig
13+
from pymatviz.io import save_fig
1414

1515
from matbench_discovery import PDF_FIGS, SITE_FIGS
1616
from matbench_discovery.plots import cumulative_metrics

scripts/model_figs/hist_classified_stable_vs_hull_dist_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import math
99
from typing import Final
1010

11-
from pymatviz.utils import save_fig
11+
from pymatviz.io import save_fig
1212

1313
from matbench_discovery import PDF_FIGS, SITE_FIGS, today
1414
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist, plt

scripts/model_figs/make_metrics_tables.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
import numpy as np
1010
import pandas as pd
11+
from pymatviz.io import df_to_pdf
1112
from sklearn.dummy import DummyClassifier
1213

1314
from matbench_discovery import PDF_FIGS, SITE_FIGS
1415
from matbench_discovery.data import DATA_FILES, df_wbm
1516
from matbench_discovery.metrics import stable_metrics
1617
from matbench_discovery.models import MODEL_METADATA
17-
from matbench_discovery.plots import df_to_pdf, df_to_svelte_table
18+
from matbench_discovery.plots import df_to_svelte_table
1819
from matbench_discovery.preds import df_metrics, df_metrics_10k, each_true_col
1920

2021
__author__ = "Janosh Riebesell"
@@ -25,7 +26,7 @@
2526
"M3GNet→MEGNet": "M3GNet",
2627
"CHGNet→MEGNet": "CHGNet",
2728
}
28-
train_size_col = "training size"
29+
train_size_col = "Training Size"
2930
df_metrics.loc[train_size_col] = df_metrics_10k.loc[train_size_col] = ""
3031
for model in df_metrics:
3132
model_name = name_map.get(model, model)
@@ -62,7 +63,7 @@
6263
df_metrics_10k["Dummy"] = dummy_metrics
6364

6465

65-
# %% for each model this ontology dict specifies (training type, test type, model class)
66+
# %% for each model this ontology dict specifies (training type, test type, model type)
6667
ontology = {
6768
"ALIGNN": ("RS2RE", "IS2RE", "GNN"),
6869
# "ALIGNN Pretrained": ("RS2RE", "IS2RE", "GNN"),
@@ -80,7 +81,7 @@
8081
"CHGNet→MEGNet": ("S2EFSM", "IS2RE-SR", "UIP-GNN"),
8182
"Dummy": ("", "", ""),
8283
}
83-
ontology_cols = ["Trained", "Deployed", "Model Class"]
84+
ontology_cols = ["Trained", "Deployed", model_type_col := "Model Type"]
8485
df_ont = pd.DataFrame(ontology, index=ontology_cols)
8586
# RS2RE = relaxed structure to relaxed energy
8687
# RP2RE = relaxed prototype to predicted energy
@@ -104,7 +105,7 @@
104105
make_uip_megnet_comparison = False
105106
show_cols = (
106107
f"F1,DAF,Precision,Accuracy,TPR,TNR,MAE,RMSE,{R2_col},"
107-
"training size,Model Class".split(",")
108+
f"{train_size_col},{model_type_col}".split(",")
108109
)
109110

110111
for label, df in (("-first-10k", df_metrics_10k), ("", df_metrics)):
@@ -160,7 +161,7 @@
160161
)
161162
try:
162163
df_to_pdf(styler, f"{PDF_FIGS}/metrics-table{label}.pdf")
163-
except ImportError as exc:
164+
except (ImportError, RuntimeError) as exc:
164165
print(f"df_to_pdf failed: {exc}")
165166

166167

scripts/model_figs/model_run_times.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import requests
1616
import wandb
1717
import wandb.apis.public
18-
from pymatviz.utils import save_fig
18+
from pymatviz.io import save_fig
1919
from tqdm import tqdm
2020

2121
from matbench_discovery import PDF_FIGS, SITE_FIGS, SITE_MODELS, WANDB_PATH

scripts/model_figs/roc_prc_curves_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import math
99

1010
import pandas as pd
11-
from pymatviz.utils import save_fig
11+
from pymatviz.io import save_fig
1212
from sklearn.metrics import auc, precision_recall_curve, roc_curve
1313
from tqdm import tqdm
1414

scripts/model_figs/rolling_mae_vs_hull_dist_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import plotly.graph_objects as go
9-
from pymatviz.utils import save_fig
9+
from pymatviz.io import save_fig
1010

1111
from matbench_discovery import PDF_FIGS, SITE_FIGS
1212
from matbench_discovery.plots import rolling_mae_vs_hull_dist

scripts/rolling_mae_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
# %%
5-
from pymatviz.utils import save_fig
5+
from pymatviz.io import save_fig
66

77
from matbench_discovery import PDF_FIGS, today
88
from matbench_discovery.plots import rolling_mae_vs_hull_dist

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
# %%
7-
from pymatviz.utils import save_fig
7+
from pymatviz.io import save_fig
88

99
from matbench_discovery import PDF_FIGS, SITE_FIGS, today
1010
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist

0 commit comments

Comments
 (0)