|
| 1 | +"""Various periodic table heatmaps with matplotlib and plotly.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import itertools |
| 6 | +from collections.abc import Sequence |
| 7 | +from typing import Union |
| 8 | + |
| 9 | +import pandas as pd |
| 10 | +from pandas.api.types import is_numeric_dtype, is_string_dtype |
| 11 | +from pymatgen.core import Composition |
| 12 | + |
| 13 | +from pymatviz.enums import ElemCountMode, Key |
| 14 | +from pymatviz.utils import df_ptable |
| 15 | + |
| 16 | + |
| 17 | +ElemValues = Union[dict[Union[str, int], float], pd.Series, Sequence[str]] |
| 18 | + |
| 19 | + |
| 20 | +def count_elements( |
| 21 | + values: ElemValues, |
| 22 | + count_mode: ElemCountMode = ElemCountMode.composition, |
| 23 | + exclude_elements: Sequence[str] = (), |
| 24 | + fill_value: float | None = 0, |
| 25 | +) -> pd.Series: |
| 26 | + """Count element occurrence in list of formula strings or dict-like compositions. |
| 27 | + If passed values are already a map from element symbol to counts, ensure the |
| 28 | + data is a pd.Series filled with zero values for missing element symbols. |
| 29 | +
|
| 30 | + Provided as standalone function for external use or to cache long computations. |
| 31 | + Caching long element counts is done by refactoring |
| 32 | + ptable_heatmap(long_list_of_formulas) # slow |
| 33 | + to |
| 34 | + elem_counts = count_elements(long_list_of_formulas) # slow |
| 35 | + ptable_heatmap(elem_counts) # fast, only rerun this line to update the plot |
| 36 | +
|
| 37 | + Args: |
| 38 | + values (dict[str, int | float] | pd.Series | list[str]): Iterable of |
| 39 | + composition strings/objects or map from element symbols to heatmap values. |
| 40 | + count_mode ("(element|fractional|reduced)_composition"): |
| 41 | + Only used when values is a list of composition strings/objects. |
| 42 | + - composition (default): Count elements in each composition as is, |
| 43 | + i.e. without reduction or normalization. |
| 44 | + - fractional_composition: Convert to normalized compositions in which the |
| 45 | + amounts of each species sum to before counting. |
| 46 | + Example: Fe2 O3 -> Fe0.4 O0.6 |
| 47 | + - reduced_composition: Convert to reduced compositions (i.e. amounts |
| 48 | + normalized by greatest common denominator) before counting. |
| 49 | + Example: Fe4 P4 O16 -> Fe P O4. |
| 50 | + - occurrence: Count the number of times each element occurs in a list of |
| 51 | + formulas irrespective of compositions. E.g. [Fe2 O3, Fe O, Fe4 P4 O16] |
| 52 | + counts to {Fe: 3, O: 3, P: 1}. |
| 53 | + exclude_elements (Sequence[str]): Elements to exclude from the count. Defaults |
| 54 | + to (). |
| 55 | + fill_value (float | None): Value to fill in for missing elements. Defaults to 0. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + pd.Series: Map element symbols to heatmap values. |
| 59 | + """ |
| 60 | + valid_count_modes = list(ElemCountMode.key_val_dict()) |
| 61 | + if count_mode not in valid_count_modes: |
| 62 | + raise ValueError(f"Invalid {count_mode=} must be one of {valid_count_modes}") |
| 63 | + # Ensure values is Series if we got dict/list/tuple |
| 64 | + srs = pd.Series(values) |
| 65 | + |
| 66 | + if is_numeric_dtype(srs): |
| 67 | + pass |
| 68 | + elif is_string_dtype(srs) or {*map(type, srs)} <= {str, Composition}: |
| 69 | + # all items are formula strings or Composition objects |
| 70 | + if count_mode == "occurrence": |
| 71 | + srs = pd.Series( |
| 72 | + itertools.chain.from_iterable( |
| 73 | + map(str, Composition(comp, allow_negative=True)) for comp in srs |
| 74 | + ) |
| 75 | + ).value_counts() |
| 76 | + else: |
| 77 | + attr = ( |
| 78 | + "element_composition" if count_mode == Key.composition else count_mode |
| 79 | + ) |
| 80 | + srs = pd.DataFrame( |
| 81 | + getattr(Composition(formula, allow_negative=True), attr).as_dict() |
| 82 | + for formula in srs |
| 83 | + ).sum() # sum up element occurrences |
| 84 | + else: |
| 85 | + raise ValueError( |
| 86 | + "Expected values to be map from element symbols to heatmap values or " |
| 87 | + f"list of compositions (strings or Pymatgen objects), got {values}" |
| 88 | + ) |
| 89 | + |
| 90 | + try: |
| 91 | + # If index consists entirely of strings representing integers, convert to ints |
| 92 | + srs.index = srs.index.astype(int) |
| 93 | + except (ValueError, TypeError): |
| 94 | + pass |
| 95 | + |
| 96 | + if pd.api.types.is_integer_dtype(srs.index): |
| 97 | + # If index is all integers, assume they represent atomic |
| 98 | + # numbers and map them to element symbols (H: 1, He: 2, ...) |
| 99 | + idx_min, idx_max = srs.index.min(), srs.index.max() |
| 100 | + if idx_max > 118 or idx_min < 1: |
| 101 | + raise ValueError( |
| 102 | + "element value keys were found to be integers and assumed to represent " |
| 103 | + f"atomic numbers, but values range from {idx_min} to {idx_max}, " |
| 104 | + "expected range [1, 118]." |
| 105 | + ) |
| 106 | + map_atomic_num_to_elem_symbol = ( |
| 107 | + df_ptable.reset_index().set_index("atomic_number").symbol |
| 108 | + ) |
| 109 | + srs.index = srs.index.map(map_atomic_num_to_elem_symbol) |
| 110 | + |
| 111 | + # Ensure all elements are present in returned Series (with value zero if they |
| 112 | + # weren't in values before) |
| 113 | + srs = srs.reindex(df_ptable.index, fill_value=fill_value).rename("count") |
| 114 | + |
| 115 | + if len(exclude_elements) > 0: |
| 116 | + if isinstance(exclude_elements, str): |
| 117 | + exclude_elements = [exclude_elements] |
| 118 | + if isinstance(exclude_elements, tuple): |
| 119 | + exclude_elements = list(exclude_elements) |
| 120 | + try: |
| 121 | + srs = srs.drop(exclude_elements) |
| 122 | + except KeyError as exc: |
| 123 | + bad_symbols = ", ".join(x for x in exclude_elements if x not in srs) |
| 124 | + raise ValueError( |
| 125 | + f"Unexpected symbol(s) {bad_symbols} in {exclude_elements=}" |
| 126 | + ) from exc |
| 127 | + |
| 128 | + return srs |
0 commit comments