Skip to content

Commit cb3d9e3

Browse files
Fix unmarshalling of forward references on Python ≥ 3.12.4 (#252)
Fixes the need for specifing recursive_guard as named arg from kwarg _evaluate from python 3.12.4+ --------- Co-authored-by: Andrew Snare <[email protected]>
1 parent fad76e8 commit cb3d9e3

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

src/databricks/labs/blueprint/installation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import os.path
1111
import re
12+
import sys
1213
import threading
1314
import types
1415
import typing
@@ -668,13 +669,30 @@ class _FromDict(Protocol):
668669
def from_dict(cls, raw: dict):
669670
pass
670671

672+
# Internal utility for evaluating forward references; mypy handles version checks, but only top-level.
673+
if sys.version_info >= (3, 12, 4):
674+
# Since Python 3.12.4, `ForwardRef._evaluate` requires recursive_guard as a keyword, and has an additional
675+
# parameter for type information.
676+
@staticmethod
677+
def _evaluate_forward_ref(type_ref: typing.ForwardRef) -> type:
678+
"""Evaluate a forward reference to a type."""
679+
# pylint: disable-next=protected-access
680+
return type_ref._evaluate(globals(), locals(), (), recursive_guard=frozenset()) # type: ignore[arg-type,misc,return-value]
681+
else:
682+
# Older versions of Python do
683+
@staticmethod
684+
def _evaluate_forward_ref(type_ref: typing.ForwardRef) -> type:
685+
"""Evaluate a forward reference to a type."""
686+
# pylint: disable-next=protected-access
687+
return type_ref._evaluate(globals(), locals(), recursive_guard=frozenset()) # type: ignore[return-value]
688+
671689
@classmethod
672690
def _unmarshal(cls, inst: Any, path: list[str], type_ref: type[T]) -> T | None:
673691
"""The `_unmarshal` method is a private method that is used to deserialize a dictionary to an object of type
674692
`type_ref`. This method is called by the `load` method."""
675693
# Forward-references aren't always resolved, so we need to handle them. (Assumes reference is visible here.)
676694
if isinstance(type_ref, typing.ForwardRef):
677-
type_ref = type_ref._evaluate(globals(), locals(), frozenset()) # pylint: disable=protected-access
695+
type_ref = cls._evaluate_forward_ref(type_ref)
678696
if dataclasses.is_dataclass(type_ref):
679697
return cls._unmarshal_dataclass(inst, path, type_ref)
680698
if isinstance(type_ref, enum.EnumMeta):

tests/unit/test_installation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,3 +732,26 @@ class SampleClass:
732732
installation = MockInstallation({"something.json": raw_data})
733733
loaded = installation.load(SampleClass, filename="something.json")
734734
assert loaded == expected
735+
736+
737+
def test_forward_referencing_class() -> None:
738+
"""Test that a class with forward-referenced fields. This simulates the behavior of future annotations."""
739+
740+
@dataclass
741+
class ForwardReferencingClass:
742+
field_str: "str" = "foo"
743+
field_int: "int" = 20
744+
field_bool: "bool" = False
745+
field_float: "float" = 2.3
746+
field_dict: "dict[str, int]" = dataclasses.field(default_factory=dict)
747+
field_list: "list[str]" = dataclasses.field(default_factory=list)
748+
field_optional: "str | None" = None
749+
field_json: "JsonValue" = None
750+
751+
instance = ForwardReferencingClass(field_dict={"a": 1, "b": 2}, field_list=["x", "y"], field_json={"key": "value"})
752+
753+
installation = MockInstallation()
754+
installation.save(instance, filename="saved.yml")
755+
756+
loaded = installation.load(ForwardReferencingClass, filename="saved.yml")
757+
assert instance == loaded

0 commit comments

Comments
 (0)