Skip to content

Commit 5ab3390

Browse files
committed
fix "COM812", # trailing comma missing
1 parent 476560b commit 5ab3390

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+2927
-2550
lines changed

examples/dataset_exploration/boltztrap_mp/explore_boltztrap_mp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242

4343
# %%
4444
fig = ptable_heatmap(
45-
count_elements(df_boltz[Key.formula]), log=True, return_type="figure"
45+
count_elements(df_boltz[Key.formula]),
46+
log=True,
47+
return_type="figure",
4648
)
4749
fig.suptitle("Elements in BoltzTraP MP dataset")
4850
pmv.save_fig(fig, "boltztrap_mp-ptable-heatmap.pdf")
@@ -65,6 +67,9 @@
6567

6668
# %%
6769
df_boltz.sort_values("pf_n", ascending=False).head(1000).hist(
68-
bins=50, log=True, layout=[2, 3], figsize=[18, 8]
70+
bins=50,
71+
log=True,
72+
layout=[2, 3],
73+
figsize=[18, 8],
6974
)
7075
plt.suptitle("BoltzTraP MP")

examples/dataset_exploration/matbench/dielectric/explore_dielectric.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
df_diel[Key.n_wyckoff] = df_diel.wyckoff.map(count_wyckoff_positions)
4141

4242
df_diel[Key.crystal_system] = df_diel[Key.spg_num].map(
43-
pmv.utils.crystal_sys_from_spg_num
43+
pmv.utils.crystal_sys_from_spg_num,
4444
)
4545

4646
df_diel[Key.volume] = [x.volume for x in df_diel[Key.structure]]
@@ -49,7 +49,9 @@
4949

5050
# %%
5151
fig = pmv.ptable_heatmap(
52-
pmv.count_elements(df_diel[Key.formula]), log=True, return_type="figure"
52+
pmv.count_elements(df_diel[Key.formula]),
53+
log=True,
54+
return_type="figure",
5355
)
5456
fig.suptitle("Elemental prevalence in the Matbench dielectric dataset")
5557
pmv.save_fig(fig, "dielectric-ptable-heatmap.pdf")
@@ -88,7 +90,8 @@
8890

