Skip to content

Commit 69b5285

Browse files
DanielYang59janosh
andauthored
Speedup import and add regression check for import time (#238)
* Add test framework to monitor module import times with regression tests * Use time.perf_counter for accurate timing * Implement lazy imports across multiple modules to improve performance: - scipy - plotly.figure_factory - sklearn - pymatgen (Structure, NearNeighbors, PhononDos, PhononBands, Composition) * Add reference import times for all core modules * Configure tests to run only on main branch * Add grace and hard thresholds for import time regression --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 94ee9e9 commit 69b5285

File tree

12 files changed

+139
-20
lines changed

12 files changed

+139
-20
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ ci:
22
autoupdate_schedule: quarterly
33
skip: [pyright]
44

5-
default_stages: [commit]
5+
default_stages: [pre-commit]
66

77
default_install_hook_types: [pre-commit, commit-msg]
88

@@ -43,7 +43,7 @@ repos:
4343
rev: v2.3.0
4444
hooks:
4545
- id: codespell
46-
stages: [commit, commit-msg]
46+
stages: [pre-commit, commit-msg]
4747
exclude_types: [csv, svg, html, yaml, jupyter]
4848
args: [--ignore-words-list, "hist,mape,te,nd,fpr", --check-filenames]
4949

pymatviz/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
import builtins
1515
from importlib.metadata import PackageNotFoundError, version
1616

17-
import matplotlib.pyplot as plt
1817
import plotly.express as px
19-
import plotly.graph_objects as go
20-
import plotly.io as pio
2118

2219
from pymatviz import (
2320
bar,

pymatviz/coordination.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
11
"""Visualizations of coordination numbers distributions."""
22

3+
from __future__ import annotations
4+
35
import math
46
from collections import Counter
5-
from collections.abc import Callable, Sequence
7+
from collections.abc import Sequence
68
from inspect import isclass
7-
from typing import Any, Literal
9+
from typing import TYPE_CHECKING
810

911
import numpy as np
1012
import plotly.graph_objects as go
1113
from plotly.colors import label_rgb
1214
from plotly.subplots import make_subplots
1315
from pymatgen.analysis.local_env import NearNeighbors
14-
from pymatgen.core import PeriodicSite, Structure
1516

1617
from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
1718
from pymatviz.enums import ElemColorScheme, LabelEnum
1819
from pymatviz.utils import normalize_to_dict
1920

2021

22+
if TYPE_CHECKING:
23+
from collections.abc import Callable
24+
from typing import Any, Literal
25+
26+
from pymatgen.core import PeriodicSite, Structure
27+
28+
2129
class SplitMode(LabelEnum):
2230
"""How to split the coordination number histogram into subplots."""
2331

@@ -57,10 +65,12 @@ def normalize_get_neighbors(
5765
# Prepare the neighbor-finding strategy
5866
if isinstance(strategy, int | float):
5967
return lambda site, structure: structure.get_neighbors(site, strategy)
68+
6069
if isinstance(strategy, NearNeighbors):
6170
return lambda site, structure: strategy.get_nn_info(
6271
structure, structure.index(site)
6372
)
73+
6474
if isclass(strategy) and issubclass(strategy, NearNeighbors):
6575
nn_instance = strategy()
6676
return lambda site, structure: nn_instance.get_nn_info(
@@ -417,6 +427,7 @@ def coordination_vs_cutoff_line(
417427
and {*map(type, strategy)} <= {int, float}
418428
):
419429
cutoff_range = strategy
430+
420431
elif isinstance(strategy, NearNeighbors) or (
421432
isclass(strategy) and issubclass(strategy, NearNeighbors)
422433
):
@@ -428,6 +439,7 @@ def coordination_vs_cutoff_line(
428439
else:
429440
raise AttributeError(f"Could not determine cutoff for {nn_instance=}")
430441
cutoff_range = (0, max_cutoff)
442+
431443
else:
432444
raise TypeError(
433445
f"Invalid {strategy=}. Expected float, tuple of floats, NearNeighbors "

pymatviz/process_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def count_elements(
6565

6666
if is_numeric_dtype(srs):
6767
pass
68+
6869
elif is_string_dtype(srs) or {*map(type, srs)} <= {str, Composition}:
6970
# all items are formula strings or Composition objects
7071
if count_mode == "occurrence":

pymatviz/ptable/ptable_plotly.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
import pandas as pd
1010
import plotly.express as px
11-
import plotly.figure_factory as ff
1211

1312
from pymatviz.enums import ElemCountMode
1413
from pymatviz.process_data import count_elements
@@ -272,6 +271,8 @@ def ptable_heatmap_plotly(
272271
zmax = max(non_nan_values) if cscale_range[1] is None else cscale_range[1]
273272
car_multiplier = 100 if heat_mode == "percent" else 1
274273

274+
import plotly.figure_factory as ff # slow import
275+
275276
fig = ff.create_annotated_heatmap(
276277
car_multiplier * heatmap_values,
277278
annotation_text=tile_texts,

pymatviz/structure_viz/helpers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
import itertools
77
import math
88
import warnings
9-
from collections.abc import Callable, Sequence
10-
from typing import TYPE_CHECKING, Any, Literal
9+
from typing import TYPE_CHECKING
1110

1211
import numpy as np
1312
import pandas as pd
14-
import plotly.graph_objects as go
15-
from pymatgen.analysis.local_env import NearNeighbors
1613
from pymatgen.core import Composition, Lattice, PeriodicSite, Species, Structure
1714

1815
from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA

pymatviz/structure_viz/mpl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
import math
1010
import warnings
11-
from collections.abc import Callable, Sequence
1211
from itertools import product
13-
from typing import TYPE_CHECKING, Any, Literal
12+
from typing import TYPE_CHECKING
1413

1514
import matplotlib.pyplot as plt
1615
import numpy as np

pymatviz/structure_viz/plotly.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44

55
import math
66
import warnings
7-
from collections.abc import Callable, Sequence
8-
from typing import TYPE_CHECKING, Any, Literal
7+
from typing import TYPE_CHECKING
98

109
import numpy as np
11-
import plotly.graph_objects as go
1210
from plotly.subplots import make_subplots
1311
from pymatgen.analysis.local_env import CrystalNN, NearNeighbors
1412
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

pymatviz/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ def normalize_to_dict(
701701
"""
702702
if isinstance(inputs, cls):
703703
return {"": inputs}
704+
704705
if (
705706
isinstance(inputs, list | tuple)
706707
and all(isinstance(obj, cls) for obj in inputs)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ dependencies = [
3333
"pandas[output-formatting]>=2.2",
3434
"plotly>=5.23",
3535
"pymatgen>=2024.7.18",
36+
# TODO: pmv doesn't actually depend on monty, however latest monty
37+
# includes a critical import patch, remove this after pmg bump dep
38+
"monty>=2024.10.21",
3639
"scikit-learn>=1.5",
3740
"scipy>=1.14",
3841
]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Test import time of core modules to avoid regression.
3+
"""
4+
5+
# ruff: noqa: T201 (check for print statement)
6+
7+
from __future__ import annotations
8+
9+
import os
10+
import subprocess
11+
import time
12+
import warnings
13+
14+
import pytest
15+
16+
17+
GEN_REF_TIME = False # switch for generating reference time
18+
19+
# Last update: 2024-10-23
20+
REF_IMPORT_TIME: dict[str, float] = {
21+
"pymatviz": 4085.73,
22+
"pymatviz.coordination": 4135.77,
23+
"pymatviz.cumulative": 4108.06,
24+
"pymatviz.histogram": 4110.41,
25+
"pymatviz.phonons": 4109.97,
26+
"pymatviz.powerups": 4066.31,
27+
"pymatviz.ptable": 4092.35,
28+
"pymatviz.rainclouds": 4098.33,
29+
"pymatviz.rdf": 4144.26,
30+
"pymatviz.relevance": 4126.54,
31+
"pymatviz.sankey": 4135.17,
32+
"pymatviz.scatter": 4087.62,
33+
"pymatviz.structure_viz": 4105.33,
34+
"pymatviz.sunburst": 4133.78,
35+
"pymatviz.uncertainty": 4179.99,
36+
"pymatviz.xrd": 4156.52,
37+
}
38+
39+
40+
@pytest.mark.skipif(
41+
not GEN_REF_TIME, reason="Set GEN_REF_TIME to generate reference import time."
42+
)
43+
def test_get_ref_import_time() -> None:
44+
"""A dummy test that would always fail, used to generate copyable reference time."""
45+
import_times = {
46+
module_name: round(measure_import_time(module_name), 2)
47+
for module_name in REF_IMPORT_TIME
48+
}
49+
50+
# Print out the import times in a copyable format
51+
print("\nCopyable import time dictionary:")
52+
print(import_times)
53+
54+
pytest.fail("Generate reference import times.")
55+
56+
57+
def measure_import_time(module_name: str, repeats: int = 3) -> float:
58+
"""Measure import time of a module in milliseconds across several runs.
59+
60+
Args:
61+
module_name (str): name of the module to test.
62+
count (int): Number of runs to average.
63+
64+
Returns:
65+
float: import time in milliseconds.
66+
"""
67+
total_time = 0.0
68+
69+
for _ in range(repeats):
70+
start_time = time.perf_counter()
71+
subprocess.run(["python", "-c", f"import {module_name}"], check=True) # noqa: S603, S607
72+
total_time += time.perf_counter() - start_time
73+
74+
return total_time / repeats * 1000
75+
76+
77+
@pytest.mark.skipif(
78+
os.getenv("GITHUB_REF") != "refs/heads/main", reason="Only run on the main branch"
79+
)
80+
@pytest.mark.skipif(GEN_REF_TIME, reason="Generating reference import time.")
81+
def test_import_time(grace_percent: float = 0.20, hard_percent: float = 0.50) -> None:
82+
"""Test the import time of core modules to avoid regression in performance.
83+
84+
Args:
85+
grace_percentage (float): Maximum allowed percentage increase in import time
86+
before a warning is raised.
87+
hard_percentage (float): Maximum allowed percentage increase in import time
88+
before the test fails.
89+
"""
90+
for module_name, ref_time in REF_IMPORT_TIME.items():
91+
current_time = measure_import_time(module_name)
92+
93+
# Calculate grace and hard thresholds
94+
grace_threshold = ref_time * (1 + grace_percent)
95+
hard_threshold = ref_time * (1 + hard_percent)
96+
97+
if current_time > grace_threshold:
98+
if current_time > hard_threshold:
99+
pytest.fail(f"{module_name} import too slow! {hard_threshold=:.2f} ms")
100+
else:
101+
warnings.warn(
102+
f"{module_name} import slightly slower: {grace_threshold=:.2f} ms",
103+
stacklevel=2,
104+
)

tests/test_coordination.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from __future__ import annotations
2+
13
import re
2-
from collections.abc import Sequence
3-
from typing import Any
4+
from typing import TYPE_CHECKING
45

56
import pytest
67
from pymatgen.analysis.local_env import CrystalNN, NearNeighbors, VoronoiNN
@@ -15,6 +16,11 @@
1516
)
1617

1718

19+
if TYPE_CHECKING:
20+
from collections.abc import Sequence
21+
from typing import Any
22+
23+
1824
def test_coordination_hist_single_structure(structures: Sequence[Structure]) -> None:
1925
"""Test coordination_hist with a single structure."""
2026
fig = coordination_hist(structures[0])

0 commit comments

Comments
 (0)