Skip to content

Commit a43912e

Browse files
committed
break into submodules
1 parent 01a6d83 commit a43912e

File tree

14 files changed

+898
-822
lines changed

14 files changed

+898
-822
lines changed

assets/debug.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from matminer.datasets import load_dataset
2+
3+
from pymatviz.enums import Key
4+
from pymatviz.io import save_and_compress_svg
5+
from pymatviz.ptable import ptable_heatmap_ratio
6+
7+
8+
df_expt_gap = load_dataset("matbench_expt_gap")
9+
df_steels = load_dataset("matbench_steels")
10+
11+
12+
fig = ptable_heatmap_ratio(
13+
df_expt_gap[Key.composition], df_steels[Key.composition], log=True, values_fmt=".4g"
14+
)
15+
title = "Element ratios in Matbench Experimental Band Gap vs Matbench Steel"
16+
fig.suptitle(title, y=0.96, fontsize=20, fontweight="bold")
17+
save_and_compress_svg(fig, "debug")

examples/dataset_exploration/matpes/eda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from pymatviz.histograms import spacegroup_hist
1717
from pymatviz.io import save_fig
1818
from pymatviz.powerups import add_identity_line
19-
from pymatviz.ptable import count_elements, ptable_heatmap, ptable_heatmap_splits
19+
from pymatviz.process_data import count_elements
20+
from pymatviz.ptable import ptable_heatmap, ptable_heatmap_splits
2021
from pymatviz.sunburst import spacegroup_sunburst
2122

2223

