Skip to content

Commit da42037

Browse files
committed
set isort known_third_party = wandb
1 parent 3d42214 commit da42037

13 files changed

+18
-29
lines changed

mb_discovery/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
from typing import Any, Generator, Sequence
55

6-
76
PKG_DIR = os.path.dirname(__file__)
87
ROOT = os.path.dirname(PKG_DIR)
98

mb_discovery/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from mb_discovery import ROOT
1313

14-
1514
__author__ = "Janosh Riebesell"
1615
__date__ = "2022-06-18"
1716

mb_discovery/m3gnet/join_and_plot_m3gnet_relax_results.py

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
hist_classified_stable_as_func_of_hull_dist,
1919
)
2020

21-
2221
today = f"{datetime.now():%Y-%m-%d}"
2322

2423

mb_discovery/m3gnet/slurm_array_m3gnet_relax_wbm.py

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from mb_discovery import ROOT, as_dict_handler
1616

17-
1817
"""
1918
To slurm submit this file, use
2019

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
hist_classified_stable_as_func_of_hull_dist,
1212
)
1313

14-
1514
__author__ = "Rhys Goodall, Janosh Riebesell"
1615
__date__ = "2022-06-18"
1716

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
hist_classified_stable_as_func_of_hull_dist,
1212
)
1313

14-
1514
__author__ = "Rhys Goodall, Janosh Riebesell"
1615
__date__ = "2022-08-25"
1716

mb_discovery/plot_scripts/plot_funcs.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from mb_discovery.plot_scripts import plt
1212

13-
1413
__author__ = "Janosh Riebesell"
1514
__date__ = "2022-08-05"
1615

@@ -365,7 +364,10 @@ def precision_recall_vs_calc_count(
365364
# previous call
366365
return ax
367366

368-
ax.set(xlabel="Number of Calculations", ylabel="Precision and Recall (%)")
367+
ax.set(
368+
xlabel="Number of compounds sorted by model-predicted stability",
369+
ylabel="Precision and Recall (%)",
370+
)
369371

370372
ax.set(ylim=(0, 100))
371373

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
precision_recall_vs_calc_count,
1111
)
1212

13-
1413
__author__ = "Rhys Goodall, Janosh Riebesell"
1514
__date__ = "2022-06-18"
1615

@@ -28,9 +27,9 @@
2827
).set_index("material_id")
2928
dfs[model_name] = df
3029

31-
# dfs["M3GNet"] = pd.read_json(
32-
# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
33-
# ).set_index("material_id")
30+
dfs["M3GNet"] = pd.read_json(
31+
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
32+
).set_index("material_id")
3433

3534
dfs["Wrenformer"] = pd.read_csv(
3635
f"{ROOT}/data/2022-08-16-wrenformer-preds.csv.bz2"

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from mb_discovery.plot_scripts import plt
88
from mb_discovery.plot_scripts.plot_funcs import rolling_mae_vs_hull_dist
99

10-
1110
__author__ = "Rhys Goodall, Janosh Riebesell"
1211
__date__ = "2022-06-18"
1312

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from mb_discovery.plot_scripts import plt
88
from mb_discovery.plot_scripts.plot_funcs import rolling_mae_vs_hull_dist
99

10-
1110
__author__ = "Rhys Goodall, Janosh Riebesell"
1211
__date__ = "2022-06-18"
1312

mb_discovery/wrenformer/mp/get_mp_energies.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from aviary.wren.utils import get_aflow_label_from_spglib
88
from mp_api.client import MPRester
99

10-
1110
"""
12-
Download all MP formation and above hull energies on 2022-08-13 for training a
13-
Wrenformer ensemble.
11+
Download all MP formation and above hull energies on 2022-08-13.
1412
1513
Related EDA of MP formation energies:
1614
https://github.com/janosh/pymatviz/blob/main/examples/mp_bimodal_e_form.ipynb
@@ -33,17 +31,16 @@
3331
"symmetry",
3432
"energy_above_hull",
3533
]
36-
with MPRester() as mpr:
34+
with MPRester(use_document_model=False) as mpr:
3735
docs = mpr.summary.search(fields=fields)
3836

3937
print(f"{today}: {len(docs) = :,}")
4038
# 2022-08-13: len(docs) = 146,323
4139

4240

4341
# %%
44-
df = pd.DataFrame(
45-
[{key: getattr(doc, key, None) for key in fields} for doc in docs]
46-
).set_index("material_id")
42+
df = pd.DataFrame(docs).set_index("material_id")
43+
df.pop("_id")
4744

4845
df["spacegroup_number"] = df.pop("symmetry").map(lambda x: x.number)
4946

mb_discovery/wrenformer/mp/use_trained_wrenformer_ensemble.py mb_discovery/wrenformer/mp/use_ensemble.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import wandb
99
from aviary.wrenformer.deploy import deploy_wandb_checkpoints
1010

11-
1211
__author__ = "Janosh Riebesell"
1312
__date__ = "2022-08-15"
1413

@@ -18,11 +17,11 @@
1817
stores predictions to CSV.
1918
"""
2019

21-
MODULE_DIR = os.path.dirname(__file__)
20+
module_dir = os.path.dirname(__file__)
21+
today = f"{datetime.now():%Y-%m-%d}"
2222

2323

2424
# %%
25-
today = f"{datetime.now():%Y-%m-%d}"
2625
# download wbm-steps-summary.csv (23.31 MB)
2726
data_path = "https://figshare.com/files/36714216?private_link=ff0ad14505f9624f0c05"
2827
df = pd.read_csv(data_path).set_index("material_id")
@@ -33,12 +32,13 @@
3332

3433
wandb.login()
3534
wandb_api = wandb.Api()
36-
runs = wandb_api.runs(
37-
"aviary/mp", filters={"tags": {"$in": ["wrenformer-e_form-ensemble-1"]}}
38-
)
35+
ensemble_id = "wrenformer-e_form-ensemble-1"
36+
runs = wandb_api.runs("aviary/mp", filters={"tags": {"$in": [ensemble_id]}})
37+
38+
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"
3939

4040
df, ensemble_metrics = deploy_wandb_checkpoints(
4141
runs, df, input_col="wyckoff", target_col=target_col
4242
)
4343

44-
df.to_csv(f"{MODULE_DIR}/{today}-wrenformer-preds-{target_col}.csv")
44+
df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv")

tests/test_plot_funcs.py

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
rolling_mae_vs_hull_dist,
1515
)
1616

17-
1817
DATA_DIR = f"{ROOT}/data/2022-06-11-from-rhys"
1918

2019
df_hull = pd.read_csv(f"{DATA_DIR}/wbm-e-above-mp-hull.csv").set_index("material_id")

0 commit comments

Comments
 (0)