Skip to content

Speedup import and add regression check for import time #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b6bee81
pre-commit migrate-config
DanielYang59 Oct 20, 2024
5c8032d
avoid import Structure for type check
DanielYang59 Oct 20, 2024
212d573
lazily import scipy
DanielYang59 Oct 20, 2024
d4e25e7
avoid import Structure in utils
DanielYang59 Oct 20, 2024
865a4fa
copy helper func _check_type from monty
DanielYang59 Oct 20, 2024
82a31c4
lazy import plotly.figure_factory
DanielYang59 Oct 20, 2024
e60c770
lazy import NearNeighbors
DanielYang59 Oct 20, 2024
2777c0f
lazy import PhononDos and PhononBands
DanielYang59 Oct 20, 2024
13960b4
more lazy import Structure
DanielYang59 Oct 20, 2024
07e084c
clean up some duplicate imports
DanielYang59 Oct 20, 2024
5c01d25
lazy import sklearn
DanielYang59 Oct 20, 2024
5918aeb
lazy import pmg composition
DanielYang59 Oct 20, 2024
1161bf0
remove unused import from root __init__
DanielYang59 Oct 20, 2024
2b04381
relocate scikit learn import
DanielYang59 Oct 20, 2024
1059be9
revert hacky type check changes
DanielYang59 Oct 21, 2024
92a4821
bump monty the hard way
DanielYang59 Oct 23, 2024
23cbe32
WIP: add draft import time checker
DanielYang59 Oct 23, 2024
fe1e32d
add more test modules
DanielYang59 Oct 23, 2024
b019517
reduce default average count to 3, it seems very slow
DanielYang59 Oct 23, 2024
6bd0bb8
tweak gen ref time logic
DanielYang59 Oct 23, 2024
bf57cd0
update ref time
DanielYang59 Oct 23, 2024
acd25f9
use perf_counter over time()
DanielYang59 Oct 23, 2024
ed48b13
lazy import plotly.figure_factory, reduce 0.2s 10%
DanielYang59 Oct 23, 2024
2245483
update ref import time
DanielYang59 Oct 23, 2024
f64739a
use standard time format
DanielYang59 Oct 23, 2024
5143765
only run on main branch
DanielYang59 Oct 23, 2024
77f3f59
use warnings.warn
DanielYang59 Oct 26, 2024
b0006c8
use perf_counter_ns
DanielYang59 Oct 26, 2024
c337e54
rename measure_import_time_in_ms -> measure_import_time
janosh Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ci:
autoupdate_schedule: quarterly
skip: [pyright]

default_stages: [commit]
default_stages: [pre-commit]

default_install_hook_types: [pre-commit, commit-msg]

Expand Down Expand Up @@ -43,7 +43,7 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
stages: [pre-commit, commit-msg]
exclude_types: [csv, svg, html, yaml, jupyter]
args: [--ignore-words-list, "hist,mape,te,nd,fpr", --check-filenames]

Expand Down
17 changes: 11 additions & 6 deletions pymatviz/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -12,16 +11,18 @@
import plotly.graph_objects as go
from matplotlib import transforms
from matplotlib.ticker import FixedLocator
from pymatgen.core import Structure
from pymatgen.symmetry.groups import SpaceGroup

from pymatviz.enums import Key
from pymatviz.utils import PLOTLY, Backend, crystal_sys_from_spg_num, si_fmt_int


if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any, Literal

from pymatgen.core import Structure


def spacegroup_bar(
data: Sequence[int | str | Structure] | pd.Series,
Expand Down Expand Up @@ -61,10 +62,14 @@ def spacegroup_bar(
Returns:
plt.Axes | go.Figure: matplotlib Axes or plotly Figure depending on backend.
"""
if isinstance(next(iter(data)), Structure):
# TODO: use this hacky type check to avoid expensive import of Structure, #209
obj = next(iter(data))
if (
obj.__class__.__module__ == "pymatgen.core.structure"
and obj.__class__.__qualname__ == "Structure"
):
# if 1st sequence item is structure, assume all are
data = cast(Sequence[Structure], data)
series = pd.Series(struct.get_space_group_info()[1] for struct in data)
series = pd.Series(struct.get_space_group_info()[1] for struct in data) # type: ignore[union-attr]
else:
series = pd.Series(data)

Expand Down
14 changes: 10 additions & 4 deletions pymatviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import scipy.stats
from matplotlib.colors import to_rgb
from matplotlib.offsetbox import AnchoredText
from matplotlib.ticker import FormatStrFormatter, PercentFormatter, ScalarFormatter
from pymatgen.core import Structure


if TYPE_CHECKING:
Expand Down Expand Up @@ -235,6 +233,8 @@ def bin_df_cols(
)

if density_col:
import scipy.stats # expensive import

# compute kernel density estimate for each bin
values = df_in[bin_by_cols].dropna().T
gaussian_kde = scipy.stats.gaussian_kde(values.astype(float))
Expand Down Expand Up @@ -679,7 +679,7 @@ def _get_matplotlib_font_color(fig: plt.Figure | plt.Axes) -> str:

def normalize_to_dict(
inputs: T | Sequence[T] | dict[str, T],
cls: type[T] = Structure,
cls: type[T] | None = None,
key_gen: Callable[[T], str] = lambda obj: getattr(
obj, "formula", type(obj).__name__
),
Expand All @@ -699,8 +699,14 @@ def normalize_to_dict(
Raises:
TypeError: If the input format is invalid.
"""
if cls is None:
from pymatgen.core import Structure # costly import

cls = Structure

if isinstance(inputs, cls):
return {"": inputs}
return {"": inputs} # type: ignore[dict-item]

if (
isinstance(inputs, list | tuple)
and all(isinstance(obj, cls) for obj in inputs)
Expand Down