|
1 | 1 | # %%
|
2 | 2 | import pandas as pd
|
3 |
| -from mp_api.client import MPRester |
4 |
| -from mp_api.client.core import MPRestError |
| 3 | +from matminer.datasets import load_dataset |
5 | 4 |
|
6 | 5 | import pymatviz as pmv
|
7 | 6 | from pymatviz.enums import Key
|
|
11 | 10 |
|
12 | 11 |
|
13 | 12 | # %% Sankey diagram of crystal systems and space groups
|
14 |
| -try: |
15 |
| - with MPRester(use_document_model=False) as mpr: |
16 |
| - fields = [Key.mat_id, "symmetry.crystal_system", "symmetry.symbol"] |
17 |
| - docs = mpr.materials.summary.search( |
18 |
| - num_elements=(1, 3), fields=fields, num_chunks=30, chunk_size=1000 |
19 |
| - ) |
20 |
| -except MPRestError: |
21 |
| - raise SystemExit(0) from None |
| 13 | +data_name = "matbench_phonons" |
| 14 | +df_phonons = load_dataset(data_name) |
22 | 15 |
|
| 16 | +df_sym = pd.DataFrame( |
| 17 | + struct.get_symmetry_dataset(backend="moyopy", return_raw_dataset=True).as_dict() |
| 18 | + for struct in df_phonons[Key.structure] |
| 19 | +).rename(columns={"number": Key.spg_num}) |
| 20 | +df_sym[Key.crystal_system] = df_sym[Key.spg_num].map(pmv.utils.spg_to_crystal_sys) |
23 | 21 |
|
24 |
| -# %% |
25 |
| -df_mp = pd.json_normalize(docs).set_index(Key.mat_id) |
26 |
| -df_mp.columns = [Key.crystal_system, Key.spg_symbol] |
27 | 22 |
|
28 |
| -frequent_symbols = df_mp[Key.spg_symbol].value_counts().nlargest(20).index |
| 23 | +# %% |
| 24 | +frequent_symbols = df_sym[Key.spg_num].value_counts().nlargest(20).index |
29 | 25 |
|
30 |
| -df_spg = df_mp.query(f"{Key.spg_symbol} in @frequent_symbols") |
| 26 | +df_spg = df_sym.query(f"{Key.spg_num} in @frequent_symbols") |
31 | 27 |
|
32 | 28 |
|
33 | 29 | # %%
|
34 | 30 | fig = pmv.sankey_from_2_df_cols(
|
35 |
| - df_spg, [Key.crystal_system, Key.spg_symbol], labels_with_counts="percent" |
| 31 | + df_spg, [Key.crystal_system, Key.spg_num], labels_with_counts="percent" |
36 | 32 | )
|
37 |
| -title = "Common Space Groups in Materials Project" |
| 33 | +title = f"Common Space Groups in {data_name}" |
38 | 34 | fig.layout.title = dict(text=title, x=0.5, y=0.95)
|
39 | 35 | fig.layout.margin.t = 50
|
40 | 36 | fig.show()
|
41 |
| -pmv.io.save_and_compress_svg(fig, "sankey-crystal-sys-to-spg-symbol") |
| 37 | +pmv.io.save_and_compress_svg(fig, f"sankey-{data_name}") |
0 commit comments