8991
x_ticks = {} # custom x axis tick labels
9092
for cry_sys, df_group in sorted(
91-
df_diel.groupby(Key.crystal_system), key=lambda x: pmv.crystal_sys_order.index(x[0])
93+
df_diel.groupby(Key.crystal_system),
94+
key=lambda x: pmv.crystal_sys_order.index(x[0]),
9295
):
9396
x_ticks[cry_sys] = (
9497
f"<b>{cry_sys}</b><br>"
@@ -101,7 +104,8 @@
101104
fig.layout.margin = dict(b=10, l=10, r=10, t=50)
102105
fig.layout.showlegend = False
103106
fig.layout.xaxis = reusable_x_axis = dict(
104-
tickvals=list(range(len(pmv.crystal_sys_order))), ticktext=list(x_ticks.values())
107+
tickvals=list(range(len(pmv.crystal_sys_order))),
108+
ticktext=list(x_ticks.values()),
105109
)
106110

107111

@@ -130,7 +134,8 @@ def rgb_color(val: float, max_val: float) -> str:
130134

131135
x_ticks = {}
132136
for cry_sys, df_group in sorted(
133-
df_diel.groupby(Key.crystal_system), key=lambda x: pmv.crystal_sys_order.index(x[0])
137+
df_diel.groupby(Key.crystal_system),
138+
key=lambda x: pmv.crystal_sys_order.index(x[0]),
134139
):
135140
n_wyckoff = df_group[Key.n_wyckoff].mean()
136141
clr = rgb_color(n_wyckoff, 14)

examples/dataset_exploration/matbench/log_g+kvrh/explore_log_g+krvh.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,20 @@
4141
df_grvh[[Key.spg_symbol, Key.spg_num]] = [
4242
struct.get_space_group_info()
4343
for struct in tqdm(
44-
df_grvh[Key.structure], desc="Getting matbench_log_gvrh spacegroups"
44+
df_grvh[Key.structure],
45+
desc="Getting matbench_log_gvrh spacegroups",
4546
)
4647
]
4748
df_grvh[Key.crystal_system] = df_grvh[Key.spg_num].map(
48-
pmv.utils.crystal_sys_from_spg_num
49+
pmv.utils.crystal_sys_from_spg_num,
4950
)
5051

5152

5253
df_grvh[Key.wyckoff] = [
5354
get_protostructure_label_from_spglib(struct)
5455
for struct in tqdm(
55-
df_grvh[Key.structure], desc="Getting matbench_log_gvrh Wyckoff strings"
56+
df_grvh[Key.structure],
57+
desc="Getting matbench_log_gvrh Wyckoff strings",
5658
)
5759
]
5860
df_grvh[Key.n_wyckoff] = df_grvh.wyckoff.map(count_wyckoff_positions)
@@ -133,7 +135,9 @@ def has_isolated_atom(crystal: Structure, radius: float = 5) -> bool:
133135
df_grvh[Key.formula] = df_grvh[Key.structure].map(lambda struct: struct.formula)
134136

135137
fig = ptable_heatmap(
136-
count_elements(df_grvh[Key.formula]), log=True, return_type="figure"
138+
count_elements(df_grvh[Key.formula]),
139+
log=True,
140+
return_type="figure",
137141
)
138142
fig.suptitle("Elemental prevalence in the Matbench bulk/shear modulus datasets")
139143
pmv.save_fig(fig, "log_gvrh-ptable-heatmap.pdf")
@@ -172,7 +176,8 @@ def rgb_color(val: float, max_val: float) -> str:
172176

173177
x_ticks = {}
174178
for cry_sys, df_group in sorted(
175-
df_grvh.groupby(Key.crystal_system), key=lambda x: crystal_sys_order.index(x[0])
179+
df_grvh.groupby(Key.crystal_system),
180+
key=lambda x: crystal_sys_order.index(x[0]),
176181
):
177182
n_wyckoff_top = df_group[Key.n_wyckoff].mean()
178183
clr = rgb_color(n_wyckoff_top, 14)

examples/dataset_exploration/matbench/perovskites/explore_perovskites.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
df_perov[Key.formula] = df_perov[Key.structure].map(lambda cryst: cryst.formula)
3434

3535
df_perov[Key.crystal_system] = df_perov[Key.spg_num].map(
36-
pmv.utils.crystal_sys_from_spg_num
36+
pmv.utils.crystal_sys_from_spg_num,
3737
)
3838

3939

@@ -49,7 +49,9 @@
4949

5050
# %%
5151
fig = ptable_heatmap(
52-
count_elements(df_perov[Key.formula]), log=True, return_type="figure"
52+
count_elements(df_perov[Key.formula]),
53+
log=True,
54+
return_type="figure",
5355
)
5456
fig.suptitle("Elements in Matbench Perovskites dataset")
5557
pmv.save_fig(fig, "perovskites-ptable-heatmap.pdf")

examples/dataset_exploration/matbench/phonons/explore_phonons.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
df_phonon[Key.volume] = df_phonon[Key.structure].map(lambda cryst: cryst.volume)
4444

4545
fig = ptable_heatmap(
46-
count_elements(df_phonon[Key.formula]), log=True, return_type="figure"
46+
count_elements(df_phonon[Key.formula]),
47+
log=True,
48+
return_type="figure",
4749
)
4850
fig.suptitle("Elemental prevalence in the Matbench phonons dataset")
4951
pmv.save_fig(fig, "phonons-ptable-heatmap.pdf")

examples/dataset_exploration/matbench/steels/explore_steels.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626

2727
# %%
2828
fig = ptable_heatmap(
29-
count_elements(df_steels[Key.composition]), log=True, return_type="figure"
29+
count_elements(df_steels[Key.composition]),
30+
log=True,
31+
return_type="figure",
3032
)
3133
fig.suptitle("Elemental prevalence in the Matbench steels dataset")
3234
pmv.save_fig(fig, "steels-ptable-heatmap.pdf")

examples/dataset_exploration/matpes/eda.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
total_force_col = "Σ|force<sub>i</sub>| (eV/Å)"
8686
df_pbe[total_force_col] = df_pbe[Key.forces].map(lambda arr: np.abs(arr).sum(axis=1))
8787
df_r2scan[total_force_col] = df_r2scan[Key.forces].map(
88-
lambda arr: np.abs(arr).sum(axis=1)
88+
lambda arr: np.abs(arr).sum(axis=1),
8989
)
9090

9191
fig = go.Figure()
@@ -149,7 +149,8 @@
149149
}
150150

151151
fig = pmv.ptable_heatmap_splits(
152-
per_elem_cohesive_energy, cbar_title=f"{col_name.label} (eV)"
152+
per_elem_cohesive_energy,
153+
cbar_title=f"{col_name.label} (eV)",
153154
)
154155

155156

@@ -201,7 +202,9 @@
201202

