Skip to content

Commit 82a9d57

Browse files
authored
feat: Better support for dataclasses
Instead of generating parameters on the fly by (wrongly) checking attributes of the class, we always load a Griffe extension that re-creates `__init__` methods and their parameters. Issue-33: #233 Issue-34: #234 Issue-38: #238 Issue-39: #239 PR-240: #240
1 parent 9efda88 commit 82a9d57

File tree

9 files changed

+409
-46
lines changed

9 files changed

+409
-46
lines changed

src/griffe/agents/inspector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from griffe.dataclasses import Alias, Attribute, Class, Docstring, Function, Module, Parameter, Parameters
3434
from griffe.enumerations import ObjectKind, ParameterKind
3535
from griffe.expressions import safe_get_annotation
36-
from griffe.extensions.base import Extensions
36+
from griffe.extensions.base import Extensions, load_extensions
3737
from griffe.importer import dynamic_import
3838

3939
if TYPE_CHECKING:
@@ -77,7 +77,7 @@ def inspect(
7777
return Inspector(
7878
module_name,
7979
filepath,
80-
extensions or Extensions(),
80+
extensions or load_extensions(),
8181
parent,
8282
docstring_parser=docstring_parser,
8383
docstring_options=docstring_options,

src/griffe/agents/visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
safe_get_condition,
3636
safe_get_expression,
3737
)
38-
from griffe.extensions.base import Extensions
38+
from griffe.extensions.base import Extensions, load_extensions
3939

4040
if TYPE_CHECKING:
4141
from pathlib import Path
@@ -92,7 +92,7 @@ def visit(
9292
module_name,
9393
filepath,
9494
code,
95-
extensions or Extensions(),
95+
extensions or load_extensions(),
9696
parent,
9797
docstring_parser=docstring_parser,
9898
docstring_options=docstring_options,

src/griffe/dataclasses.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,16 @@ def __str__(self) -> str:
203203
def __repr__(self) -> str:
204204
return f"Parameter(name={self.name!r}, annotation={self.annotation!r}, kind={self.kind!r}, default={self.default!r})"
205205

206+
def __eq__(self, __value: object) -> bool:
207+
if not isinstance(__value, Parameter):
208+
return NotImplemented
209+
return (
210+
self.name == __value.name
211+
and self.annotation == __value.annotation
212+
and self.kind == __value.kind
213+
and self.default == __value.default
214+
)
215+
206216
@property
207217
def required(self) -> bool:
208218
"""Whether this parameter is required."""
@@ -1561,14 +1571,6 @@ def parameters(self) -> Parameters:
15611571
try:
15621572
return self.all_members["__init__"].parameters # type: ignore[union-attr]
15631573
except KeyError:
1564-
if "dataclass" in self.labels:
1565-
return Parameters(
1566-
*[
1567-
Parameter(attr.name, annotation=attr.annotation, default=attr.value)
1568-
for attr in self.attributes.values()
1569-
if "property" not in attr.labels
1570-
],
1571-
)
15721574
return Parameters()
15731575

15741576
@cached_property

src/griffe/expressions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ class ExprCall(Expr):
243243
arguments: Sequence[str | Expr]
244244
"""Passed arguments."""
245245

246+
@property
247+
def canonical_path(self) -> str:
248+
"""The canonical path of this subscript's left part."""
249+
return self.function.canonical_path
250+
246251
def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]: # noqa: D102
247252
yield from _yield(self.function, flat=flat)
248253
yield "("

src/griffe/extensions/base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def call(self, event: str, **kwargs: Any) -> None:
347347

348348
builtin_extensions: set[str] = {
349349
"hybrid",
350+
"dataclasses",
350351
}
351352

352353

@@ -454,7 +455,9 @@ def _load_extension(
454455
return [ext(**options) for ext in extensions]
455456

456457

457-
def load_extensions(exts: Sequence[str | dict[str, Any] | ExtensionType | type[ExtensionType]]) -> Extensions:
458+
def load_extensions(
459+
exts: Sequence[str | dict[str, Any] | ExtensionType | type[ExtensionType]] | None = None,
460+
) -> Extensions:
458461
"""Load configured extensions.
459462
460463
Parameters:
@@ -464,12 +467,23 @@ def load_extensions(exts: Sequence[str | dict[str, Any] | ExtensionType | type[E
464467
An extensions container.
465468
"""
466469
extensions = Extensions()
467-
for extension in exts:
470+
for extension in exts or ():
468471
ext = _load_extension(extension)
469472
if isinstance(ext, list):
470473
extensions.add(*ext)
471474
else:
472475
extensions.add(ext)
476+
477+
# TODO: Deprecate and remove at some point?
478+
# Always add our built-in dataclasses extension.
479+
from griffe.extensions.dataclasses import DataclassesExtension
480+
481+
for ext in extensions._extensions:
482+
if type(ext) == DataclassesExtension:
483+
break
484+
else:
485+
extensions.add(*_load_extension("dataclasses")) # type: ignore[misc]
486+
473487
return extensions
474488

475489

src/griffe/extensions/dataclasses.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""Built-in extension adding support for dataclasses.
2+
3+
This extension re-creates `__init__` methods of dataclasses
4+
during static analysis.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import ast
10+
from contextlib import suppress
11+
from functools import lru_cache
12+
from typing import Any, cast
13+
14+
from griffe.dataclasses import Attribute, Class, Decorator, Function, Module, Parameter, Parameters
15+
from griffe.enumerations import ParameterKind
16+
from griffe.expressions import (
17+
Expr,
18+
ExprAttribute,
19+
ExprCall,
20+
ExprDict,
21+
)
22+
from griffe.extensions.base import Extension
23+
24+
25+
def _dataclass_decorator(decorators: list[Decorator]) -> Expr | None:
26+
for decorator in decorators:
27+
if isinstance(decorator.value, Expr) and decorator.value.canonical_path == "dataclasses.dataclass":
28+
return decorator.value
29+
return None
30+
31+
32+
def _expr_args(expr: Expr) -> dict[str, str | Expr]:
33+
args = {}
34+
if isinstance(expr, ExprCall):
35+
for argument in expr.arguments:
36+
try:
37+
args[argument.name] = argument.value # type: ignore[union-attr]
38+
except AttributeError:
39+
# Argument is a unpacked variable.
40+
with suppress(Exception):
41+
collection = expr.function.parent.modules_collection # type: ignore[attr-defined]
42+
var = collection[argument.value.canonical_path] # type: ignore[union-attr]
43+
args.update(_expr_args(var.value))
44+
elif isinstance(expr, ExprDict):
45+
args.update({ast.literal_eval(str(key)): value for key, value in zip(expr.keys, expr.values)})
46+
return args
47+
48+
49+
def _dataclass_arguments(decorators: list[Decorator]) -> dict[str, Any]:
50+
if (expr := _dataclass_decorator(decorators)) and isinstance(expr, ExprCall):
51+
return _expr_args(expr)
52+
return {}
53+
54+
55+
def _field_arguments(attribute: Attribute) -> dict[str, Any]:
56+
if attribute.value:
57+
value = attribute.value
58+
if isinstance(value, ExprAttribute):
59+
value = value.last
60+
if isinstance(value, ExprCall) and value.canonical_path == "dataclasses.field":
61+
return _expr_args(value)
62+
return {}
63+
64+
65+
@lru_cache(maxsize=None)
66+
def _dataclass_parameters(class_: Class) -> list[Parameter]:
67+
# Fetch `@dataclass` arguments if any.
68+
dec_args = _dataclass_arguments(class_.decorators)
69+
70+
# Parameters not added to `__init__`, return empty list.
71+
if dec_args.get("init") == "False":
72+
return []
73+
74+
# All parameters marked as keyword-only.
75+
kw_only = dec_args.get("kw_only") == "True"
76+
77+
# Iterate on current attributes to find parameters.
78+
parameters = []
79+
for member in class_.members.values():
80+
if member.is_attribute:
81+
member = cast(Attribute, member)
82+
83+
# Start of keyword-only parameters.
84+
if isinstance(member.annotation, Expr) and member.annotation.canonical_path == "dataclasses.KW_ONLY":
85+
kw_only = True
86+
continue
87+
88+
# Fetch `field` arguments if any.
89+
field_args = _field_arguments(member)
90+
91+
# Parameter not added to `__init__`, skip it.
92+
if field_args.get("init") == "False":
93+
continue
94+
95+
# Determine parameter kind.
96+
kind = (
97+
ParameterKind.keyword_only
98+
if kw_only or field_args.get("kw_only") == "True"
99+
else ParameterKind.positional_or_keyword
100+
)
101+
102+
# Determine parameter default.
103+
if "default_factory" in field_args:
104+
default = ExprCall(function=field_args["default_factory"], arguments=[])
105+
else:
106+
default = field_args.get("default", None if field_args else member.value)
107+
108+
# Add parameter to the list.
109+
parameters.append(
110+
Parameter(
111+
member.name,
112+
annotation=member.annotation,
113+
kind=kind,
114+
default=default,
115+
),
116+
)
117+
118+
return parameters
119+
120+
121+
def _reorder_parameters(parameters: list[Parameter]) -> list[Parameter]:
122+
# De-duplicate, overwriting previous parameters.
123+
params_dict = {param.name: param for param in parameters}
124+
125+
# Re-order, putting positional-only in front and keyword-only at the end.
126+
pos_only = []
127+
pos_kw = []
128+
kw_only = []
129+
for param in params_dict.values():
130+
if param.kind is ParameterKind.positional_only:
131+
pos_only.append(param)
132+
elif param.kind is ParameterKind.keyword_only:
133+
kw_only.append(param)
134+
else:
135+
pos_kw.append(param)
136+
return pos_only + pos_kw + kw_only
137+
138+
139+
def _set_dataclass_init(class_: Class) -> None:
140+
# Retrieve parameters from all parent dataclasses.
141+
parameters = []
142+
try:
143+
mro = class_.mro()
144+
except ValueError:
145+
mro = () # type: ignore[assignment]
146+
for parent in reversed(mro):
147+
if _dataclass_decorator(parent.decorators):
148+
parameters.extend(_dataclass_parameters(parent))
149+
# At least one parent dataclass makes the current class a dataclass:
150+
# that's how `dataclasses.is_dataclass` works.
151+
class_.labels.add("dataclass")
152+
153+
# If the class is not decorated with `@dataclass`, skip it.
154+
if not _dataclass_decorator(class_.decorators):
155+
return
156+
157+
# Add current class parameters.
158+
parameters.extend(_dataclass_parameters(class_))
159+
160+
# Create `__init__` method with re-ordered parameters.
161+
init = Function(
162+
"__init__",
163+
lineno=0,
164+
endlineno=0,
165+
parent=class_,
166+
parameters=Parameters(
167+
Parameter(name="self", annotation=None, kind=ParameterKind.positional_or_keyword, default=None),
168+
*_reorder_parameters(parameters),
169+
),
170+
returns="None",
171+
)
172+
class_.set_member("__init__", init)
173+
174+
175+
def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None:
176+
if mod_cls.canonical_path in processed:
177+
return
178+
processed.add(mod_cls.canonical_path)
179+
if isinstance(mod_cls, Class):
180+
if "__init__" not in mod_cls.members:
181+
_set_dataclass_init(mod_cls)
182+
for member in mod_cls.members.values():
183+
if not member.is_alias and member.is_class:
184+
_apply_recursively(member, processed) # type: ignore[arg-type]
185+
elif isinstance(mod_cls, Module):
186+
for member in mod_cls.members.values():
187+
if not member.is_alias and (member.is_module or member.is_class):
188+
_apply_recursively(member, processed) # type: ignore[arg-type]
189+
190+
191+
class DataclassesExtension(Extension):
192+
"""Built-in extension adding support for dataclasses.
193+
194+
This extension creates `__init__` methods of dataclasses
195+
if they don't already exist.
196+
"""
197+
198+
def on_package_loaded(self, *, pkg: Module) -> None:
199+
"""Hook for loaded packages.
200+
201+
Parameters:
202+
pkg: The loaded package.
203+
"""
204+
_apply_recursively(pkg, set())

src/griffe/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from griffe.enumerations import Kind
2727
from griffe.exceptions import AliasResolutionError, CyclicAliasError, LoadingError, UnimportableModuleError
2828
from griffe.expressions import ExprName
29-
from griffe.extensions.base import Extensions
29+
from griffe.extensions.base import Extensions, load_extensions
3030
from griffe.finder import ModuleFinder, NamespacePackage, Package
3131
from griffe.git import tmp_worktree
3232
from griffe.logger import get_logger
@@ -69,7 +69,7 @@ def __init__(
6969
allow_inspection: Whether to allow inspecting modules when visiting them is not possible.
7070
store_source: Whether to store code source in the lines collection.
7171
"""
72-
self.extensions: Extensions = extensions or Extensions()
72+
self.extensions: Extensions = extensions or load_extensions()
7373
"""Loaded Griffe extensions."""
7474
self.docstring_parser: Parser | None = docstring_parser
7575
"""Selected docstring parser."""

0 commit comments

Comments
 (0)