examples/make_assets/ptable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from pymatviz.enums import Key
99
from pymatviz.io import save_and_compress_svg
10+
from pymatviz.process_data import count_elements
1011
from pymatviz.ptable import (
11-
count_elements,
1212
ptable_heatmap,
1313
ptable_heatmap_plotly,
1414
ptable_heatmap_ratio,

pymatviz/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@
2828
plot_phonon_bands_and_dos,
2929
plot_phonon_dos,
3030
)
31+
from pymatviz.process_data import count_elements
3132
from pymatviz.ptable import (
32-
ChildPlotters,
33-
PTableProjector,
34-
count_elements,
3533
ptable_heatmap,
3634
ptable_heatmap_plotly,
3735
ptable_heatmap_ratio,

pymatviz/_preprocess_data.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

pymatviz/histograms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pymatviz.enums import ElemCountMode, Key
1919
from pymatviz.powerups import annotate_bars
20-
from pymatviz.ptable import count_elements
20+
from pymatviz.process_data import count_elements
2121
from pymatviz.utils import (
2222
BACKENDS,
2323
MATPLOTLIB,
@@ -29,7 +29,7 @@
2929

3030

3131
if TYPE_CHECKING:
32-
from pymatviz.ptable import ElemValues
32+
from pymatviz.utils import ElemValues
3333

3434

3535
def spacegroup_hist(

pymatviz/process_data.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Various periodic table heatmaps with matplotlib and plotly."""
2+
3+
from __future__ import annotations
4+
5+
import itertools
6+
from collections.abc import Sequence
7+
from typing import Union
8+
9+
import pandas as pd
10+
from pandas.api.types import is_numeric_dtype, is_string_dtype
11+
from pymatgen.core import Composition
12+
13+
from pymatviz.enums import ElemCountMode, Key
14+
from pymatviz.utils import df_ptable
15+
16+
17+
ElemValues = Union[dict[Union[str, int], float], pd.Series, Sequence[str]]
18+
19+
20+
def count_elements(
21+
values: ElemValues,
22+
count_mode: ElemCountMode = ElemCountMode.composition,
23+
exclude_elements: Sequence[str] = (),
24+
fill_value: float | None = 0,
25+
) -> pd.Series:
26+
"""Count element occurrence in list of formula strings or dict-like compositions.
27+
If passed values are already a map from element symbol to counts, ensure the
28+
data is a pd.Series filled with zero values for missing element symbols.
29+
30+
Provided as standalone function for external use or to cache long computations.
31+
Caching long element counts is done by refactoring
32+
ptable_heatmap(long_list_of_formulas) # slow
33+
to
34+
elem_counts = count_elements(long_list_of_formulas) # slow
35+
ptable_heatmap(elem_counts) # fast, only rerun this line to update the plot
36+
37+
Args:
38+
values (dict[str, int | float] | pd.Series | list[str]): Iterable of
39+
composition strings/objects or map from element symbols to heatmap values.
40+
count_mode ("(element|fractional|reduced)_composition"):
41+
Only used when values is a list of composition strings/objects.
42+
- composition (default): Count elements in each composition as is,
43+
i.e. without reduction or normalization.
44+
- fractional_composition: Convert to normalized compositions in which the
45+
amounts of each species sum to before counting.
46+
Example: Fe2 O3 -> Fe0.4 O0.6
47+
- reduced_composition: Convert to reduced compositions (i.e. amounts
48+
normalized by greatest common denominator) before counting.
49+
Example: Fe4 P4 O16 -> Fe P O4.
50+
- occurrence: Count the number of times each element occurs in a list of
51+
formulas irrespective of compositions. E.g. [Fe2 O3, Fe O, Fe4 P4 O16]
52+
counts to {Fe: 3, O: 3, P: 1}.
53+
exclude_elements (Sequence[str]): Elements to exclude from the count. Defaults
54+
to ().
55+
fill_value (float | None): Value to fill in for missing elements. Defaults to 0.
56+
57+
Returns:
58+
pd.Series: Map element symbols to heatmap values.
59+
"""
60+
valid_count_modes = list(ElemCountMode.key_val_dict())
61+
if count_mode not in valid_count_modes:
62+
raise ValueError(f"Invalid {count_mode=} must be one of {valid_count_modes}")
63+
# Ensure values is Series if we got dict/list/tuple
64+
srs = pd.Series(values)
65+
66+
if is_numeric_dtype(srs):
67+
pass
68+
elif is_string_dtype(srs) or {*map(type, srs)} <= {str, Composition}:
69+
# all items are formula strings or Composition objects
70+
if count_mode == "occurrence":
71+
srs = pd.Series(
72+
itertools.chain.from_iterable(
73+
map(str, Composition(comp, allow_negative=True)) for comp in srs
74+
)
75+
).value_counts()
76+
else:
77+
attr = (
78+
"element_composition" if count_mode == Key.composition else count_mode
79+
)
80+
srs = pd.DataFrame(
81+
getattr(Composition(formula, allow_negative=True), attr).as_dict()
82+
for formula in srs
83+
).sum() # sum up element occurrences
84+
else:
85+
raise ValueError(
86+
"Expected values to be map from element symbols to heatmap values or "
87+
f"list of compositions (strings or Pymatgen objects), got {values}"
88+
)
89+
90+
try:
91+
# If index consists entirely of strings representing integers, convert to ints
92+
srs.index = srs.index.astype(int)
93+
except (ValueError, TypeError):
94+
pass
95+
96+
if pd.api.types.is_integer_dtype(srs.index):
97+
# If index is all integers, assume they represent atomic
98+
# numbers and map them to element symbols (H: 1, He: 2, ...)
99+
idx_min, idx_max = srs.index.min(), srs.index.max()
100+
if idx_max > 118 or idx_min < 1:
101+
raise ValueError(
102+
"element value keys were found to be integers and assumed to represent "
103+
f"atomic numbers, but values range from {idx_min} to {idx_max}, "
104+
"expected range [1, 118]."
105+
)
106+
map_atomic_num_to_elem_symbol = (
107+
df_ptable.reset_index().set_index("atomic_number").symbol
108+
)
109+
srs.index = srs.index.map(map_atomic_num_to_elem_symbol)
110+
111+
# Ensure all elements are present in returned Series (with value zero if they
112+
# weren't in values before)
113+
srs = srs.reindex(df_ptable.index, fill_value=fill_value).rename("count")
114+
115+
if len(exclude_elements) > 0:
116+
if isinstance(exclude_elements, str):
117+
exclude_elements = [exclude_elements]
118+
if isinstance(exclude_elements, tuple):
119+
exclude_elements = list(exclude_elements)
120+
try:
121+
srs = srs.drop(exclude_elements)
122+
except KeyError as exc:
123+
bad_symbols = ", ".join(x for x in exclude_elements if x not in srs)
124+
raise ValueError(
125+
f"Unexpected symbol(s) {bad_symbols} in {exclude_elements=}"
126+
) from exc
127+
128+
return srs

pymatviz/ptable/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""matplotlib and plotly periodic table figures."""
2+
3+
from __future__ import annotations
4+
5+
from pymatviz.ptable.matplotlib import (
6+
ptable_heatmap,
7+
ptable_heatmap_ratio,
8+
ptable_heatmap_splits,
9+
ptable_hists,
10+
ptable_lines,
11+
ptable_scatters,
12+
)
13+
from pymatviz.ptable.plotly import ptable_heatmap_plotly

0 commit comments

Comments
 (0)