diff --git a/assets/matbench-phonons-structures-2d.svg b/assets/matbench-phonons-structures-2d.svg index c78a8a64..62c9a57d 100644 --- a/assets/matbench-phonons-structures-2d.svg +++ b/assets/matbench-phonons-structures-2d.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg b/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg index a27896ac..c16b9bc5 100644 --- a/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg +++ b/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg @@ -1 +1,7 @@ - \ No newline at end of file + + + + + + + \ No newline at end of file diff --git a/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg b/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg index 31bae044..d566689d 100644 --- a/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg +++ b/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg @@ -1 +1,7 @@ - \ No newline at end of file + + + + + + + \ No newline at end of file diff --git a/examples/_generate_assets.py b/examples/_generate_assets.py index 3470e9cd..1dddb553 100644 --- a/examples/_generate_assets.py +++ b/examples/_generate_assets.py @@ -11,8 +11,9 @@ from matminer.datasets import load_dataset from monty.io import zopen from monty.json import MontyDecoder +from mp_api.client import MPRester +from pymatgen.core import Structure from pymatgen.core.periodic_table import Element -from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands from pymatgen.phonon.dos import PhononDos from tqdm import tqdm @@ -61,6 +62,8 @@ px.defaults.template = "pymatviz_white" pio.templates.default = "pymatviz_white" +struct: Structure # for type hinting + # Random classification data np.random.seed(42) rand_clf_size = 100 @@ -338,8 +341,11 @@ title = f"{len(axs.flat)} Matbench phonon structures" fig.suptitle(title, fontweight="bold", fontsize=20) -for row, ax in zip(df_phonons.itertuples(), axs.flat): - idx, struct, *_, spg_num = row +for idx, (row, ax) in enumerate(zip(df_phonons.itertuples(), axs.flat), start=1): + struct = row.structure + spg_num = struct.get_space_group_info()[1] + struct.add_oxidation_state_by_guess() + plot_structure_2d( struct, ax=ax, diff --git a/pymatviz/structure_viz.py b/pymatviz/structure_viz.py index fbc47fdb..0a33f984 100644 --- a/pymatviz/structure_viz.py +++ b/pymatviz/structure_viz.py @@ -1,39 +1,34 @@ -"""2D plots of pymatgen structures with matplotlib.""" +"""2D plots of pymatgen structures with matplotlib. + +plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were +inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib. +""" from __future__ import annotations import math import warnings from itertools import product -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import PathPatch, Wedge from matplotlib.path import Path from pymatgen.analysis.local_env import CrystalNN, NearNeighbors +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from pymatviz.utils import covalent_radii, jmol_colors +from pymatviz.utils import ExperimentalWarning, covalent_radii, jmol_colors if TYPE_CHECKING: from collections.abc import Sequence + from typing import Any, Literal from numpy.typing import ArrayLike from pymatgen.core import Structure -class ExperimentalWarning(Warning): - """Used for experimental show_bonds feature.""" - - -warnings.simplefilter("once", ExperimentalWarning) - - -# plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were -# inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib - - def _angles_to_rotation_matrix( angles: str, rotation: ArrayLike | None = None ) -> ArrayLike: @@ -52,8 +47,10 @@ def _angles_to_rotation_matrix( """ if rotation is None: rotation = np.eye(3) + + # Return initial rotation matrix if no angles if not angles: - return rotation.copy() # return initial rotation matrix if no angles + return rotation.copy() for angle in angles.split(","): radians = math.radians(float(angle[:-1])) @@ -82,28 +79,27 @@ def unit_cell_to_lines(cell: ArrayLike) -> tuple[ArrayLike, ArrayLike, ArrayLike - z-indices that sort plot elements into out-of-plane layers - lines used to plot the unit cell """ - n_lines = 0 + n_lines = n1 = 0 segments = [] - for c in range(3): - norm = math.sqrt(sum(cell[c] ** 2)) + for idx in range(3): + norm = math.sqrt(sum(cell[idx] ** 2)) segment = max(2, int(norm / 0.3)) segments.append(segment) n_lines += 4 * segment lines = np.empty((n_lines, 3)) - z_indices = np.empty(n_lines, int) + z_indices = np.empty(n_lines, dtype=int) unit_cell_lines = np.zeros((3, 3)) - n1 = 0 - for c in range(3): - segment = segments[c] - dd = cell[c] / (4 * segment - 2) - unit_cell_lines[c] = dd + for idx in range(3): + segment = segments[idx] + dd = cell[idx] / (4 * segment - 2) + unit_cell_lines[idx] = dd P = np.arange(1, 4 * segment + 1, 4)[:, None] * dd - z_indices[n1:] = c + z_indices[n1:] = idx for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]: n2 = n1 + segment - lines[n1:n2] = P + i * cell[c - 2] + j * cell[c - 1] + lines[n1:n2] = P + i * cell[idx - 2] + j * cell[idx - 1] n1 = n2 return lines, z_indices, unit_cell_lines @@ -122,13 +118,12 @@ def plot_structure_2d( | Literal["symbol", "species"] | dict[str, str | float] | Sequence[str | float] = True, - site_labels_bbox: dict[str, Any] | None = None, label_kwargs: dict[str, Any] | None = None, bond_kwargs: dict[str, Any] | None = None, standardize_struct: bool | None = None, axis: bool | str = "off", ) -> plt.Axes: - """Plot pymatgen structures in 2d with matplotlib. + """Plot pymatgen structures in 2D with matplotlib. Inspired by ASE's ase.visualize.plot.plot_atoms() https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib @@ -137,7 +132,7 @@ def plot_structure_2d( For example, these two snippets should give very similar output: - ```py + ```python from pymatgen.ext.matproj import MPRester mp_19017 = MPRester().get_structure_by_material_id("mp-19017") @@ -182,10 +177,10 @@ def plot_structure_2d( colors, either a named color (str) or rgb(a) values like (0.2, 0.3, 0.6). Defaults to JMol colors (https://jmol.sourceforge.net/jscolors). scale (float, optional): Scaling of the plotted atoms and lines. Defaults to 1. - show_unit_cell (bool, optional): Whether to draw unit cell. Defaults to True. - show_bonds (bool | NearNeighbors, optional): Whether to draw bonds. If True, use + show_unit_cell (bool, optional): Whether to plot unit cell. Defaults to True. + show_bonds (bool | NearNeighbors, optional): Whether to plot bonds. If True, use pymatgen.analysis.local_env.CrystalNN to infer the structure's connectivity. - If False, don't draw bonds. If a subclass of + If False, don't plot bonds. If a subclass of pymatgen.analysis.local_env.NearNeighbors, use that to determine connectivity. Options include VoronoiNN, MinimumDistanceNN, OpenBabelNN, CovalentBondNN, dtc. Defaults to True. @@ -197,12 +192,10 @@ def plot_structure_2d( number of sites in the crystal. If a string, must be "symbol" or "species". "symbol" hides the oxidation state, "species" shows it (equivalent to True). Defaults to True. - site_labels_bbox (dict, optional): Keyword arguments for matplotlib.text.Text - bbox like {"facecolor": "white", "alpha": 0.5}. Defaults to None. label_kwargs (dict, optional): Keyword arguments for matplotlib.text.Text like {"fontsize": 14}. Defaults to None. bond_kwargs (dict, optional): Keyword arguments for the matplotlib.path.Path - class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, + class used to plot chemical bonds. Allowed are edgecolor, facecolor, color, linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle. Defaults to None. standardize_struct (bool, optional): Whether to standardize the structure using @@ -228,7 +221,7 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, f" the number of sites in the crystal ({len(struct)=})" ) - # default behavior in case of no user input is to standardize if any fractional + # Default behavior in case of no user input: standardize if any fractional # coordinates are negative has_sites_outside_unit_cell = any(any(site.frac_coords < 0) for site in struct) if standardize_struct is False and has_sites_outside_unit_cell: @@ -240,9 +233,9 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, elif standardize_struct is None: standardize_struct = has_sites_outside_unit_cell if standardize_struct: - from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - struct = SpacegroupAnalyzer(struct).get_conventional_standard_structure() + + # Get default colors if colors is None: colors = jmol_colors @@ -257,15 +250,14 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, else: # atomic_radii is assumed to be a map from element symbols to atomic radii # make sure all present elements are assigned a radius - missing = set(elements_at_sites) - set(atomic_radii) - if missing: + if missing := set(elements_at_sites) - set(atomic_radii): raise ValueError(f"atomic_radii is missing keys: {missing}") radii_at_sites = np.array( [atomic_radii[el] for el in elements_at_sites] # type: ignore[index] ) - n_atoms = len(struct) + # Generate lines for unit cell rotation_matrix = _angles_to_rotation_matrix(rotation) unit_cell = struct.lattice.matrix @@ -280,6 +272,8 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, unit_cell_lines = None cell_vertices = None + # Zip atoms and unit cell lines together + n_atoms = len(struct) n_lines = len(lines) positions = np.empty((n_atoms + n_lines, 3)) @@ -287,21 +281,26 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, positions[:n_atoms] = site_coords positions[n_atoms:] = lines - # determine which lines should be hidden behind other objects + # Determine which unit cell line should be hidden behind other objects for idx in range(n_lines): this_layer = unit_cell_lines[z_indices[idx]] + occluded_top = ((site_coords - lines[idx] + this_layer) ** 2).sum( 1 ) < radii_at_sites**2 + occluded_bottom = ((site_coords - lines[idx] - this_layer) ** 2).sum( 1 ) < radii_at_sites**2 + if any(occluded_top & occluded_bottom): z_indices[idx] = -1 + # Apply rotation matrix positions = np.dot(positions, rotation_matrix) rotated_site_coords = positions[:n_atoms] + # Normalize wedge positions min_coords = (rotated_site_coords - radii_at_sites[:, None]).min(0) max_coords = (rotated_site_coords + radii_at_sites[:, None]).max(0) @@ -316,19 +315,23 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, positions *= scale positions -= offset + # Rotate and scale unit cell lines if n_lines > 0: unit_cell_lines = np.dot(unit_cell_lines, rotation_matrix)[:, :2] * scale special_site_labels = ("symbol", "species") - # sort positions by 3rd dim so we draw from back to front in z-axis (out-of-plane) + # Sort positions by 3rd dim to plot from back to front along z-axis (out-of-plane) for idx in positions[:, 2].argsort(): xy = positions[idx, :2] start = 0 + zorder = positions[idx][2] + if idx < n_atoms: - # loop over all species on a site (usually just 1 for ordered sites) - for specie, occupancy in struct[idx].species.items(): - # strip oxidation state from element symbol (e.g. Ta5+ to Ta) - elem_symbol = specie.symbol + # Loop over all species on a site (usually just 1 for ordered sites) + for species, occupancy in struct[idx].species.items(): + # Strip oxidation state from element symbol (e.g. Ta5+ to Ta) + elem_symbol = species.symbol + radius = atomic_radii[elem_symbol] * scale # type: ignore[index] face_color = colors[elem_symbol] wedge = Wedge( @@ -338,23 +341,25 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, 360 * (start + occupancy), facecolor=face_color, edgecolor="black", + zorder=zorder, ) ax.add_patch(wedge) + # Generate labels if site_labels == "symbol": txt = elem_symbol elif site_labels in ("species", True): - txt = specie + txt = species elif site_labels is False: txt = "" elif isinstance(site_labels, dict): - # try element incl. oxidation state as dict key first (e.g. Na+), + # Try element incl. oxidation state as dict key first (e.g. Na+), # then just element as fallback txt = site_labels.get( - repr(specie), site_labels.get(elem_symbol, "") + repr(species), site_labels.get(elem_symbol, "") ) if txt in special_site_labels: - txt = specie if txt == "species" else elem_symbol + txt = species if txt == "species" else elem_symbol elif isinstance(site_labels, (list, tuple)): txt = site_labels[idx] # idx runs from 0 to n_atoms else: @@ -363,8 +368,9 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, f"{', '.join(special_site_labels)}, dict, list)" ) + # Add labels if site_labels: - # place element symbol half way along outer wedge edge for + # Place element symbol half way along outer wedge edge for # disordered sites half_way = 2 * np.pi * (start + occupancy / 2) direction = np.array([math.cos(half_way), math.sin(half_way)]) @@ -372,20 +378,23 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, (0.5 * radius) * direction if occupancy < 1 else (0, 0) ) - bbox = dict(facecolor="none", edgecolor="none", pad=1) - bbox.update(site_labels_bbox or {}) txt_kwds = dict( - ha="center", va="center", bbox=bbox, **(label_kwargs or {}) + ha="center", + va="center", + zorder=zorder, + **(label_kwargs or {}), ) ax.text(*(xy + text_offset), txt, **txt_kwds) start += occupancy - else: # draw unit cell + + # Plot unit cell + else: cell_idx = idx - n_atoms - # only draw line if not obstructed by an atom + # Only plot lines not obstructed by an atom if z_indices[cell_idx] != -1: hxy = unit_cell_lines[z_indices[cell_idx]] - path = PathPatch(Path((xy + hxy, xy - hxy))) + path = PathPatch(Path((xy + hxy, xy - hxy)), zorder=zorder) ax.add_patch(path) if show_bonds: @@ -404,8 +413,8 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, ) # If structure doesn't have any oxidation states yet, guess them from chemical - # composition. Helps CrystalNN and other strategies to estimate better bond - # connectivity. Uses getattr on site.specie since it's often a pymatgen Element + # composition. Use CrystalNN and other strategies to better estimate bond + # connectivity. Use getattr on site.specie since it's often a pymatgen Element # which has no oxi_state if not any( hasattr(getattr(site, "specie", None), "oxi_state") for site in struct @@ -413,7 +422,8 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, try: struct.add_oxidation_state_by_guess() except ValueError: # fails for disordered structures - "Charge balance analysis requires integer values in Composition" + # Charge balance analysis requires integer values in Composition + pass structure_graph = neighbor_strategy_cls().get_bonded_structure(struct) diff --git a/pymatviz/utils.py b/pymatviz/utils.py index 30a8c822..0358e627 100644 --- a/pymatviz/utils.py +++ b/pymatviz/utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import ast +import warnings from contextlib import contextmanager from functools import partial, wraps from os.path import dirname @@ -79,6 +80,13 @@ element_symbols[Z] = symbol +class ExperimentalWarning(Warning): + """Warning for experimental features.""" + + +warnings.simplefilter("once", ExperimentalWarning) + + def pretty_label(key: str, backend: Backend) -> str: """Map metric keys to their pretty labels.""" if backend not in VALID_BACKENDS: diff --git a/pyproject.toml b/pyproject.toml index f02c95f8..47851e36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ test = [ "pytest-cov", "weasyprint", ] -data-src = ["matminer"] +data-src = ["matminer", "mp_api",] export-figs = ["kaleido"] gh-pages = ["jupyter", "lazydocs", "nbconvert"] # needed for pandas Stylers, see https://github.com/pandas-dev/pandas/blob/-/pyproject.toml @@ -105,7 +105,7 @@ lint.ignore = [ "PTH", "RUF001", # ambiguous-unicode-character-string "S311", - "SIM105", # Use contextlib.suppress(FileNotFoundError) instead of try-except-pass + "SIM105", # Use contextlib.suppress() instead of try-except-pass "TD", "TRY003", ] diff --git a/tests/test_structure_viz.py b/tests/test_structure_viz.py index 668de50e..9bb30d0b 100644 --- a/tests/test_structure_viz.py +++ b/tests/test_structure_viz.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import pandas as pd @@ -69,14 +69,10 @@ def test_plot_structure_2d_axis(axis: str | bool) -> None: "site_labels", [True, False, "symbol", "species", {"Fe": "Iron"}, {"Fe": 1.0}, ["Fe", "O"]], ) -@pytest.mark.parametrize("site_labels_bbox", [None, {}, {"boxstyle": "round"}]) def test_plot_structure_2d_site_labels( site_labels: bool | str | dict[str, str | float] | Sequence[str], - site_labels_bbox: dict[str, Any] | None, ) -> None: - ax = plot_structure_2d( - disordered_struct, site_labels=site_labels, site_labels_bbox=site_labels_bbox - ) + ax = plot_structure_2d(disordered_struct, site_labels=site_labels) if site_labels is False: assert not ax.axes.texts else: