|
| 1 | +"""This module calculates and plots pairwise radial distribution functions (RDFs) for |
| 2 | +pymatgen structures using plotly. |
| 3 | +
|
| 4 | +The main function, pairwise_rdfs, generates a plotly figure with facets for each |
| 5 | +pair of elements in the given structure. It supports customization of cutoff distance, |
| 6 | +bin size, specific element pairs to plot, reference line. |
| 7 | +
|
| 8 | +Example usage: |
| 9 | + structure = Structure(...) # Create or load a pymatgen Structure |
| 10 | + fig = pairwise_rdfs(structure, bin_size=0.1) |
| 11 | + fig.show() |
| 12 | +""" |
| 13 | + |
| 14 | +from typing import Any |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import plotly.graph_objects as go |
| 18 | +from plotly.subplots import make_subplots |
| 19 | +from pymatgen.core import Structure |
| 20 | +from scipy.signal import find_peaks |
| 21 | + |
| 22 | + |
| 23 | +def calculate_rdf( |
| 24 | + structure: Structure, |
| 25 | + center_species: str, |
| 26 | + neighbor_species: str, |
| 27 | + cutoff: float, |
| 28 | + n_bins: int, |
| 29 | +) -> tuple[np.ndarray, np.ndarray]: |
| 30 | + """Calculate the radial distribution function (RDF) for a given pair of species. |
| 31 | +
|
| 32 | + The RDF is normalized by the number of pairs and the shell volume density, which |
| 33 | + makes the RDF approach 1 for large separations in a homogeneous system. |
| 34 | +
|
| 35 | + Args: |
| 36 | + structure (Structure): A pymatgen Structure object. |
| 37 | + center_species (str): Symbol of the central species. |
| 38 | + neighbor_species (str): Symbol of the neighbor species. |
| 39 | + cutoff (float): Maximum distance for RDF calculation. |
| 40 | + n_bins (int): Number of bins for RDF calculation. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + tuple[np.ndarray, np.ndarray]: Arrays of (radii, g(r)) values. |
| 44 | + """ |
| 45 | + bin_size = cutoff / n_bins |
| 46 | + radii = np.linspace(0, cutoff, n_bins + 1)[1:] |
| 47 | + rdf = np.zeros(n_bins) |
| 48 | + |
| 49 | + center_indices = [ |
| 50 | + i for i, site in enumerate(structure) if site.specie.symbol == center_species |
| 51 | + ] |
| 52 | + neighbor_indices = [ |
| 53 | + i for i, site in enumerate(structure) if site.specie.symbol == neighbor_species |
| 54 | + ] |
| 55 | + |
| 56 | + for center_idx in center_indices: |
| 57 | + for neighbor_idx in neighbor_indices: |
| 58 | + if center_idx != neighbor_idx: |
| 59 | + distance = structure.get_distance(center_idx, neighbor_idx) |
| 60 | + if distance < cutoff: |
| 61 | + rdf[int(distance / bin_size)] += 1 |
| 62 | + |
| 63 | + # Normalize RDF by the number of center-neighbor pairs and shell volumes |
| 64 | + rdf = rdf / (len(center_indices) * len(neighbor_indices)) |
| 65 | + shell_volumes = 4 * np.pi * radii**2 * bin_size |
| 66 | + rdf = rdf / (shell_volumes / structure.volume) |
| 67 | + |
| 68 | + return radii, rdf |
| 69 | + |
| 70 | + |
| 71 | +def find_last_significant_peak( |
| 72 | + radii: np.ndarray, rdf: np.ndarray, prominence: float = 0.1 |
| 73 | +) -> float: |
| 74 | + """Find the position of the last significant peak in the RDF.""" |
| 75 | + peaks, properties = find_peaks(rdf, prominence=prominence, distance=5) |
| 76 | + if peaks.size > 0: |
| 77 | + # Sort peaks by prominence and select the last significant one |
| 78 | + sorted_peaks = peaks[np.argsort(properties["prominences"])] |
| 79 | + return radii[sorted_peaks[-1]] |
| 80 | + return radii[-1] |
| 81 | + |
| 82 | + |
| 83 | +def element_pair_rdfs( |
| 84 | + structure: Structure, |
| 85 | + cutoff: float = 15, |
| 86 | + n_bins: int = 75, |
| 87 | + bin_size: float | None = None, |
| 88 | + element_pairs: list[tuple[str, str]] | None = None, |
| 89 | + reference_line: dict[str, Any] | None = None, |
| 90 | +) -> go.Figure: |
| 91 | + """Generate a plotly figure of pairwise radial distribution functions (RDFs) for |
| 92 | + all (or a subset of) element pairs in a structure. |
| 93 | +
|
| 94 | + The RDF is the probability of finding a neighbor at a distance r from a central |
| 95 | + atom. Basically a histogram of pair-wise particle distances. |
| 96 | +
|
| 97 | + Args: |
| 98 | + structure (Structure): pymatgen Structure. |
| 99 | + cutoff (float, optional): Maximum distance for RDF calculation. Default is 15 Å. |
| 100 | + n_bins (int, optional): Number of bins for RDF calculation. Default is 75. |
| 101 | + bin_size (float, optional): Size of bins for RDF calculation. If specified, it |
| 102 | + overrides n_bins. Default is None. |
| 103 | + element_pairs (list[tuple[str, str]], optional): Element pairs to plot. |
| 104 | + If None, all pairs are plotted. |
| 105 | + reference_line (dict, optional): Keywords for reference line at g(r)=1 drawn |
| 106 | + with Figure.add_hline(). If None (default), no reference line is drawn. |
| 107 | +
|
| 108 | + Returns: |
| 109 | + go.Figure: A plotly figure with facets for each pairwise RDF. |
| 110 | +
|
| 111 | + Raises: |
| 112 | + ValueError: If the structure contains no sites, if invalid element pairs are |
| 113 | + provided, or if both n_bins and bin_size are specified. |
| 114 | + """ |
| 115 | + if not structure.sites: |
| 116 | + raise ValueError("input structure contains no sites") |
| 117 | + |
| 118 | + if n_bins != 75 and bin_size is not None: |
| 119 | + raise ValueError( |
| 120 | + f"Cannot specify both {n_bins=} and {bin_size=}. Pick one or the other." |
| 121 | + ) |
| 122 | + |
| 123 | + uniq_elements = sorted({site.specie.symbol for site in structure}) |
| 124 | + element_pairs = element_pairs or [ |
| 125 | + (e1, e2) for e1 in uniq_elements for e2 in uniq_elements if e1 <= e2 |
| 126 | + ] |
| 127 | + element_pairs = sorted(element_pairs) |
| 128 | + |
| 129 | + if extra_elems := {e1 for e1, _e2 in element_pairs} - set(uniq_elements): |
| 130 | + raise ValueError( |
| 131 | + f"Elements {extra_elems} in element_pairs are not present in the structure" |
| 132 | + ) |
| 133 | + |
| 134 | + # Calculate pairwise RDFs |
| 135 | + if bin_size is not None: |
| 136 | + n_bins = int(cutoff / bin_size) |
| 137 | + elem_pair_rdfs = { |
| 138 | + pair: calculate_rdf(structure, *pair, cutoff, n_bins) for pair in element_pairs |
| 139 | + } |
| 140 | + |
| 141 | + # Determine subplot layout |
| 142 | + n_pairs = len(element_pairs) |
| 143 | + n_cols = min(3, n_pairs) |
| 144 | + n_rows = (n_pairs + n_cols - 1) // n_cols |
| 145 | + |
| 146 | + # Create the plotly figure with facets |
| 147 | + fig = make_subplots( |
| 148 | + rows=n_rows, |
| 149 | + cols=n_cols, |
| 150 | + subplot_titles=[f"{e1}-{e2}" for e1, e2 in element_pairs], |
| 151 | + vertical_spacing=0.25 / n_rows, |
| 152 | + horizontal_spacing=0.15 / n_cols, |
| 153 | + ) |
| 154 | + |
| 155 | + # Add RDF traces to the figure |
| 156 | + for idx, (pair, (radii, rdf)) in enumerate(elem_pair_rdfs.items()): |
| 157 | + row, col = divmod(idx, n_cols) |
| 158 | + row += 1 |
| 159 | + col += 1 |
| 160 | + |
| 161 | + fig.add_scatter( |
| 162 | + x=radii, |
| 163 | + y=rdf, |
| 164 | + mode="lines", |
| 165 | + name=f"{pair[0]}-{pair[1]}", |
| 166 | + line=dict(color="royalblue"), |
| 167 | + showlegend=False, |
| 168 | + row=row, |
| 169 | + col=col, |
| 170 | + hovertemplate="r = %{x:.2f} Å<br>g(r) = %{y:.2f}<extra></extra>", |
| 171 | + ) |
| 172 | + |
| 173 | + # if one of the last n_col subplots, add x-axis label |
| 174 | + if idx >= n_pairs - n_cols: |
| 175 | + fig.update_xaxes(title_text="r (Å)", row=row, col=col) |
| 176 | + |
| 177 | + # Add reference line if specified |
| 178 | + if reference_line is not None: |
| 179 | + defaults = dict(line_dash="dash", line_color="red") |
| 180 | + fig.add_hline(y=1, row=row, col=col, **defaults | reference_line) |
| 181 | + |
| 182 | + # set subplot height/width and x/y axis labels |
| 183 | + fig.update_layout(height=200 * n_rows, width=350 * n_cols) |
| 184 | + fig.update_yaxes(title=dict(text="g(r)", standoff=0.1), col=1) |
| 185 | + |
| 186 | + return fig |
0 commit comments