Skip to content

Commit 81809eb

Browse files
authored
Self return type on from_dict methods (#3702)
def from_dict(cls, dct: dict) -> Self:
1 parent 4b41edc commit 81809eb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+620
-517
lines changed

dev_scripts/update_pt_data.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,16 @@ def parse_radii():
128128
def update_ionic_radii():
129129
data = loadfn(ptable_yaml_path)
130130

131-
for d in data.values():
132-
if "Ionic_radii" in d:
133-
d["Ionic radii"] = {k: v / 100 for k, v in d["Ionic_radii"].items()}
134-
del d["Ionic_radii"]
135-
if "Ionic_radii_hs" in d:
136-
d["Ionic radii hs"] = {k: v / 100 for k, v in d["Ionic_radii_hs"].items()}
137-
del d["Ionic_radii_hs"]
138-
if "Ionic_radii_ls" in d:
139-
d["Ionic radii ls"] = {k: v / 100 for k, v in d["Ionic_radii_ls"].items()}
140-
del d["Ionic_radii_ls"]
131+
for dct in data.values():
132+
if "Ionic_radii" in dct:
133+
dct["Ionic radii"] = {k: v / 100 for k, v in dct["Ionic_radii"].items()}
134+
del dct["Ionic_radii"]
135+
if "Ionic_radii_hs" in dct:
136+
dct["Ionic radii hs"] = {k: v / 100 for k, v in dct["Ionic_radii_hs"].items()}
137+
del dct["Ionic_radii_hs"]
138+
if "Ionic_radii_ls" in dct:
139+
dct["Ionic radii ls"] = {k: v / 100 for k, v in dct["Ionic_radii_ls"].items()}
140+
del dct["Ionic_radii_ls"]
141141
with open("periodic_table2.yaml", mode="w") as file:
142142
yaml.dump(data, file)
143143
with open("../pymatgen/core/periodic_table.json", mode="w") as file:

pymatgen/alchemy/filters.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1414

1515
if TYPE_CHECKING:
16+
from typing_extensions import Self
17+
1618
from pymatgen.core import Structure
1719

1820

@@ -106,7 +108,7 @@ def as_dict(self):
106108
}
107109

108110
@classmethod
109-
def from_dict(cls, dct):
111+
def from_dict(cls, dct: dict) -> Self:
110112
"""
111113
Args:
112114
dct (dict): Dict representation.
@@ -165,7 +167,7 @@ def as_dict(self):
165167
}
166168

167169
@classmethod
168-
def from_dict(cls, dct):
170+
def from_dict(cls, dct: dict) -> Self:
169171
"""
170172
Args:
171173
dct (dict): Dict representation.

pymatgen/analysis/chemenv/connectivity/connected_components.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import itertools
66
import logging
7+
from typing import TYPE_CHECKING
78

89
import matplotlib.pyplot as plt
910
import networkx as nx
@@ -18,6 +19,9 @@
1819
from pymatgen.analysis.chemenv.utils.graph_utils import get_delta
1920
from pymatgen.analysis.chemenv.utils.math_utils import get_linearly_independent_vectors
2021

22+
if TYPE_CHECKING:
23+
from typing_extensions import Self
24+
2125

2226
def draw_network(env_graph, pos, ax, sg=None, periodicity_vectors=None):
2327
"""Draw network of environments in a matplotlib figure axes.
@@ -827,29 +831,29 @@ def as_dict(self):
827831
}
828832

829833
@classmethod
830-
def from_dict(cls, d):
834+
def from_dict(cls, dct: dict) -> Self:
831835
"""
832836
Reconstructs the ConnectedComponent object from a dict representation of the
833837
ConnectedComponent object created using the as_dict method.
834838
835839
Args:
836-
d (dict): dict representation of the ConnectedComponent object
840+
dct (dict): dict representation of the ConnectedComponent object
837841
838842
Returns:
839843
ConnectedComponent: The connected component representing the links of a given set of environments.
840844
"""
841845
nodes_map = {
842-
inode_str: EnvironmentNode.from_dict(nodedict) for inode_str, (nodedict, nodedata) in d["nodes"].items()
846+
inode_str: EnvironmentNode.from_dict(nodedict) for inode_str, (nodedict, nodedata) in dct["nodes"].items()
843847
}
844-
nodes_data = {inode_str: nodedata for inode_str, (nodedict, nodedata) in d["nodes"].items()}
845-
dod = {}
846-
for e1, e1dict in d["graph"].items():
847-
dod[e1] = {}
848+
nodes_data = {inode_str: nodedata for inode_str, (nodedict, nodedata) in dct["nodes"].items()}
849+
nested_dict: dict[str, dict] = {}
850+
for e1, e1dict in dct["graph"].items():
851+
nested_dict[e1] = {}
848852
for e2, e2dict in e1dict.items():
849-
dod[e1][e2] = {
853+
nested_dict[e1][e2] = {
850854
cls._edgedictkey_to_edgekey(ied): cls._retuplify_edgedata(edata) for ied, edata in e2dict.items()
851855
}
852-
graph = nx.from_dict_of_dicts(dod, create_using=nx.MultiGraph, multigraph_input=True)
856+
graph = nx.from_dict_of_dicts(nested_dict, create_using=nx.MultiGraph, multigraph_input=True)
853857
nx.set_node_attributes(graph, nodes_data)
854858
nx.relabel_nodes(graph, nodes_map, copy=False)
855859
return cls(graph=graph)

pymatgen/analysis/chemenv/connectivity/structure_connectivity.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import collections
66
import logging
7+
from typing import TYPE_CHECKING
78

89
import networkx as nx
910
import numpy as np
@@ -13,6 +14,9 @@
1314
from pymatgen.analysis.chemenv.connectivity.environment_nodes import get_environment_node
1415
from pymatgen.analysis.chemenv.coordination_environments.structure_environments import LightStructureEnvironments
1516

17+
if TYPE_CHECKING:
18+
from typing_extensions import Self
19+
1620
__author__ = "David Waroquiers"
1721
__copyright__ = "Copyright 2012, The Materials Project"
1822
__credits__ = "Geoffroy Hautier"
@@ -299,26 +303,30 @@ def as_dict(self):
299303
}
300304

301305
@classmethod
302-
def from_dict(cls, d):
306+
def from_dict(cls, dct: dict) -> Self:
303307
"""
304308
Args:
305-
d ():
309+
dct (dict):
306310
307311
Returns:
308312
StructureConnectivity
309313
"""
310314
# Reconstructs the graph with integer as nodes (json's as_dict replaces integer keys with str keys)
311-
cgraph = nx.from_dict_of_dicts(d["connectivity_graph"], create_using=nx.MultiGraph, multigraph_input=True)
312-
cgraph = nx.relabel_nodes(cgraph, int) # Just relabel the nodes using integer casting (maps str->int)
315+
connect_graph = nx.from_dict_of_dicts(
316+
dct["connectivity_graph"], create_using=nx.MultiGraph, multigraph_input=True
317+
)
318+
connect_graph = nx.relabel_nodes(
319+
connect_graph, int
320+
) # Just relabel the nodes using integer casting (maps str->int)
313321
# Relabel multi-edges (removes multi-edges with str keys and adds them back with int keys)
314-
edges = set(cgraph.edges())
322+
edges = set(connect_graph.edges())
315323
for n1, n2 in edges:
316-
new_edges = {int(iedge): edata for iedge, edata in cgraph[n1][n2].items()}
317-
cgraph.remove_edges_from([(n1, n2, iedge) for iedge, edata in cgraph[n1][n2].items()])
318-
cgraph.add_edges_from([(n1, n2, iedge, edata) for iedge, edata in new_edges.items()])
324+
new_edges = {int(iedge): edata for iedge, edata in connect_graph[n1][n2].items()}
325+
connect_graph.remove_edges_from([(n1, n2, iedge) for iedge, edata in connect_graph[n1][n2].items()])
326+
connect_graph.add_edges_from([(n1, n2, iedge, edata) for iedge, edata in new_edges.items()])
319327
return cls(
320-
LightStructureEnvironments.from_dict(d["light_structure_environments"]),
321-
connectivity_graph=cgraph,
328+
LightStructureEnvironments.from_dict(dct["light_structure_environments"]),
329+
connectivity_graph=connect_graph,
322330
environment_subgraphs=None,
323331
)
324332
# TODO: also deserialize the environment_subgraphs

pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def as_dict(self):
8282
}
8383

8484
@classmethod
85-
def from_dict(cls, d: dict) -> Self:
85+
def from_dict(cls, dct: dict) -> Self:
8686
"""Initialize distance cutoff from dict.
8787
8888
Args:
89-
d: Dict representation of the distance cutoff.
89+
dct (dict): Dict representation of the distance cutoff.
9090
"""
91-
return cls(d["value"])
91+
return cls(dct["value"])
9292

9393

9494
class AngleCutoffFloat(float, StrategyOption):
@@ -116,13 +116,13 @@ def as_dict(self):
116116
}
117117

118118
@classmethod
119-
def from_dict(cls, d):
119+
def from_dict(cls, dct: dict) -> Self:
120120
"""Initialize angle cutoff from dict.
121121
122122
Args:
123-
d: Dict representation of the angle cutoff.
123+
dct (dict): Dict representation of the angle cutoff.
124124
"""
125-
return cls(d["value"])
125+
return cls(dct["value"])
126126

127127

128128
class CSMFloat(float, StrategyOption):
@@ -154,7 +154,7 @@ def from_dict(cls, dct: dict) -> Self:
154154
"""Initialize CSM from dict.
155155
156156
Args:
157-
d: Dict representation of the CSM.
157+
dct (dict): Dict representation of the CSM.
158158
"""
159159
return cls(dct["value"])
160160

@@ -188,7 +188,7 @@ def from_dict(cls, dct: dict) -> Self:
188188
"""Initialize additional condition from dict.
189189
190190
Args:
191-
d: Dict representation of the additional condition.
191+
dct (dict): Dict representation of the additional condition.
192192
"""
193193
return cls(dct["value"])
194194

@@ -1350,7 +1350,7 @@ def as_dict(self):
13501350
}
13511351

13521352
@classmethod
1353-
def from_dict(cls, dct):
1353+
def from_dict(cls, dct: dict) -> Self:
13541354
"""Construct AngleNbSetWeight from dict representation."""
13551355
return cls(aa=dct["aa"])
13561356

@@ -1418,7 +1418,7 @@ def from_dict(cls, dct: dict) -> Self:
14181418
"""Initialize from dict.
14191419
14201420
Args:
1421-
dct: Dict representation of NormalizedAngleDistanceNbSetWeight.
1421+
dct (dict): Dict representation of NormalizedAngleDistanceNbSetWeight.
14221422
14231423
Returns:
14241424
NormalizedAngleDistanceNbSetWeight.
@@ -1722,7 +1722,7 @@ def from_dict(cls, dct: dict) -> Self:
17221722
"""Initialize from dict.
17231723
17241724
Args:
1725-
dct: Dict representation of SelfCSMNbSetWeight.
1725+
dct (dict): Dict representation of SelfCSMNbSetWeight.
17261726
17271727
Returns:
17281728
SelfCSMNbSetWeight.
@@ -1960,7 +1960,7 @@ def from_dict(cls, dct: dict) -> Self:
19601960
"""Initialize from dict.
19611961
19621962
Args:
1963-
dct: Dict representation of DeltaCSMNbSetWeight.
1963+
dct (dict): Dict representation of DeltaCSMNbSetWeight.
19641964
19651965
Returns:
19661966
DeltaCSMNbSetWeight.
@@ -2026,7 +2026,7 @@ def from_dict(cls, dct: dict) -> Self:
20262026
"""Initialize from dict.
20272027
20282028
Args:
2029-
dct: Dict representation of CNBiasNbSetWeight.
2029+
dct (dict): Dict representation of CNBiasNbSetWeight.
20302030
20312031
Returns:
20322032
CNBiasNbSetWeight.
@@ -2096,7 +2096,7 @@ def from_description(cls, dct: dict) -> Self:
20962096
"""Initialize weights from description.
20972097
20982098
Args:
2099-
dct: Dictionary description.
2099+
dct (dict): Dictionary description.
21002100
21012101
Returns:
21022102
CNBiasNbSetWeight.
@@ -2334,7 +2334,7 @@ def from_dict(cls, dct: dict) -> Self:
23342334
"""Initialize from dict.
23352335
23362336
Args:
2337-
dct: Dict representation of DistanceAngleAreaNbSetWeight.
2337+
dct (dict): Dict representation of DistanceAngleAreaNbSetWeight.
23382338
23392339
Returns:
23402340
DistanceAngleAreaNbSetWeight.
@@ -2404,7 +2404,7 @@ def from_dict(cls, dct: dict) -> Self:
24042404
"""Initialize from dict.
24052405
24062406
Args:
2407-
dct: Dict representation of DistancePlateauNbSetWeight.
2407+
dct (dict): Dict representation of DistancePlateauNbSetWeight.
24082408
24092409
Returns:
24102410
DistancePlateauNbSetWeight.
@@ -2471,7 +2471,7 @@ def from_dict(cls, dct: dict) -> Self:
24712471
"""Initialize from dict.
24722472
24732473
Args:
2474-
dct: Dict representation of AnglePlateauNbSetWeight.
2474+
dct (dict): Dict representation of AnglePlateauNbSetWeight.
24752475
24762476
Returns:
24772477
AnglePlateauNbSetWeight.
@@ -2552,7 +2552,7 @@ def from_dict(cls, dct: dict) -> Self:
25522552
"""Initialize from dict.
25532553
25542554
Args:
2555-
dct: Dict representation of DistanceNbSetWeight.
2555+
dct (dict): Dict representation of DistanceNbSetWeight.
25562556
25572557
Returns:
25582558
DistanceNbSetWeight.
@@ -2636,7 +2636,7 @@ def from_dict(cls, dct: dict) -> Self:
26362636
"""Initialize from dict.
26372637
26382638
Args:
2639-
dct: Dict representation of DeltaDistanceNbSetWeight.
2639+
dct (dict): Dict representation of DeltaDistanceNbSetWeight.
26402640
26412641
Returns:
26422642
DeltaDistanceNbSetWeight.

pymatgen/analysis/chemenv/coordination_environments/coordination_geometries.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@
1313
import itertools
1414
import json
1515
import os
16+
from typing import TYPE_CHECKING
1617

1718
import numpy as np
1819
from monty.json import MontyDecoder, MSONable
1920
from scipy.special import factorial
2021

22+
if TYPE_CHECKING:
23+
from typing_extensions import Self
24+
2125
__author__ = "David Waroquiers"
2226
__copyright__ = "Copyright 2012, The Materials Project"
2327
__credits__ = "Geoffroy Hautier"
@@ -109,7 +113,7 @@ def as_dict(self):
109113
}
110114

111115
@classmethod
112-
def from_dict(cls, dct):
116+
def from_dict(cls, dct: dict) -> Self:
113117
"""
114118
Reconstruct ExplicitPermutationsAlgorithm from its JSON-serializable dict representation.
115119
"""
@@ -324,7 +328,7 @@ def as_dict(self):
324328
}
325329

326330
@classmethod
327-
def from_dict(cls, dct):
331+
def from_dict(cls, dct: dict) -> Self:
328332
"""
329333
Reconstructs the SeparationPlane algorithm from its JSON-serializable dict representation.
330334
@@ -497,7 +501,7 @@ def as_dict(self):
497501
return {"hints_type": self.hints_type, "options": self.options}
498502

