Skip to content

Remove text background and fix z-order in structure_viz #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assets/matbench-phonons-structures-2d.svg
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this file needs regenerating. the oxidation states are gone

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't notice this, yes as I mentioned here #139 (comment), the fetching behavior of oxidation states still inconsistent for some reason.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/_generate_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from matminer.datasets import load_dataset
from monty.io import zopen
from monty.json import MontyDecoder
from mp_api.client import MPRester
from pymatgen.core.periodic_table import Element
from pymatgen.ext.matproj import MPRester
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands
from pymatgen.phonon.dos import PhononDos
from tqdm import tqdm
Expand Down
105 changes: 65 additions & 40 deletions pymatviz/structure_viz.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
"""2D plots of pymatgen structures with matplotlib."""
"""2D plots of pymatgen structures with matplotlib.

plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were
inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib.
"""

from __future__ import annotations

import math
import warnings
from itertools import product
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import PathPatch, Wedge
from matplotlib.path import Path
from pymatgen.analysis.local_env import CrystalNN, NearNeighbors
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from pymatviz.utils import covalent_radii, jmol_colors


if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any, Literal

from numpy.typing import ArrayLike
from pymatgen.core import Structure


class ExperimentalWarning(Warning):
"""Used for experimental show_bonds feature."""
"""Warning for experimental features."""


warnings.simplefilter("once", ExperimentalWarning)


# plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were
# inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib


def _angles_to_rotation_matrix(
angles: str, rotation: ArrayLike | None = None
) -> ArrayLike:
Expand All @@ -52,8 +54,10 @@ def _angles_to_rotation_matrix(
"""
if rotation is None:
rotation = np.eye(3)

# Return initial rotation matrix if no angles
if not angles:
return rotation.copy() # return initial rotation matrix if no angles
return rotation.copy()

for angle in angles.split(","):
radians = math.radians(float(angle[:-1]))
Expand Down Expand Up @@ -128,7 +132,7 @@ def plot_structure_2d(
standardize_struct: bool | None = None,
axis: bool | str = "off",
) -> plt.Axes:
"""Plot pymatgen structures in 2d with matplotlib.
"""Plot pymatgen structures in 2D with matplotlib.

Inspired by ASE's ase.visualize.plot.plot_atoms()
https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib
Expand All @@ -137,7 +141,7 @@ def plot_structure_2d(

For example, these two snippets should give very similar output:

```py
```python
from pymatgen.ext.matproj import MPRester

mp_19017 = MPRester().get_structure_by_material_id("mp-19017")
Expand Down Expand Up @@ -182,10 +186,10 @@ def plot_structure_2d(
colors, either a named color (str) or rgb(a) values like (0.2, 0.3, 0.6).
Defaults to JMol colors (https://jmol.sourceforge.net/jscolors).
scale (float, optional): Scaling of the plotted atoms and lines. Defaults to 1.
show_unit_cell (bool, optional): Whether to draw unit cell. Defaults to True.
show_bonds (bool | NearNeighbors, optional): Whether to draw bonds. If True, use
show_unit_cell (bool, optional): Whether to plot unit cell. Defaults to True.
show_bonds (bool | NearNeighbors, optional): Whether to plot bonds. If True, use
pymatgen.analysis.local_env.CrystalNN to infer the structure's connectivity.
If False, don't draw bonds. If a subclass of
If False, don't plot bonds. If a subclass of
pymatgen.analysis.local_env.NearNeighbors, use that to determine
connectivity. Options include VoronoiNN, MinimumDistanceNN, OpenBabelNN,
CovalentBondNN, dtc. Defaults to True.
Expand All @@ -202,7 +206,7 @@ def plot_structure_2d(
label_kwargs (dict, optional): Keyword arguments for matplotlib.text.Text like
{"fontsize": 14}. Defaults to None.
bond_kwargs (dict, optional): Keyword arguments for the matplotlib.path.Path
class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
class used to plot chemical bonds. Allowed are edgecolor, facecolor, color,
linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle.
Defaults to None.
standardize_struct (bool, optional): Whether to standardize the structure using
Expand All @@ -228,7 +232,7 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
f" the number of sites in the crystal ({len(struct)=})"
)

# default behavior in case of no user input is to standardize if any fractional
# Default behavior in case of no user input: standardize if any fractional
# coordinates are negative
has_sites_outside_unit_cell = any(any(site.frac_coords < 0) for site in struct)
if standardize_struct is False and has_sites_outside_unit_cell:
Expand All @@ -240,9 +244,9 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
elif standardize_struct is None:
standardize_struct = has_sites_outside_unit_cell
if standardize_struct:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

struct = SpacegroupAnalyzer(struct).get_conventional_standard_structure()

# Get default colors
if colors is None:
colors = jmol_colors

Expand All @@ -257,15 +261,14 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
else:
# atomic_radii is assumed to be a map from element symbols to atomic radii
# make sure all present elements are assigned a radius
missing = set(elements_at_sites) - set(atomic_radii)
if missing:
if missing := set(elements_at_sites) - set(atomic_radii):
raise ValueError(f"atomic_radii is missing keys: {missing}")

radii_at_sites = np.array(
[atomic_radii[el] for el in elements_at_sites] # type: ignore[index]
)

n_atoms = len(struct)
# Generate lines for unit cell
rotation_matrix = _angles_to_rotation_matrix(rotation)
unit_cell = struct.lattice.matrix

Expand All @@ -280,28 +283,35 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
unit_cell_lines = None
cell_vertices = None

# Zip atoms and unit cell lines together
n_atoms = len(struct)
n_lines = len(lines)

positions = np.empty((n_atoms + n_lines, 3))
site_coords = np.array([site.coords for site in struct])
positions[:n_atoms] = site_coords
positions[n_atoms:] = lines

# determine which lines should be hidden behind other objects
# Determine which unit cell line should be hidden behind other objects
for idx in range(n_lines):
this_layer = unit_cell_lines[z_indices[idx]]

occluded_top = ((site_coords - lines[idx] + this_layer) ** 2).sum(
1
) < radii_at_sites**2

occluded_bottom = ((site_coords - lines[idx] - this_layer) ** 2).sum(
1
) < radii_at_sites**2

if any(occluded_top & occluded_bottom):
z_indices[idx] = -1

# Apply rotation matrix
positions = np.dot(positions, rotation_matrix)
rotated_site_coords = positions[:n_atoms]

# Normalize wedge positions
min_coords = (rotated_site_coords - radii_at_sites[:, None]).min(0)
max_coords = (rotated_site_coords + radii_at_sites[:, None]).max(0)

Expand All @@ -316,19 +326,23 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
positions *= scale
positions -= offset

# Rotate and scale unit cell lines
if n_lines > 0:
unit_cell_lines = np.dot(unit_cell_lines, rotation_matrix)[:, :2] * scale

special_site_labels = ("symbol", "species")
# sort positions by 3rd dim so we draw from back to front in z-axis (out-of-plane)
# Sort positions by 3rd dim to plot from back to front along z-axis (out-of-plane)
for idx in positions[:, 2].argsort():
xy = positions[idx, :2]
start = 0
zorder = positions[idx][2]

if idx < n_atoms:
# loop over all species on a site (usually just 1 for ordered sites)
for specie, occupancy in struct[idx].species.items():
# strip oxidation state from element symbol (e.g. Ta5+ to Ta)
elem_symbol = specie.symbol
# Loop over all species on a site (usually just 1 for ordered sites)
for species, occupancy in struct[idx].species.items():
# Strip oxidation state from element symbol (e.g. Ta5+ to Ta)
elem_symbol = species.symbol

radius = atomic_radii[elem_symbol] * scale # type: ignore[index]
face_color = colors[elem_symbol]
wedge = Wedge(
Expand All @@ -338,23 +352,25 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
360 * (start + occupancy),
facecolor=face_color,
edgecolor="black",
zorder=zorder,
)
ax.add_patch(wedge)

# Generate labels
if site_labels == "symbol":
txt = elem_symbol
elif site_labels in ("species", True):
txt = specie
txt = species
elif site_labels is False:
txt = ""
elif isinstance(site_labels, dict):
# try element incl. oxidation state as dict key first (e.g. Na+),
# Try element incl. oxidation state as dict key first (e.g. Na+),
# then just element as fallback
txt = site_labels.get(
repr(specie), site_labels.get(elem_symbol, "")
repr(species), site_labels.get(elem_symbol, "")
)
if txt in special_site_labels:
txt = specie if txt == "species" else elem_symbol
txt = species if txt == "species" else elem_symbol
elif isinstance(site_labels, (list, tuple)):
txt = site_labels[idx] # idx runs from 0 to n_atoms
else:
Expand All @@ -363,29 +379,37 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
f"{', '.join(special_site_labels)}, dict, list)"
)

# Add labels
if site_labels:
# place element symbol half way along outer wedge edge for
# Place element symbol half way along outer wedge edge for
# disordered sites
half_way = 2 * np.pi * (start + occupancy / 2)
direction = np.array([math.cos(half_way), math.sin(half_way)])
text_offset = (
(0.5 * radius) * direction if occupancy < 1 else (0, 0)
)

bbox = dict(facecolor="none", edgecolor="none", pad=1)
bbox.update(site_labels_bbox or {})
bbox = dict(facecolor="none", edgecolor="none", pad=1, alpha=0)
bbox |= site_labels_bbox or {}

txt_kwds = dict(
ha="center", va="center", bbox=bbox, **(label_kwargs or {})
ha="center",
va="center",
zorder=zorder,
bbox=bbox,
**(label_kwargs or {}),
)
ax.text(*(xy + text_offset), txt, **txt_kwds)

start += occupancy
else: # draw unit cell

# Plot unit cell
else:
cell_idx = idx - n_atoms
# only draw line if not obstructed by an atom
# Only plot lines not obstructed by an atom
if z_indices[cell_idx] != -1:
hxy = unit_cell_lines[z_indices[cell_idx]]
path = PathPatch(Path((xy + hxy, xy - hxy)))
path = PathPatch(Path((xy + hxy, xy - hxy)), zorder=zorder)
ax.add_patch(path)

if show_bonds:
Expand All @@ -404,16 +428,17 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
)

# If structure doesn't have any oxidation states yet, guess them from chemical
# composition. Helps CrystalNN and other strategies to estimate better bond
# connectivity. Uses getattr on site.specie since it's often a pymatgen Element
# composition. Use CrystalNN and other strategies to better estimate bond
# connectivity. Use getattr on site.specie since it's often a pymatgen Element
# which has no oxi_state
if not any(
hasattr(getattr(site, "specie", None), "oxi_state") for site in struct
):
try:
struct.add_oxidation_state_by_guess()
except ValueError: # fails for disordered structures
"Charge balance analysis requires integer values in Composition"
# Charge balance analysis requires integer values in Composition
pass

structure_graph = neighbor_strategy_cls().get_bonded_structure(struct)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ test = [
"pytest-cov",
"weasyprint",
]
data-src = ["matminer"]
data-src = ["matminer", "mp_api",]
export-figs = ["kaleido"]
gh-pages = ["jupyter", "lazydocs", "nbconvert"]
# needed for pandas Stylers, see https://github.com/pandas-dev/pandas/blob/-/pyproject.toml
Expand Down Expand Up @@ -105,7 +105,7 @@ lint.ignore = [
"PTH",
"RUF001", # ambiguous-unicode-character-string
"S311",
"SIM105", # Use contextlib.suppress(FileNotFoundError) instead of try-except-pass
"SIM105", # Use contextlib.suppress() instead of try-except-pass
"TD",
"TRY003",
]
Expand Down