202203
# %% spacegroup histogram
203204
fig = pmv.spacegroup_bar(
204-
df_r2scan[Key.spg_num], title="r2SCAN spacegroup histogram", log=True
205+
df_r2scan[Key.spg_num],
206+
title="r2SCAN spacegroup histogram",
207+
log=True,
205208
)
206209
fig.show()
207210
pmv.save_fig(fig, "r2scan-spacegroup-hist.pdf")
@@ -230,7 +233,9 @@
230233
for site, force in zip(struct, forces, strict=True)
231234
}
232235
for struct, forces in zip(
233-
df_r2scan[Key.structure], df_r2scan[Key.forces], strict=True
236+
df_r2scan[Key.structure],
237+
df_r2scan[Key.forces],
238+
strict=True,
234239
)
235240
).mean()
236241

examples/dataset_exploration/ricci_carrier_transport/convert_dtype+add_strucs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
# %%
1717
df_carrier = pd.concat(
18-
[df_carrier, pd.json_normalize(df_carrier.data)], axis="columns"
18+
[df_carrier, pd.json_normalize(df_carrier.data)],
19+
axis="columns",
1920
).drop(columns=["data", "is_public", "project"])
2021

2122
df_carrier = df_carrier.set_index("identifier")
@@ -99,5 +100,6 @@
99100

100101
# %%
101102
df_carrier.to_json(
102-
"cleaned_ricci_boltztrap_mp_tabular.json.gz", default_handler=lambda x: x.as_dict()
103+
"cleaned_ricci_boltztrap_mp_tabular.json.gz",
104+
default_handler=lambda x: x.as_dict(),
103105
)

examples/dataset_exploration/ricci_carrier_transport/explore_carrier_transport.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040

4141
# %%
4242
fig = ptable_heatmap(
43-
count_elements(df_carrier.pretty_formula.dropna()), log=True, return_type="figure"
43+
count_elements(df_carrier.pretty_formula.dropna()),
44+
log=True,
45+
return_type="figure",
4446
)
4547
fig.suptitle("Elemental prevalence in the Ricci Carrier Transport dataset")
4648
pmv.save_fig(fig, "carrier-transport-ptable-heatmap.pdf")
@@ -55,7 +57,7 @@
5557
# %%
5658
ax = df_carrier[["S.p [µV/K]", "S.n [µV/K]"]].hist(bins=50, log=True, figsize=[18, 8])
5759
plt.suptitle(
58-
"Ricci Carrier Transport dataset histograms for n- and p-type Seebeck coefficients"
60+
"Ricci Carrier Transport dataset histograms for n- and p-type Seebeck coefficients",
5961
)
6062
pmv.save_fig(ax, "carrier-transport-seebeck-n+p.pdf")
6163

examples/dataset_exploration/wbm/explore_wbm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
# %% download wbm-summary.csv (12 MB)
1515
df_wbm = pd.read_csv("https://figshare.com/ndownloader/files/44225498").set_index(
16-
Key.mat_id, drop=False
16+
Key.mat_id,
17+
drop=False,
1718
)
1819

1920
df_wbm["batch_idx"] = df_wbm.index.str.split("-").str[2].astype(int)
@@ -69,7 +70,8 @@
6970
fig.layout.margin = dict(b=10, l=10, r=10, t=50)
7071
fig.layout.showlegend = False
7172
fig.layout.xaxis = dict(
72-
tickvals=list(range(len(crystal_sys_order))), ticktext=list(x_ticks.values())
73+
tickvals=list(range(len(crystal_sys_order))),
74+
ticktext=list(x_ticks.values()),
7375
)
7476
fig.update_traces(hoverinfo="skip", hovertemplate=None)
7577

examples/diatomics/mace_pair_repulsion.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def timer(label: str = "") -> Generator[None, None, None]:
3737

3838