499503
@classmethod
500-
def from_dict(cls, dct):
504+
def from_dict(cls, dct: dict) -> Self:
501505
"""Reconstructs the NeighborsSetsHints from its JSON-serializable dict representation."""
502506
return cls(hints_type=dct["hints_type"], options=dct["options"])
503507

@@ -592,7 +596,7 @@ def as_dict(self):
592596
}
593597

594598
@classmethod
595-
def from_dict(cls, dct):
599+
def from_dict(cls, dct: dict) -> Self:
596600
"""
597601
Reconstructs the CoordinationGeometry from its JSON-serializable dict representation.
598602
@@ -602,7 +606,6 @@ def from_dict(cls, dct):
602606
Returns:
603607
CoordinationGeometry
604608
"""
605-
dec = MontyDecoder()
606609
return cls(
607610
mp_symbol=dct["mp_symbol"],
608611
name=dct["name"],
@@ -620,7 +623,7 @@ def from_dict(cls, dct):
620623
deactivate=dct["deactivate"],
621624
faces=dct["_faces"],
622625
edges=dct["_edges"],
623-
algorithms=[dec.process_decoded(algo_d) for algo_d in dct["_algorithms"]]
626+
algorithms=[MontyDecoder().process_decoded(algo_d) for algo_d in dct["_algorithms"]]
624627
if dct["_algorithms"] is not None
625628
else None,
626629
equivalent_indices=dct.get("equivalent_indices"),

0 commit comments

Comments
 (0)