Skip to content

Commit 23cbe32

Browse files
committed
WIP: add draft import time checker
1 parent 92a4821 commit 23cbe32

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
Test import time of core modules to avoid regression.
3+
"""
4+
5+
# ruff: noqa: T201 (check for print statement)
6+
7+
from __future__ import annotations
8+
9+
import subprocess
10+
import time
11+
12+
import pytest
13+
14+
15+
# TODO: add more test modules
16+
REF_IMPORT_TIME: dict[str, float] = {
17+
"pymatviz": None,
18+
"pymatviz.ptable": None,
19+
"pymatviz.io": None,
20+
"pymatviz.scatter": None,
21+
"pymatviz.phonons": None,
22+
}
23+
24+
25+
# @pytest.mark.skip(reason="Unskip to generate reference import time.")
26+
def test_get_ref_import_time() -> None:
27+
"""A dummy test that would always fail, used to generate copyable reference time."""
28+
# Measure import time for each module
29+
import_times = {
30+
module_name: measure_import_time_in_ms(module_name)
31+
for module_name in REF_IMPORT_TIME
32+
}
33+
34+
# Print out the import times in a copyable format
35+
print("\nCopyable import time dictionary:")
36+
print("{")
37+
for module_name, import_time in import_times.items():
38+
print(f' "{module_name}": {import_time:.2f},')
39+
print("}")
40+
41+
pytest.fail("Generated reference import times.")
42+
43+
44+
def measure_import_time_in_ms(module_name: str, count: int = 10) -> float:
45+
"""Measure import time of a module in milliseconds across several runs.
46+
47+
Args:
48+
module_name (str): name of the module to test.
49+
count (int): Number of runs to average.
50+
51+
Returns:
52+
float: import time in milliseconds.
53+
"""
54+
total_time = 0.0
55+
56+
for _ in range(count):
57+
start_time = time.time()
58+
subprocess.run(["python", "-c", f"import {module_name}"], check=True) # noqa: S603, S607
59+
total_time += time.time() - start_time
60+
61+
return (total_time / count) * 1000
62+
63+
64+
def test_import_time(grace_percent: float = 0.20, hard_percent: float = 0.50) -> None:
65+
"""Test the import time of core modules to avoid regression in performance.
66+
67+
Args:
68+
grace_percentage (float): Maximum allowed percentage increase in import time
69+
before a warning is raised.
70+
hard_percentage (float): Maximum allowed percentage increase in import time
71+
before the test fails.
72+
"""
73+
for module_name, ref_time in REF_IMPORT_TIME.items():
74+
if ref_time is None:
75+
pytest.skip(f"No reference import time for {module_name}")
76+
77+
current_time = measure_import_time_in_ms(module_name)
78+
79+
# Calculate grace and hard thresholds
80+
grace_threshold = ref_time * (1 + grace_percent)
81+
hard_threshold = ref_time * (1 + hard_percent)
82+
83+
if current_time > grace_threshold:
84+
if current_time > hard_threshold:
85+
pytest.fail(f"{module_name} import too slow! {hard_threshold=:.2f} ms")
86+
else:
87+
pytest.warns(
88+
UserWarning,
89+
f"{module_name} import slightly slower: {grace_threshold=:.2f} ms",
90+
)

tests/test_coordination.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from __future__ import annotations
2+
13
import re
2-
from collections.abc import Sequence
3-
from typing import Any
4+
from typing import TYPE_CHECKING
45

56
import pytest
67
from pymatgen.analysis.local_env import CrystalNN, NearNeighbors, VoronoiNN
@@ -15,6 +16,11 @@
1516
)
1617

1718

19+
if TYPE_CHECKING:
20+
from collections.abc import Sequence
21+
from typing import Any
22+
23+
1824
def test_coordination_hist_single_structure(structures: Sequence[Structure]) -> None:
1925
"""Test coordination_hist with a single structure."""
2026
fig = coordination_hist(structures[0])

0 commit comments

Comments
 (0)