diff --git a/assets/matbench-phonons-structures-2d.svg b/assets/matbench-phonons-structures-2d.svg
index c78a8a64..62c9a57d 100644
--- a/assets/matbench-phonons-structures-2d.svg
+++ b/assets/matbench-phonons-structures-2d.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg b/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg
index a27896ac..c16b9bc5 100644
--- a/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg
+++ b/assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg
@@ -1 +1,7 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg b/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg
index 31bae044..d566689d 100644
--- a/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg
+++ b/assets/struct-2d-mp-19017-Li4Mn0.8Fe1.6P4C1.6O16-disordered.svg
@@ -1 +1,7 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/examples/_generate_assets.py b/examples/_generate_assets.py
index 3470e9cd..1dddb553 100644
--- a/examples/_generate_assets.py
+++ b/examples/_generate_assets.py
@@ -11,8 +11,9 @@
from matminer.datasets import load_dataset
from monty.io import zopen
from monty.json import MontyDecoder
+from mp_api.client import MPRester
+from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element
-from pymatgen.ext.matproj import MPRester
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands
from pymatgen.phonon.dos import PhononDos
from tqdm import tqdm
@@ -61,6 +62,8 @@
px.defaults.template = "pymatviz_white"
pio.templates.default = "pymatviz_white"
+struct: Structure # for type hinting
+
# Random classification data
np.random.seed(42)
rand_clf_size = 100
@@ -338,8 +341,11 @@
title = f"{len(axs.flat)} Matbench phonon structures"
fig.suptitle(title, fontweight="bold", fontsize=20)
-for row, ax in zip(df_phonons.itertuples(), axs.flat):
- idx, struct, *_, spg_num = row
+for idx, (row, ax) in enumerate(zip(df_phonons.itertuples(), axs.flat), start=1):
+ struct = row.structure
+ spg_num = struct.get_space_group_info()[1]
+ struct.add_oxidation_state_by_guess()
+
plot_structure_2d(
struct,
ax=ax,
diff --git a/pymatviz/structure_viz.py b/pymatviz/structure_viz.py
index fbc47fdb..0a33f984 100644
--- a/pymatviz/structure_viz.py
+++ b/pymatviz/structure_viz.py
@@ -1,39 +1,34 @@
-"""2D plots of pymatgen structures with matplotlib."""
+"""2D plots of pymatgen structures with matplotlib.
+
+plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were
+inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib.
+"""
from __future__ import annotations
import math
import warnings
from itertools import product
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import PathPatch, Wedge
from matplotlib.path import Path
from pymatgen.analysis.local_env import CrystalNN, NearNeighbors
+from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
-from pymatviz.utils import covalent_radii, jmol_colors
+from pymatviz.utils import ExperimentalWarning, covalent_radii, jmol_colors
if TYPE_CHECKING:
from collections.abc import Sequence
+ from typing import Any, Literal
from numpy.typing import ArrayLike
from pymatgen.core import Structure
-class ExperimentalWarning(Warning):
- """Used for experimental show_bonds feature."""
-
-
-warnings.simplefilter("once", ExperimentalWarning)
-
-
-# plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were
-# inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib
-
-
def _angles_to_rotation_matrix(
angles: str, rotation: ArrayLike | None = None
) -> ArrayLike:
@@ -52,8 +47,10 @@ def _angles_to_rotation_matrix(
"""
if rotation is None:
rotation = np.eye(3)
+
+ # Return initial rotation matrix if no angles
if not angles:
- return rotation.copy() # return initial rotation matrix if no angles
+ return rotation.copy()
for angle in angles.split(","):
radians = math.radians(float(angle[:-1]))
@@ -82,28 +79,27 @@ def unit_cell_to_lines(cell: ArrayLike) -> tuple[ArrayLike, ArrayLike, ArrayLike
- z-indices that sort plot elements into out-of-plane layers
- lines used to plot the unit cell
"""
- n_lines = 0
+ n_lines = n1 = 0
segments = []
- for c in range(3):
- norm = math.sqrt(sum(cell[c] ** 2))
+ for idx in range(3):
+ norm = math.sqrt(sum(cell[idx] ** 2))
segment = max(2, int(norm / 0.3))
segments.append(segment)
n_lines += 4 * segment
lines = np.empty((n_lines, 3))
- z_indices = np.empty(n_lines, int)
+ z_indices = np.empty(n_lines, dtype=int)
unit_cell_lines = np.zeros((3, 3))
- n1 = 0
- for c in range(3):
- segment = segments[c]
- dd = cell[c] / (4 * segment - 2)
- unit_cell_lines[c] = dd
+ for idx in range(3):
+ segment = segments[idx]
+ dd = cell[idx] / (4 * segment - 2)
+ unit_cell_lines[idx] = dd
P = np.arange(1, 4 * segment + 1, 4)[:, None] * dd
- z_indices[n1:] = c
+ z_indices[n1:] = idx
for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
n2 = n1 + segment
- lines[n1:n2] = P + i * cell[c - 2] + j * cell[c - 1]
+ lines[n1:n2] = P + i * cell[idx - 2] + j * cell[idx - 1]
n1 = n2
return lines, z_indices, unit_cell_lines
@@ -122,13 +118,12 @@ def plot_structure_2d(
| Literal["symbol", "species"]
| dict[str, str | float]
| Sequence[str | float] = True,
- site_labels_bbox: dict[str, Any] | None = None,
label_kwargs: dict[str, Any] | None = None,
bond_kwargs: dict[str, Any] | None = None,
standardize_struct: bool | None = None,
axis: bool | str = "off",
) -> plt.Axes:
- """Plot pymatgen structures in 2d with matplotlib.
+ """Plot pymatgen structures in 2D with matplotlib.
Inspired by ASE's ase.visualize.plot.plot_atoms()
https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib
@@ -137,7 +132,7 @@ def plot_structure_2d(
For example, these two snippets should give very similar output:
- ```py
+ ```python
from pymatgen.ext.matproj import MPRester
mp_19017 = MPRester().get_structure_by_material_id("mp-19017")
@@ -182,10 +177,10 @@ def plot_structure_2d(
colors, either a named color (str) or rgb(a) values like (0.2, 0.3, 0.6).
Defaults to JMol colors (https://jmol.sourceforge.net/jscolors).
scale (float, optional): Scaling of the plotted atoms and lines. Defaults to 1.
- show_unit_cell (bool, optional): Whether to draw unit cell. Defaults to True.
- show_bonds (bool | NearNeighbors, optional): Whether to draw bonds. If True, use
+ show_unit_cell (bool, optional): Whether to plot unit cell. Defaults to True.
+ show_bonds (bool | NearNeighbors, optional): Whether to plot bonds. If True, use
pymatgen.analysis.local_env.CrystalNN to infer the structure's connectivity.
- If False, don't draw bonds. If a subclass of
+ If False, don't plot bonds. If a subclass of
pymatgen.analysis.local_env.NearNeighbors, use that to determine
connectivity. Options include VoronoiNN, MinimumDistanceNN, OpenBabelNN,
CovalentBondNN, dtc. Defaults to True.
@@ -197,12 +192,10 @@ def plot_structure_2d(
number of sites in the crystal. If a string, must be "symbol" or
"species". "symbol" hides the oxidation state, "species" shows it
(equivalent to True). Defaults to True.
- site_labels_bbox (dict, optional): Keyword arguments for matplotlib.text.Text
- bbox like {"facecolor": "white", "alpha": 0.5}. Defaults to None.
label_kwargs (dict, optional): Keyword arguments for matplotlib.text.Text like
{"fontsize": 14}. Defaults to None.
bond_kwargs (dict, optional): Keyword arguments for the matplotlib.path.Path
- class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
+ class used to plot chemical bonds. Allowed are edgecolor, facecolor, color,
linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle.
Defaults to None.
standardize_struct (bool, optional): Whether to standardize the structure using
@@ -228,7 +221,7 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
f" the number of sites in the crystal ({len(struct)=})"
)
- # default behavior in case of no user input is to standardize if any fractional
+ # Default behavior in case of no user input: standardize if any fractional
# coordinates are negative
has_sites_outside_unit_cell = any(any(site.frac_coords < 0) for site in struct)
if standardize_struct is False and has_sites_outside_unit_cell:
@@ -240,9 +233,9 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
elif standardize_struct is None:
standardize_struct = has_sites_outside_unit_cell
if standardize_struct:
- from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
-
struct = SpacegroupAnalyzer(struct).get_conventional_standard_structure()
+
+ # Get default colors
if colors is None:
colors = jmol_colors
@@ -257,15 +250,14 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
else:
# atomic_radii is assumed to be a map from element symbols to atomic radii
# make sure all present elements are assigned a radius
- missing = set(elements_at_sites) - set(atomic_radii)
- if missing:
+ if missing := set(elements_at_sites) - set(atomic_radii):
raise ValueError(f"atomic_radii is missing keys: {missing}")
radii_at_sites = np.array(
[atomic_radii[el] for el in elements_at_sites] # type: ignore[index]
)
- n_atoms = len(struct)
+ # Generate lines for unit cell
rotation_matrix = _angles_to_rotation_matrix(rotation)
unit_cell = struct.lattice.matrix
@@ -280,6 +272,8 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
unit_cell_lines = None
cell_vertices = None
+ # Zip atoms and unit cell lines together
+ n_atoms = len(struct)
n_lines = len(lines)
positions = np.empty((n_atoms + n_lines, 3))
@@ -287,21 +281,26 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
positions[:n_atoms] = site_coords
positions[n_atoms:] = lines
- # determine which lines should be hidden behind other objects
+ # Determine which unit cell line should be hidden behind other objects
for idx in range(n_lines):
this_layer = unit_cell_lines[z_indices[idx]]
+
occluded_top = ((site_coords - lines[idx] + this_layer) ** 2).sum(
1
) < radii_at_sites**2
+
occluded_bottom = ((site_coords - lines[idx] - this_layer) ** 2).sum(
1
) < radii_at_sites**2
+
if any(occluded_top & occluded_bottom):
z_indices[idx] = -1
+ # Apply rotation matrix
positions = np.dot(positions, rotation_matrix)
rotated_site_coords = positions[:n_atoms]
+ # Normalize wedge positions
min_coords = (rotated_site_coords - radii_at_sites[:, None]).min(0)
max_coords = (rotated_site_coords + radii_at_sites[:, None]).max(0)
@@ -316,19 +315,23 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
positions *= scale
positions -= offset
+ # Rotate and scale unit cell lines
if n_lines > 0:
unit_cell_lines = np.dot(unit_cell_lines, rotation_matrix)[:, :2] * scale
special_site_labels = ("symbol", "species")
- # sort positions by 3rd dim so we draw from back to front in z-axis (out-of-plane)
+ # Sort positions by 3rd dim to plot from back to front along z-axis (out-of-plane)
for idx in positions[:, 2].argsort():
xy = positions[idx, :2]
start = 0
+ zorder = positions[idx][2]
+
if idx < n_atoms:
- # loop over all species on a site (usually just 1 for ordered sites)
- for specie, occupancy in struct[idx].species.items():
- # strip oxidation state from element symbol (e.g. Ta5+ to Ta)
- elem_symbol = specie.symbol
+ # Loop over all species on a site (usually just 1 for ordered sites)
+ for species, occupancy in struct[idx].species.items():
+ # Strip oxidation state from element symbol (e.g. Ta5+ to Ta)
+ elem_symbol = species.symbol
+
radius = atomic_radii[elem_symbol] * scale # type: ignore[index]
face_color = colors[elem_symbol]
wedge = Wedge(
@@ -338,23 +341,25 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
360 * (start + occupancy),
facecolor=face_color,
edgecolor="black",
+ zorder=zorder,
)
ax.add_patch(wedge)
+ # Generate labels
if site_labels == "symbol":
txt = elem_symbol
elif site_labels in ("species", True):
- txt = specie
+ txt = species
elif site_labels is False:
txt = ""
elif isinstance(site_labels, dict):
- # try element incl. oxidation state as dict key first (e.g. Na+),
+ # Try element incl. oxidation state as dict key first (e.g. Na+),
# then just element as fallback
txt = site_labels.get(
- repr(specie), site_labels.get(elem_symbol, "")
+ repr(species), site_labels.get(elem_symbol, "")
)
if txt in special_site_labels:
- txt = specie if txt == "species" else elem_symbol
+ txt = species if txt == "species" else elem_symbol
elif isinstance(site_labels, (list, tuple)):
txt = site_labels[idx] # idx runs from 0 to n_atoms
else:
@@ -363,8 +368,9 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
f"{', '.join(special_site_labels)}, dict, list)"
)
+ # Add labels
if site_labels:
- # place element symbol half way along outer wedge edge for
+ # Place element symbol half way along outer wedge edge for
# disordered sites
half_way = 2 * np.pi * (start + occupancy / 2)
direction = np.array([math.cos(half_way), math.sin(half_way)])
@@ -372,20 +378,23 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
(0.5 * radius) * direction if occupancy < 1 else (0, 0)
)
- bbox = dict(facecolor="none", edgecolor="none", pad=1)
- bbox.update(site_labels_bbox or {})
txt_kwds = dict(
- ha="center", va="center", bbox=bbox, **(label_kwargs or {})
+ ha="center",
+ va="center",
+ zorder=zorder,
+ **(label_kwargs or {}),
)
ax.text(*(xy + text_offset), txt, **txt_kwds)
start += occupancy
- else: # draw unit cell
+
+ # Plot unit cell
+ else:
cell_idx = idx - n_atoms
- # only draw line if not obstructed by an atom
+ # Only plot lines not obstructed by an atom
if z_indices[cell_idx] != -1:
hxy = unit_cell_lines[z_indices[cell_idx]]
- path = PathPatch(Path((xy + hxy, xy - hxy)))
+ path = PathPatch(Path((xy + hxy, xy - hxy)), zorder=zorder)
ax.add_patch(path)
if show_bonds:
@@ -404,8 +413,8 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
)
# If structure doesn't have any oxidation states yet, guess them from chemical
- # composition. Helps CrystalNN and other strategies to estimate better bond
- # connectivity. Uses getattr on site.specie since it's often a pymatgen Element
+ # composition. Use CrystalNN and other strategies to better estimate bond
+ # connectivity. Use getattr on site.specie since it's often a pymatgen Element
# which has no oxi_state
if not any(
hasattr(getattr(site, "specie", None), "oxi_state") for site in struct
@@ -413,7 +422,8 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
try:
struct.add_oxidation_state_by_guess()
except ValueError: # fails for disordered structures
- "Charge balance analysis requires integer values in Composition"
+ # Charge balance analysis requires integer values in Composition
+ pass
structure_graph = neighbor_strategy_cls().get_bonded_structure(struct)
diff --git a/pymatviz/utils.py b/pymatviz/utils.py
index 30a8c822..0358e627 100644
--- a/pymatviz/utils.py
+++ b/pymatviz/utils.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import ast
+import warnings
from contextlib import contextmanager
from functools import partial, wraps
from os.path import dirname
@@ -79,6 +80,13 @@
element_symbols[Z] = symbol
+class ExperimentalWarning(Warning):
+ """Warning for experimental features."""
+
+
+warnings.simplefilter("once", ExperimentalWarning)
+
+
def pretty_label(key: str, backend: Backend) -> str:
"""Map metric keys to their pretty labels."""
if backend not in VALID_BACKENDS:
diff --git a/pyproject.toml b/pyproject.toml
index f02c95f8..47851e36 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -48,7 +48,7 @@ test = [
"pytest-cov",
"weasyprint",
]
-data-src = ["matminer"]
+data-src = ["matminer", "mp_api",]
export-figs = ["kaleido"]
gh-pages = ["jupyter", "lazydocs", "nbconvert"]
# needed for pandas Stylers, see https://github.com/pandas-dev/pandas/blob/-/pyproject.toml
@@ -105,7 +105,7 @@ lint.ignore = [
"PTH",
"RUF001", # ambiguous-unicode-character-string
"S311",
- "SIM105", # Use contextlib.suppress(FileNotFoundError) instead of try-except-pass
+ "SIM105", # Use contextlib.suppress() instead of try-except-pass
"TD",
"TRY003",
]
diff --git a/tests/test_structure_viz.py b/tests/test_structure_viz.py
index 668de50e..9bb30d0b 100644
--- a/tests/test_structure_viz.py
+++ b/tests/test_structure_viz.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import pandas as pd
@@ -69,14 +69,10 @@ def test_plot_structure_2d_axis(axis: str | bool) -> None:
"site_labels",
[True, False, "symbol", "species", {"Fe": "Iron"}, {"Fe": 1.0}, ["Fe", "O"]],
)
-@pytest.mark.parametrize("site_labels_bbox", [None, {}, {"boxstyle": "round"}])
def test_plot_structure_2d_site_labels(
site_labels: bool | str | dict[str, str | float] | Sequence[str],
- site_labels_bbox: dict[str, Any] | None,
) -> None:
- ax = plot_structure_2d(
- disordered_struct, site_labels=site_labels, site_labels_bbox=site_labels_bbox
- )
+ ax = plot_structure_2d(disordered_struct, site_labels=site_labels)
if site_labels is False:
assert not ax.axes.texts
else: