Skip to content

Speedup import and add regression check for import time #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b6bee81
pre-commit migrate-config
DanielYang59 Oct 20, 2024
5c8032d
avoid import Structure for type check
DanielYang59 Oct 20, 2024
212d573
lazily import scipy
DanielYang59 Oct 20, 2024
d4e25e7
avoid import Structure in utils
DanielYang59 Oct 20, 2024
865a4fa
copy helper func _check_type from monty
DanielYang59 Oct 20, 2024
82a31c4
lazy import plotly.figure_factory
DanielYang59 Oct 20, 2024
e60c770
lazy import NearNeighbors
DanielYang59 Oct 20, 2024
2777c0f
lazy import PhononDos and PhononBands
DanielYang59 Oct 20, 2024
13960b4
more lazy import Structure
DanielYang59 Oct 20, 2024
07e084c
clean up some duplicate imports
DanielYang59 Oct 20, 2024
5c01d25
lazy import sklearn
DanielYang59 Oct 20, 2024
5918aeb
lazy import pmg composition
DanielYang59 Oct 20, 2024
1161bf0
remove unused import from root __init__
DanielYang59 Oct 20, 2024
2b04381
relocate scikit learn import
DanielYang59 Oct 20, 2024
1059be9
revert hacky type check changes
DanielYang59 Oct 21, 2024
92a4821
bump monty the hard way
DanielYang59 Oct 23, 2024
23cbe32
WIP: add draft import time checker
DanielYang59 Oct 23, 2024
fe1e32d
add more test modules
DanielYang59 Oct 23, 2024
b019517
reduce default average count to 3, it seems very slow
DanielYang59 Oct 23, 2024
6bd0bb8
tweak gen ref time logic
DanielYang59 Oct 23, 2024
bf57cd0
update ref time
DanielYang59 Oct 23, 2024
acd25f9
use perf_counter over time()
DanielYang59 Oct 23, 2024
ed48b13
lazy import plotly.figure_factory, reduce 0.2s 10%
DanielYang59 Oct 23, 2024
2245483
update ref import time
DanielYang59 Oct 23, 2024
f64739a
use standard time format
DanielYang59 Oct 23, 2024
5143765
only run on main branch
DanielYang59 Oct 23, 2024
77f3f59
use warnings.warn
DanielYang59 Oct 26, 2024
b0006c8
use perf_counter_ns
DanielYang59 Oct 26, 2024
c337e54
rename measure_import_time_in_ms -> measure_import_time
janosh Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ci:
autoupdate_schedule: quarterly
skip: [pyright]

default_stages: [commit]
default_stages: [pre-commit]

default_install_hook_types: [pre-commit, commit-msg]

Expand Down Expand Up @@ -43,7 +43,7 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
stages: [pre-commit, commit-msg]
exclude_types: [csv, svg, html, yaml, jupyter]
args: [--ignore-words-list, "hist,mape,te,nd,fpr", --check-filenames]

Expand Down
3 changes: 0 additions & 3 deletions pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
import builtins
from importlib.metadata import PackageNotFoundError, version

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