3939
def generate_diatomics(
40-
symbol0: str, symbol1: str, distances: list[float] | np.ndarray
40+
symbol0: str,
41+
symbol1: str,
42+
distances: list[float] | np.ndarray,
4143
) -> list[Atoms]:
4244
"""Build diatomic molecules in vacuum for given distances.
4345
@@ -56,7 +58,10 @@ def generate_diatomics(
5658

5759

5860
def calc_one_pair(
59-
z0: int, z1: int, calc: MACECalculator, distances: list[float] | np.ndarray
61+
z0: int,
62+
z1: int,
63+
calc: MACECalculator,
64+
distances: list[float] | np.ndarray,
6065
) -> list[float]:
6166
"""Calculate potential energy for a pair of elements at given distances.
6267
@@ -72,7 +77,9 @@ def calc_one_pair(
7277
return [
7378
calc.get_potential_energy(at)
7479
for at in generate_diatomics(
75-
chemical_symbols[z0], chemical_symbols[z1], distances
80+
chemical_symbols[z0],
81+
chemical_symbols[z1],
82+
distances,
7683
)
7784
]
7885

examples/diatomics/plot.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626

2727
# %%
2828
def plot_on_ax(
29-
ax: plt.Axes, distances: np.ndarray, energy: np.ndarray, formula: str
29+
ax: plt.Axes,
30+
distances: np.ndarray,
31+
energy: np.ndarray,
32+
formula: str,
3033
) -> None:
3134
"""Plot pair repulsion curve on a given axes.
3235
@@ -54,7 +57,7 @@ def plot_homo_nuclear(model_size: str) -> None:
5457
n_rows, n_columns, size_factor = 10, 18, 3
5558

5659
fig = plt.figure(
57-
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor)
60+
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor),
5861
)
5962
gs = plt.GridSpec(figure=fig, nrows=n_rows, ncols=n_columns)
6063

@@ -93,12 +96,12 @@ def plot_hetero_nuclear(model_size: str) -> None:
9396
[
9497
int(fn.name.split("-")[2])
9598
for fn in Path("simulations/").glob(f"results-{model_size}-*-X.json")
96-
]
99+
],
97100
)
98101
with PdfPages(f"{model_size}-hetero-nuclear.pdf") as pdf:
99102
for z_main in z_calculated:
100103
fig = plt.figure(
101-
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor)
104+
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor),
102105
)
103106
gs = plt.GridSpec(figure=fig, nrows=n_rows, ncols=n_columns)
104107
plot_element_heteronuclear(fig, gs, model_size, z_main)
@@ -113,7 +116,10 @@ def plot_hetero_nuclear(model_size: str) -> None:
113116

114117

115118
def plot_element_heteronuclear(
116-
fig: plt.Figure, gs: plt.GridSpec, model_size: str, atomic_number: int
119+
fig: plt.Figure,
120+
gs: plt.GridSpec,
121+
model_size: str,
122+
atomic_number: int,
117123
) -> None:
118124
"""Plot heteronuclear pair repulsion curves for a specific element.
119125

examples/make_assets/histogram.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929

3030
# %% Histogram Plots
3131
ax = pmv.elements_hist(
32-
df_expt_gap[Key.composition], keep_top=15, v_offset=200, rotation=0, fontsize=12
32+
df_expt_gap[Key.composition],
33+
keep_top=15,
34+
v_offset=200,
35+
rotation=0,
36+
fontsize=12,
3337
)
3438
pmv.io.save_and_compress_svg(ax, "elements-hist")
3539

examples/make_assets/phonons.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848

4949
fig = pmv.phonon_bands_and_dos(ph_bands, ph_doses)
5050
fig.layout.title = dict(
51-
text=f"Phonon Bands and DOS of {formula} ({mp_id})", x=0.5, y=0.98
51+
text=f"Phonon Bands and DOS of {formula} ({mp_id})",
52+
x=0.5,
53+
y=0.98,
5254
)
5355
fig.layout.margin = dict(l=0, r=0, b=0, t=40)
5456
pmv.io.save_and_compress_svg(fig, f"phonon-bands-and-dos-{mp_id}")

examples/make_assets/ptable/ptable_matplotlib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@
4949

5050
# %%
5151
fig = pmv.ptable_heatmap_ratio(
52-
df_expt_gap[Key.composition], df_steels[Key.composition], log=True, value_fmt=".4g"
52+
df_expt_gap[Key.composition],
53+
df_steels[Key.composition],
54+
log=True,
55+
value_fmt=".4g",
5356
)
5457
title = "Element ratios in Matbench Experimental Band Gap vs Matbench Steel"
5558
fig.suptitle(title, y=0.96, fontsize=16, fontweight="bold")

examples/make_assets/ptable/ptable_plotly.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232

3333
# %%
3434
fig = pmv.ptable_heatmap_plotly(
35-
df_expt_gap[Key.composition], log=True, colorscale="viridis"
35+
df_expt_gap[Key.composition],
36+
log=True,
37+
colorscale="viridis",
3638
)
3739
title = "Elements in Matbench Experimental Bandgap (log scale)"
3840
fig.layout.title = dict(text=f"<b>{title}</b>", x=0.45, y=0.94, font_size=20)

0 commit comments

Comments
 (0)