9
9
from pymatgen .entries .computed_entries import ComputedStructureEntry
10
10
from tqdm import tqdm
11
11
12
- data_files = {
13
- "summary" : "2022-10-19-wbm-summary.csv" ,
14
- "initial-structures" : "2022-10-19-wbm-init-structs.json.bz2" ,
15
- "computed-structure-entries" : "2022-10-19-wbm-cses.json.bz2" ,
12
+ DATA_FILENAMES = {
13
+ "wbm-summary" : "wbm/2022-10-19-wbm-summary.csv" ,
14
+ "wbm-initial-structures" : "wbm/2022-10-19-wbm-init-structs.json.bz2" ,
15
+ "wbm-computed-structure-entries" : "wbm/2022-10-19-wbm-cses.json.bz2" ,
16
+ "mp-energies" : "mp/2022-08-13-mp-energies.json.gz" ,
17
+ "mp-computed-structure-entries" : "mp/2022-09-16-mp-computed-structure-entries.json.gz" ,
18
+ "mp-patched-phase-diagram" : "mp/2022-09-18-ppd-mp.pkl.gz" ,
19
+ "mp-elemental-ref-energies" : "mp/2022-09-19-mp-elemental-ref-energies.json" ,
16
20
}
17
21
18
- base_url = "https://raw.githubusercontent.com/janosh/matbench-discovery/main/data/wbm "
19
- default_cache_loc = os .path .expanduser ("~/.cache/matbench-discovery" )
22
+ RAW_REPO_URL = "https://raw.githubusercontent.com/janosh/matbench-discovery"
23
+ default_cache_dir = os .path .expanduser ("~/.cache/matbench-discovery" )
20
24
21
25
22
26
def chunks (xs : Sequence [Any ], n : int ) -> Generator [Sequence [Any ], None , None ]:
@@ -35,21 +39,26 @@ def as_dict_handler(obj: Any) -> dict[str, Any] | None:
35
39
# removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories
36
40
37
41
38
- def load_wbm (
39
- parts : Sequence [str ] = ("summary" ,),
42
+ def load_train_test (
43
+ parts : str | Sequence [str ] = ("summary" ,),
40
44
version : int = 1 ,
41
- cache_dir : str | None = default_cache_loc ,
45
+ cache_dir : str | None = default_cache_dir ,
42
46
hydrate : bool = False ,
43
47
) -> pd .DataFrame | dict [str , pd .DataFrame ]:
44
- """_summary_
48
+ """Download the MP training data and WBM test data in parts or in full as pandas
49
+ DataFrames. The full training and test sets are each about ~500 MB as compressed
50
+ JSON will be cached locally for faster re-loading unless cache_dir is set to None.
51
+
52
+ Hint: Import DATA_FILES from the same module as this function and
53
+ print(list(DATA_FILES)) to see permissible data names.
45
54
46
55
Args:
47
- parts (str, optional): Which parts of the WBM dataset to load. Can be any subset
48
- of {'summary', 'initial-structures', 'computed-structure-entries'} . Defaults
49
- to ["summary"], a dataframe with columns for material properties like VASP
50
- energy, formation energy, energy above the convex hull (3 columns with old,
51
- new and no Materials Project energy corrections applied for each), volume,
52
- band gap, number of sites per unit cell, and more.
56
+ parts (str | list[str] , optional): Which parts of the MP/ WBM dataset to load.
57
+ Can be any subset of list(DATA_FILES) . Defaults to ["summary"], a dataframe
58
+ with columns for material properties like VASP energy, formation energy,
59
+ energy above the convex hull (3 columns with old, new and no Materials
60
+ Project energy corrections applied for each), volume, band gap, number of
61
+ sites per unit cell, and more.
53
62
version (int, optional): Which version of the dataset to load. Defaults to 1
54
63
(currently the only available option).
55
64
cache_dir (str, optional): Where to cache data files on local drive. Defaults to
@@ -60,31 +69,36 @@ def load_wbm(
60
69
False as it noticeably increases load time.
61
70
62
71
Raises:
63
- ValueError: On bad version or bad keys for which data parts to load .
72
+ ValueError: On bad version number or bad part names .
64
73
65
74
Returns:
66
75
pd.DataFrame | dict[str, pd.DataFrame]: Single dataframe of dictionary of
67
76
multiple data parts were requested.
68
77
"""
78
+ if parts == "all" :
79
+ parts = list (DATA_FILENAMES )
80
+ elif isinstance (parts , str ):
81
+ parts = [parts ]
82
+
69
83
if version != 1 :
70
84
raise ValueError (f"Only version 1 currently available, got { version = } " )
71
- if missing := set (parts ) - set (data_files ):
72
- raise ValueError (f"{ missing } must be subset of { set (data_files )} " )
85
+ if missing := set (parts ) - set (DATA_FILENAMES ):
86
+ raise ValueError (f"{ missing } must be subset of { set (DATA_FILENAMES )} " )
73
87
74
88
dfs = {}
75
89
for key in parts :
76
- file = data_files [key ]
90
+ file = DATA_FILENAMES [key ]
77
91
reader = pd .read_csv if file .endswith (".csv" ) else pd .read_json
78
92
79
93
cache_path = f"{ cache_dir } /{ file } "
80
94
if os .path .isfile (cache_path ):
81
95
df = reader (cache_path )
82
96
else :
83
- url = f"{ base_url } /{ file } "
84
- print (f"Downloading { url = } " )
97
+ url = f"{ RAW_REPO_URL } / { version } .0.0/data /{ file } "
98
+ print (f"Downloading { key } from { url } " )
85
99
df = reader (url )
86
100
if cache_dir and not os .path .isfile (cache_path ):
87
- os .makedirs (cache_dir , exist_ok = True )
101
+ os .makedirs (os . path . dirname ( cache_path ) , exist_ok = True )
88
102
if ".csv" in file :
89
103
df .to_csv (cache_path )
90
104
elif ".json" in file :
0 commit comments