Skip to content

Commit 3147dc5

Browse files
committed
revert hacky type check changes
1 parent 2b04381 commit 3147dc5

File tree

9 files changed

+27
-74
lines changed

9 files changed

+27
-74
lines changed

pymatviz/bar.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from collections.abc import Sequence
6+
from typing import TYPE_CHECKING, cast
67

78
import matplotlib.pyplot as plt
89
import numpy as np
@@ -11,24 +12,16 @@
1112
import plotly.graph_objects as go
1213
from matplotlib import transforms
1314
from matplotlib.ticker import FixedLocator
15+
from pymatgen.core.structure import Structure
1416
from pymatgen.symmetry.groups import SpaceGroup
1517

1618
from pymatviz.enums import Key
17-
from pymatviz.utils import (
18-
PLOTLY,
19-
Backend,
20-
_check_type,
21-
crystal_sys_from_spg_num,
22-
si_fmt_int,
23-
)
19+
from pymatviz.utils import PLOTLY, Backend, crystal_sys_from_spg_num, si_fmt_int
2420

2521

2622
if TYPE_CHECKING:
27-
from collections.abc import Sequence
2823
from typing import Any, Literal
2924

30-
from pymatgen.core import Structure
31-
3225

3326
def spacegroup_bar(
3427
data: Sequence[int | str | Structure] | pd.Series,
@@ -68,10 +61,10 @@ def spacegroup_bar(
6861
Returns:
6962
plt.Axes | go.Figure: matplotlib Axes or plotly Figure depending on backend.
7063
"""
71-
# TODO: use this hacky type check to avoid expensive import of Structure, #209
72-
if _check_type(next(iter(data)), "pymatgen.core.structure.Structure"):
64+
if isinstance(next(iter(data)), Structure):
7365
# if 1st sequence item is structure, assume all are
74-
series = pd.Series(struct.get_space_group_info()[1] for struct in data) # type: ignore[union-attr]
66+
data = cast(Sequence[Structure], data)
67+
series = pd.Series(struct.get_space_group_info()[1] for struct in data)
7568
else:
7669
series = pd.Series(data)
7770

pymatviz/coordination.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
import plotly.graph_objects as go
1313
from plotly.colors import label_rgb
1414
from plotly.subplots import make_subplots
15+
from pymatgen.analysis.local_env import NearNeighbors
1516

1617
from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
1718
from pymatviz.enums import ElemColorScheme, LabelEnum
18-
from pymatviz.utils import _check_type, normalize_to_dict
19+
from pymatviz.utils import normalize_to_dict
1920

2021

2122
if TYPE_CHECKING:
2223
from collections.abc import Callable
2324
from typing import Any, Literal
2425

25-
from pymatgen.analysis.local_env import NearNeighbors
2626
from pymatgen.core import PeriodicSite, Structure
2727

2828

@@ -66,13 +66,11 @@ def normalize_get_neighbors(
6666
if isinstance(strategy, int | float):
6767
return lambda site, structure: structure.get_neighbors(site, strategy)
6868

69-
if _check_type(strategy, "pymatgen.analysis.local_env.NearNeighbors"):
69+
if isinstance(strategy, NearNeighbors):
7070
return lambda site, structure: strategy.get_nn_info(
7171
structure, structure.index(site)
7272
)
7373

74-
from pymatgen.analysis.local_env import NearNeighbors # costly import
75-
7674
if isclass(strategy) and issubclass(strategy, NearNeighbors):
7775
nn_instance = strategy()
7876
return lambda site, structure: nn_instance.get_nn_info(
@@ -430,8 +428,6 @@ def coordination_vs_cutoff_line(
430428
"""
431429
structures = normalize_to_dict(structures)
432430

433-
from pymatgen.analysis.local_env import NearNeighbors
434-
435431
# Determine cutoff range based on strategy
436432
if (
437433
isinstance(strategy, tuple)

pymatviz/phonons.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import plotly.graph_objects as go
1111
import scipy.constants as const
1212
from plotly.subplots import make_subplots
13+
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
14+
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands
1315
from pymatgen.util.string import htmlify
1416

1517

@@ -19,11 +21,10 @@
1921

2022
import numpy as np
2123
from pymatgen.core import Structure
22-
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands
2324
from pymatgen.phonon.dos import PhononDos
2425
from typing_extensions import Self
2526

26-
# AnyBandStructure: TypeAlias = BandStructureSymmLine | PhononBands
27+
AnyBandStructure: TypeAlias = BandStructureSymmLine | PhononBands
2728
YMin: TypeAlias = float | Literal["y_min"]
2829
YMax: TypeAlias = float | Literal["y_max"]
2930
BranchMode: TypeAlias = Literal["union", "intersection"]

pymatviz/powerups/both.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import matplotlib.pyplot as plt
99
import numpy as np
1010
import plotly.graph_objects as go
11+
import sklearn
12+
from sklearn.metrics import mean_absolute_percentage_error as mape
13+
from sklearn.metrics import r2_score
1114

1215
from pymatviz.utils import (
1316
BACKENDS,
@@ -76,11 +79,6 @@ def annotate_metrics(
7679

7780
backend: Backend = PLOTLY if isinstance(fig, go.Figure) else MATPLOTLIB
7881

79-
# Lazily import costly scikit-learn
80-
import sklearn
81-
from sklearn.metrics import mean_absolute_percentage_error as mape
82-
from sklearn.metrics import r2_score
83-
8482
funcs = {
8583
"MAE": lambda x, y: np.abs(x - y).mean(),
8684
"RMSE": lambda x, y: (((x - y) ** 2).mean()) ** 0.5,

pymatviz/process_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pandas as pd
99
from pandas.api.types import is_numeric_dtype, is_string_dtype
10+
from pymatgen.core import Composition
1011

1112
from pymatviz.enums import ElemCountMode, Key
1213
from pymatviz.utils import ElemValues, df_ptable
@@ -62,8 +63,6 @@ def count_elements(
6263
# Ensure values is Series if we got dict/list/tuple
6364
srs = pd.Series(values)
6465

65-
from pymatgen.core import Composition # costly import
66-
6766
if is_numeric_dtype(srs):
6867
pass
6968

pymatviz/ptable/ptable_plotly.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import pandas as pd
1010
import plotly.express as px
11+
import plotly.figure_factory as ff
1112

1213
from pymatviz.enums import ElemCountMode
1314
from pymatviz.process_data import count_elements
@@ -271,8 +272,6 @@ def ptable_heatmap_plotly(
271272
zmax = max(non_nan_values) if cscale_range[1] is None else cscale_range[1]
272273
car_multiplier = 100 if heat_mode == "percent" else 1
273274

274-
import plotly.figure_factory as ff # costly import
275-
276275
fig = ff.create_annotated_heatmap(
277276
car_multiplier * heatmap_values,
278277
annotation_text=tile_texts,

pymatviz/relevance.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING
66

77
import matplotlib.pyplot as plt
8+
import sklearn.metrics as skm
89

910
from pymatviz.utils import df_to_arrays
1011

@@ -36,8 +37,6 @@ def roc_curve(
3637
ax = ax or plt.gca()
3738

3839
# get the metrics
39-
import sklearn.metrics as skm
40-
4140
false_pos_rate, true_pos_rate, _ = skm.roc_curve(targets, proba_pos)
4241
roc_auc = skm.roc_auc_score(targets, proba_pos)
4342

@@ -72,8 +71,6 @@ def precision_recall_curve(
7271
ax = ax or plt.gca()
7372

7473
# get the metrics
75-
import sklearn.metrics as skm
76-
7774
precision, recall, _ = skm.precision_recall_curve(targets, proba_pos)
7875

7976
# proba_pos.round() converts class probabilities to integer class labels

pymatviz/utils.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
import pandas as pd
1717
import plotly.graph_objects as go
1818
import plotly.io as pio
19+
import scipy.stats
1920
from matplotlib.colors import to_rgb
2021
from matplotlib.offsetbox import AnchoredText
2122
from matplotlib.ticker import FormatStrFormatter, PercentFormatter, ScalarFormatter
23+
from pymatgen.core.structure import Structure
2224

2325

2426
if TYPE_CHECKING:
@@ -233,8 +235,6 @@ def bin_df_cols(
233235
)
234236

235237
if density_col:
236-
import scipy.stats # expensive import
237-
238238
# compute kernel density estimate for each bin
239239
values = df_in[bin_by_cols].dropna().T
240240
gaussian_kde = scipy.stats.gaussian_kde(values.astype(float))
@@ -679,7 +679,7 @@ def _get_matplotlib_font_color(fig: plt.Figure | plt.Axes) -> str:
679679

680680
def normalize_to_dict(
681681
inputs: T | Sequence[T] | dict[str, T],
682-
cls: type[T] | None = None,
682+
cls: type[T] = Structure,
683683
key_gen: Callable[[T], str] = lambda obj: getattr(
684684
obj, "formula", type(obj).__name__
685685
),
@@ -699,13 +699,8 @@ def normalize_to_dict(
699699
Raises:
700700
TypeError: If the input format is invalid.
701701
"""
702-
if cls is None:
703-
from pymatgen.core import Structure # costly import
704-
705-
cls = Structure
706-
707702
if isinstance(inputs, cls):
708-
return {"": inputs} # type: ignore[dict-item]
703+
return {"": inputs}
709704

710705
if (
711706
isinstance(inputs, list | tuple)
@@ -728,23 +723,3 @@ def normalize_to_dict(
728723
raise TypeError(
729724
f"Invalid {inputs=}, expected {cls_name} or dict/list/tuple of {cls_name}"
730725
)
731-
732-
733-
def _check_type(obj: object, type_str: tuple[str, ...] | str) -> bool:
734-
"""Alternative to isinstance that avoids imports.
735-
736-
Todo:
737-
Taken from monty.json, use until monty.json import fix merged.
738-
739-
Note for future developers: the type_str is not always obvious for an
740-
object. For example, pandas.DataFrame is actually pandas.core.frame.DataFrame.
741-
To find out the type_str for an object, run type(obj).mro(). This will
742-
list all the types that an object can resolve to in order of generality
743-
(all objects have the builtins.object as the last one).
744-
"""
745-
type_str = type_str if isinstance(type_str, tuple) else (type_str,)
746-
try:
747-
mro = type(obj).mro()
748-
except TypeError:
749-
return False
750-
return any(f"{o.__module__}.{o.__name__}" == ts for o in mro for ts in type_str)

pymatviz/xrd.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,14 @@
88
import plotly.graph_objects as go
99
from plotly.subplots import make_subplots
1010
from pymatgen.analysis.diffraction.xrd import DiffractionPattern, XRDCalculator
11-
12-
from pymatviz.utils import _check_type
11+
from pymatgen.core import Structure
1312

1413

1514
if TYPE_CHECKING:
1615
from typing import Any, TypeAlias
1716

18-
from pymatgen.core import Structure
19-
20-
PatternOrStruct: TypeAlias = DiffractionPattern | Structure
2117

18+
PatternOrStruct: TypeAlias = DiffractionPattern | Structure
2219
HklFormat: TypeAlias = Literal["compact", "full", None]
2320
ValidHklFormats = HklCompact, HklFull, HklNone = get_args(HklFormat)
2421

@@ -101,9 +98,7 @@ def xrd_pattern( # noqa: D417
10198
)
10299

103100
# Convert single object to dict for uniform processing
104-
if isinstance(patterns, DiffractionPattern) or _check_type(
105-
patterns, "pymatgen.core.structure.Structure"
106-
):
101+
if isinstance(patterns, DiffractionPattern | Structure):
107102
patterns = {"XRD Pattern": patterns}
108103
elif not isinstance(patterns, dict):
109104
raise TypeError(
@@ -141,7 +136,7 @@ def xrd_pattern( # noqa: D417
141136
else:
142137
pattern_or_struct, trace_kwargs = pattern_data, {}
143138

144-
if _check_type(pattern_or_struct, "pymatgen.core.structure.Structure"):
139+
if isinstance(pattern_or_struct, Structure):
145140
xrd_calculator = XRDCalculator(wavelength=wavelength)
146141
diffraction_pattern = xrd_calculator.get_pattern(pattern_or_struct)
147142
elif isinstance(pattern_or_struct, DiffractionPattern):

0 commit comments

Comments
 (0)