from pymatviz import (
bar,
Expand Down
21 changes: 14 additions & 7 deletions pymatviz/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -12,16 +11,24 @@
import plotly.graph_objects as go
from matplotlib import transforms
from matplotlib.ticker import FixedLocator
from pymatgen.core import Structure
from pymatgen.symmetry.groups import SpaceGroup

from pymatviz.enums import Key
from pymatviz.utils import PLOTLY, Backend, crystal_sys_from_spg_num, si_fmt_int
from pymatviz.utils import (
PLOTLY,
Backend,
_check_type,
crystal_sys_from_spg_num,
si_fmt_int,
)


if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any, Literal

from pymatgen.core import Structure


def spacegroup_bar(
data: Sequence[int | str | Structure] | pd.Series,
Expand Down Expand Up @@ -61,10 +68,10 @@ def spacegroup_bar(
Returns:
plt.Axes | go.Figure: matplotlib Axes or plotly Figure depending on backend.
"""
if isinstance(next(iter(data)), Structure):
# TODO: use this hacky type check to avoid expensive import of Structure, #209
if _check_type(next(iter(data)), "pymatgen.core.structure.Structure"):
# if 1st sequence item is structure, assume all are
data = cast(Sequence[Structure], data)
series = pd.Series(struct.get_space_group_info()[1] for struct in data)
series = pd.Series(struct.get_space_group_info()[1] for struct in data) # type: ignore[union-attr]
else:
series = pd.Series(data)

Expand Down
28 changes: 22 additions & 6 deletions pymatviz/coordination.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
"""Visualizations of coordination numbers distributions."""

from __future__ import annotations

import math
from collections import Counter
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from inspect import isclass
from typing import Any, Literal
from typing import TYPE_CHECKING

import numpy as np
import plotly.graph_objects as go
from plotly.colors import label_rgb
from plotly.subplots import make_subplots
from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core import PeriodicSite, Structure

from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
from pymatviz.enums import ElemColorScheme, LabelEnum
from pymatviz.utils import normalize_to_dict
from pymatviz.utils import _check_type, normalize_to_dict


if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any, Literal

from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core import PeriodicSite, Structure


class SplitMode(LabelEnum):
Expand Down Expand Up @@ -57,10 +65,14 @@ def normalize_get_neighbors(
# Prepare the neighbor-finding strategy
if isinstance(strategy, int | float):
return lambda site, structure: structure.get_neighbors(site, strategy)
if isinstance(strategy, NearNeighbors):

if _check_type(strategy, "pymatgen.analysis.local_env.NearNeighbors"):
return lambda site, structure: strategy.get_nn_info(
structure, structure.index(site)
)

from pymatgen.analysis.local_env import NearNeighbors # costly import

if isclass(strategy) and issubclass(strategy, NearNeighbors):
nn_instance = strategy()
return lambda site, structure: nn_instance.get_nn_info(
Expand Down Expand Up @@ -418,13 +430,16 @@ def coordination_vs_cutoff_line(
"""
structures = normalize_to_dict(structures)

from pymatgen.analysis.local_env import NearNeighbors

# Determine cutoff range based on strategy
if (
isinstance(strategy, tuple)
and len(strategy) == 2
and {*map(type, strategy)} <= {int, float}
):
cutoff_range = strategy

elif isinstance(strategy, NearNeighbors) or (
isclass(strategy) and issubclass(strategy, NearNeighbors)
):
Expand All @@ -436,6 +451,7 @@ def coordination_vs_cutoff_line(
else:
raise AttributeError(f"Could not determine cutoff for {nn_instance=}")
cutoff_range = (0, max_cutoff)

else:
raise TypeError(
f"Invalid {strategy=}. Expected float, tuple of floats, NearNeighbors "
Expand Down
13 changes: 9 additions & 4 deletions pymatviz/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
import plotly.graph_objects as go
import scipy.constants as const
from plotly.subplots import make_subplots
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands
from pymatgen.phonon.dos import PhononDos
from pymatgen.util.string import htmlify


Expand All @@ -22,9 +19,11 @@

import numpy as np
from pymatgen.core import Structure
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands
from pymatgen.phonon.dos import PhononDos
from typing_extensions import Self

AnyBandStructure: TypeAlias = BandStructureSymmLine | PhononBands
# AnyBandStructure: TypeAlias = BandStructureSymmLine | PhononBands
YMin: TypeAlias = float | Literal["y_min"]
YMax: TypeAlias = float | Literal["y_max"]
BranchMode: TypeAlias = Literal["union", "intersection"]
Expand Down Expand Up @@ -201,6 +200,9 @@ def phonon_bands(
f"Invalid {branch_mode=}, must be one of {get_args(BranchMode)}"
)

# costly import
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands

if type(band_structs) not in {PhononBands, dict}:
cls_name = PhononBands.__name__
raise TypeError(
Expand Down Expand Up @@ -356,6 +358,9 @@ def phonon_dos(
if normalize not in valid_normalize:
raise ValueError(f"Invalid {normalize=}, must be one of {valid_normalize}.")

# costly import
from pymatgen.phonon.dos import PhononDos

if type(doses) not in {PhononDos, dict}:
raise TypeError(
f"Only {PhononDos.__name__} or dict supported, got {type(doses).__name__}"
Expand Down
8 changes: 5 additions & 3 deletions pymatviz/powerups/both.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import sklearn
from sklearn.metrics import mean_absolute_percentage_error as mape
from sklearn.metrics import r2_score

from pymatviz.utils import (
BACKENDS,
Expand Down Expand Up @@ -79,6 +76,11 @@ def annotate_metrics(

backend: Backend = PLOTLY if isinstance(fig, go.Figure) else MATPLOTLIB

# Lazily import costly scikit-learn
import sklearn
from sklearn.metrics import mean_absolute_percentage_error as mape
from sklearn.metrics import r2_score

funcs = {
"MAE": lambda x, y: np.abs(x - y).mean(),
"RMSE": lambda x, y: (((x - y) ** 2).mean()) ** 0.5,
Expand Down
4 changes: 3 additions & 1 deletion pymatviz/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pandas as pd
from pandas.api.types import is_numeric_dtype, is_string_dtype
from pymatgen.core import Composition

from pymatviz.enums import ElemCountMode, Key
from pymatviz.utils import ElemValues, df_ptable
Expand Down Expand Up @@ -63,8 +62,11 @@ def count_elements(
# Ensure values is Series if we got dict/list/tuple
srs = pd.Series(values)

from pymatgen.core import Composition # costly import

if is_numeric_dtype(srs):
pass

elif is_string_dtype(srs) or {*map(type, srs)} <= {str, Composition}:
# all items are formula strings or Composition objects
if count_mode == "occurrence":
Expand Down
3 changes: 2 additions & 1 deletion pymatviz/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff

from pymatviz.enums import ElemCountMode
from pymatviz.process_data import count_elements
Expand Down Expand Up @@ -272,6 +271,8 @@ def ptable_heatmap_plotly(
zmax = max(non_nan_values) if cscale_range[1] is None else cscale_range[1]
car_multiplier = 100 if heat_mode == "percent" else 1

import plotly.figure_factory as ff # costly import

fig = ff.create_annotated_heatmap(
car_multiplier * heatmap_values,
annotation_text=tile_texts,
Expand Down
5 changes: 4 additions & 1 deletion pymatviz/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import sklearn.metrics as skm

from pymatviz.utils import df_to_arrays

Expand Down Expand Up @@ -37,6 +36,8 @@ def roc_curve(
ax = ax or plt.gca()

# get the metrics
import sklearn.metrics as skm

false_pos_rate, true_pos_rate, _ = skm.roc_curve(targets, proba_pos)
roc_auc = skm.roc_auc_score(targets, proba_pos)

Expand Down Expand Up @@ -71,6 +72,8 @@ def precision_recall_curve(
ax = ax or plt.gca()

# get the metrics
import sklearn.metrics as skm

precision, recall, _ = skm.precision_recall_curve(targets, proba_pos)

# proba_pos.round() converts class probabilities to integer class labels
Expand Down
5 changes: 1 addition & 4 deletions pymatviz/structure_viz/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
import itertools
import math
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core import Composition, Lattice, PeriodicSite, Species, Structure

from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
Expand Down
3 changes: 1 addition & 2 deletions pymatviz/structure_viz/mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

import math
import warnings
from collections.abc import Callable, Sequence
from itertools import product
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
Expand Down
4 changes: 1 addition & 3 deletions pymatviz/structure_viz/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@

import math
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pymatgen.analysis.local_env import CrystalNN, NearNeighbors
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
Expand Down
34 changes: 30 additions & 4 deletions pymatviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import scipy.stats
from matplotlib.colors import to_rgb
from matplotlib.offsetbox import AnchoredText
from matplotlib.ticker import FormatStrFormatter, PercentFormatter, ScalarFormatter
from pymatgen.core import Structure


if TYPE_CHECKING:
Expand Down Expand Up @@ -235,6 +233,8 @@ def bin_df_cols(
)

if density_col:
import scipy.stats # expensive import

# compute kernel density estimate for each bin
values = df_in[bin_by_cols].dropna().T
gaussian_kde = scipy.stats.gaussian_kde(values.astype(float))
Expand Down Expand Up @@ -679,7 +679,7 @@ def _get_matplotlib_font_color(fig: plt.Figure | plt.Axes) -> str:

def normalize_to_dict(
inputs: T | Sequence[T] | dict[str, T],
cls: type[T] = Structure,
cls: type[T] | None = None,
key_gen: Callable[[T], str] = lambda obj: getattr(
obj, "formula", type(obj).__name__
),
Expand All @@ -699,8 +699,14 @@ def normalize_to_dict(
Raises:
TypeError: If the input format is invalid.
"""
if cls is None:
from pymatgen.core import Structure # costly import

cls = Structure

if isinstance(inputs, cls):
return {"": inputs}
return {"": inputs} # type: ignore[dict-item]

if (
isinstance(inputs, list | tuple)
and all(isinstance(obj, cls) for obj in inputs)
Expand All @@ -722,3 +728,23 @@ def normalize_to_dict(
raise TypeError(
f"Invalid {inputs=}, expected {cls_name} or dict/list/tuple of {cls_name}"
)


def _check_type(obj: object, type_str: tuple[str, ...] | str) -> bool:
"""Alternative to isinstance that avoids imports.

Todo:
Taken from monty.json, use until monty.json import fix merged.

Note for future developers: the type_str is not always obvious for an
object. For example, pandas.DataFrame is actually pandas.core.frame.DataFrame.
To find out the type_str for an object, run type(obj).mro(). This will
list all the types that an object can resolve to in order of generality
(all objects have the builtins.object as the last one).
"""
type_str = type_str if isinstance(type_str, tuple) else (type_str,)
try:
mro = type(obj).mro()
except TypeError:
return False
return any(f"{o.__module__}.{o.__name__}" == ts for o in mro for ts in type_str)
Loading