Skip to content

Commit 379e9de

Browse files
authored
element_pair_rdfs plots radial distribution functions (RDFs) for element pairs in a structure (#203)
* improve set_plotly_template auto-complete with Literal type * add element_pair_rdfs(structure) -> go.Figure in new pymatviz/rdf.py module * add tests/test_rdf.py * remove ase.Atoms conversion to avoid new pkg dep * show element_pair_rdfs examples in readme * should have used save_and_compress_svg
1 parent 8670f22 commit 379e9de

File tree

8 files changed

+358
-2
lines changed

8 files changed

+358
-2
lines changed

assets/element-pair-rdfs-Na8Nb8O24.svg

+1
Loading

assets/element-pair-rdfs-Si16O32.svg

+1
Loading

examples/make_assets/rdf.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from matminer.datasets import load_dataset
2+
3+
import pymatviz as pmv
4+
from pymatviz.enums import Key
5+
6+
7+
pmv.set_plotly_template("pymatviz_white")
8+
9+
df_phonons = load_dataset("matbench_phonons")
10+
11+
12+
# get the 2 largest structures
13+
df_phonons[Key.n_sites] = df_phonons[Key.structure].apply(len)
14+
15+
# plot element-pair RDFs for each structure
16+
for struct in df_phonons.nlargest(2, Key.n_sites)[Key.structure]:
17+
fig = pmv.element_pair_rdfs(struct, n_bins=100, cutoff=10)
18+
formula = struct.formula
19+
fig.layout.title.update(text=f"Pairwise RDFs - {formula}", x=0.5, y=0.98)
20+
fig.layout.margin = dict(l=40, r=0, t=50, b=0)
21+
22+
fig.show()
23+
pmv.io.save_and_compress_svg(fig, f"element-pair-rdfs-{formula.replace(' ', '')}")

pymatviz/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import builtins
1415
from importlib.metadata import PackageNotFoundError, version
1516

1617
import matplotlib.pyplot as plt
@@ -30,6 +31,7 @@
3031
powerups,
3132
process_data,
3233
ptable,
34+
rdf,
3335
relevance,
3436
sankey,
3537
scatter,
@@ -56,6 +58,7 @@
5658
ptable_lines,
5759
ptable_scatters,
5860
)
61+
from pymatviz.rdf import element_pair_rdfs
5962
from pymatviz.relevance import precision_recall_curve, roc_curve
6063
from pymatviz.sankey import sankey_from_2_df_cols
6164
from pymatviz.scatter import (
@@ -94,7 +97,7 @@
9497
pass # package not installed
9598

9699

97-
IS_IPYTHON = hasattr(__builtins__, "__IPYTHON__")
100+
IS_IPYTHON = hasattr(builtins, "__IPYTHON__")
98101

99102
# define a sensible order for crystal systems across plots
100103
crystal_sys_order = (

pymatviz/rdf.py

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""This module calculates and plots pairwise radial distribution functions (RDFs) for
2+
pymatgen structures using plotly.
3+
4+
The main function, pairwise_rdfs, generates a plotly figure with facets for each
5+
pair of elements in the given structure. It supports customization of cutoff distance,
6+
bin size, specific element pairs to plot, reference line.
7+
8+
Example usage:
9+
structure = Structure(...) # Create or load a pymatgen Structure
10+
fig = pairwise_rdfs(structure, bin_size=0.1)
11+
fig.show()
12+
"""
13+
14+
from typing import Any
15+
16+
import numpy as np
17+
import plotly.graph_objects as go
18+
from plotly.subplots import make_subplots
19+
from pymatgen.core import Structure
20+
from scipy.signal import find_peaks
21+
22+
23+
def calculate_rdf(
24+
structure: Structure,
25+
center_species: str,
26+
neighbor_species: str,
27+
cutoff: float,
28+
n_bins: int,
29+
) -> tuple[np.ndarray, np.ndarray]:
30+
"""Calculate the radial distribution function (RDF) for a given pair of species.
31+
32+
The RDF is normalized by the number of pairs and the shell volume density, which
33+
makes the RDF approach 1 for large separations in a homogeneous system.
34+
35+
Args:
36+
structure (Structure): A pymatgen Structure object.
37+
center_species (str): Symbol of the central species.
38+
neighbor_species (str): Symbol of the neighbor species.
39+
cutoff (float): Maximum distance for RDF calculation.
40+
n_bins (int): Number of bins for RDF calculation.
41+
42+
Returns:
43+
tuple[np.ndarray, np.ndarray]: Arrays of (radii, g(r)) values.
44+
"""
45+
bin_size = cutoff / n_bins
46+
radii = np.linspace(0, cutoff, n_bins + 1)[1:]
47+
rdf = np.zeros(n_bins)
48+
49+
center_indices = [
50+
i for i, site in enumerate(structure) if site.specie.symbol == center_species
51+
]
52+
neighbor_indices = [
53+
i for i, site in enumerate(structure) if site.specie.symbol == neighbor_species
54+
]
55+
56+
for center_idx in center_indices:
57+
for neighbor_idx in neighbor_indices:
58+
if center_idx != neighbor_idx:
59+
distance = structure.get_distance(center_idx, neighbor_idx)
60+
if distance < cutoff:
61+
rdf[int(distance / bin_size)] += 1
62+
63+
# Normalize RDF by the number of center-neighbor pairs and shell volumes
64+
rdf = rdf / (len(center_indices) * len(neighbor_indices))
65+
shell_volumes = 4 * np.pi * radii**2 * bin_size
66+
rdf = rdf / (shell_volumes / structure.volume)
67+
68+
return radii, rdf
69+
70+
71+
def find_last_significant_peak(
72+
radii: np.ndarray, rdf: np.ndarray, prominence: float = 0.1
73+
) -> float:
74+
"""Find the position of the last significant peak in the RDF."""
75+
peaks, properties = find_peaks(rdf, prominence=prominence, distance=5)
76+
if peaks.size > 0:
77+
# Sort peaks by prominence and select the last significant one
78+
sorted_peaks = peaks[np.argsort(properties["prominences"])]
79+
return radii[sorted_peaks[-1]]
80+
return radii[-1]
81+
82+
83+
def element_pair_rdfs(
84+
structure: Structure,
85+
cutoff: float = 15,
86+
n_bins: int = 75,
87+
bin_size: float | None = None,
88+
element_pairs: list[tuple[str, str]] | None = None,
89+
reference_line: dict[str, Any] | None = None,
90+
) -> go.Figure:
91+
"""Generate a plotly figure of pairwise radial distribution functions (RDFs) for
92+
all (or a subset of) element pairs in a structure.
93+
94+
The RDF is the probability of finding a neighbor at a distance r from a central
95+
atom. Basically a histogram of pair-wise particle distances.
96+
97+
Args:
98+
structure (Structure): pymatgen Structure.
99+
cutoff (float, optional): Maximum distance for RDF calculation. Default is 15 Å.
100+
n_bins (int, optional): Number of bins for RDF calculation. Default is 75.
101+
bin_size (float, optional): Size of bins for RDF calculation. If specified, it
102+
overrides n_bins. Default is None.
103+
element_pairs (list[tuple[str, str]], optional): Element pairs to plot.
104+
If None, all pairs are plotted.
105+
reference_line (dict, optional): Keywords for reference line at g(r)=1 drawn
106+
with Figure.add_hline(). If None (default), no reference line is drawn.
107+
108+
Returns:
109+
go.Figure: A plotly figure with facets for each pairwise RDF.
110+
111+
Raises:
112+
ValueError: If the structure contains no sites, if invalid element pairs are
113+
provided, or if both n_bins and bin_size are specified.
114+
"""
115+
if not structure.sites:
116+
raise ValueError("input structure contains no sites")
117+
118+
if n_bins != 75 and bin_size is not None:
119+
raise ValueError(
120+
f"Cannot specify both {n_bins=} and {bin_size=}. Pick one or the other."
121+
)
122+
123+
uniq_elements = sorted({site.specie.symbol for site in structure})
124+
element_pairs = element_pairs or [
125+
(e1, e2) for e1 in uniq_elements for e2 in uniq_elements if e1 <= e2
126+
]
127+
element_pairs = sorted(element_pairs)
128+
129+
if extra_elems := {e1 for e1, _e2 in element_pairs} - set(uniq_elements):
130+
raise ValueError(
131+
f"Elements {extra_elems} in element_pairs are not present in the structure"
132+
)
133+
134+
# Calculate pairwise RDFs
135+
if bin_size is not None:
136+
n_bins = int(cutoff / bin_size)
137+
elem_pair_rdfs = {
138+
pair: calculate_rdf(structure, *pair, cutoff, n_bins) for pair in element_pairs
139+
}
140+
141+
# Determine subplot layout
142+
n_pairs = len(element_pairs)
143+
n_cols = min(3, n_pairs)
144+
n_rows = (n_pairs + n_cols - 1) // n_cols
145+
146+
# Create the plotly figure with facets
147+
fig = make_subplots(
148+
rows=n_rows,
149+
cols=n_cols,
150+
subplot_titles=[f"{e1}-{e2}" for e1, e2 in element_pairs],
151+
vertical_spacing=0.25 / n_rows,
152+
horizontal_spacing=0.15 / n_cols,
153+
)
154+
155+
# Add RDF traces to the figure
156+
for idx, (pair, (radii, rdf)) in enumerate(elem_pair_rdfs.items()):
157+
row, col = divmod(idx, n_cols)
158+
row += 1
159+
col += 1
160+
161+
fig.add_scatter(
162+
x=radii,
163+
y=rdf,
164+
mode="lines",
165+
name=f"{pair[0]}-{pair[1]}",
166+
line=dict(color="royalblue"),
167+
showlegend=False,
168+
row=row,
169+
col=col,
170+
hovertemplate="r = %{x:.2f} Å<br>g(r) = %{y:.2f}<extra></extra>",
171+
)
172+
173+
# if one of the last n_col subplots, add x-axis label
174+
if idx >= n_pairs - n_cols:
175+
fig.update_xaxes(title_text="r (Å)", row=row, col=col)
176+
177+
# Add reference line if specified
178+
if reference_line is not None:
179+
defaults = dict(line_dash="dash", line_color="red")
180+
fig.add_hline(y=1, row=row, col=col, **defaults | reference_line)
181+
182+
# set subplot height/width and x/y axis labels
183+
fig.update_layout(height=200 * n_rows, width=350 * n_cols)
184+
fig.update_yaxes(title=dict(text="g(r)", standoff=0.1), col=1)
185+
186+
return fig

pymatviz/templates.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from importlib.metadata import PackageNotFoundError, version
6+
from typing import Literal
67

78
import matplotlib.pyplot as plt
89
import plotly.express as px
@@ -62,7 +63,9 @@
6263
)
6364

6465

65-
def set_plotly_template(template: str | go.layout.Template) -> None:
66+
def set_plotly_template(
67+
template: Literal["pymatviz_white", "pymatviz_dark"] | str | go.layout.Template, # noqa: PYI051
68+
) -> None:
6669
"""Set the default plotly express and graph objects template.
6770
6871
Args:

readme.md

+11
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,17 @@ See [`pymatviz/xrd.py`](pymatviz/xrd.py).
168168
[xrd-pattern]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern.svg
169169
[xrd-pattern-multiple]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern-multiple.svg
170170

171+
## Radial Distribution Functions
172+
173+
See [`pymatviz/rdf.py`](pymatviz/rdf.py).
174+
175+
| [`rdf_plot(rdf)`](pymatviz/rdf.py) | [`rdf_plot(rdf, rdf2)`](pymatviz/rdf.py) |
176+
| :--------------------------------: | :--------------------------------------: |
177+
| ![element-pair-rdfs-Si16O32] | ![element-pair-rdfs-Na8Nb8O24] |
178+
179+
[element-pair-rdfs-Si16O32]: examples/make_assets/element-pair-rdfs-Si16O32.svg
180+
[element-pair-rdfs-Na8Nb8O24]: examples/make_assets/element-pair-rdfs-Na8Nb8O24.svg
181+
171182
## Uncertainty
172183

173184
See [`pymatviz/uncertainty.py`](pymatviz/uncertainty.py).

0 commit comments

Comments
 (0)