Skip to content

Commit d3495fc

Browse files
committed
Adds Typeddict Extract fields subclass type check and test for it
1 parent 7266a26 commit d3495fc

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

hamilton/function_modifiers/expanders.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import typing_extensions
99
import typing_inspect
1010

11-
from hamilton import node, registry
11+
from hamilton import htypes, node, registry
1212
from hamilton.dev_utils import deprecation
1313
from hamilton.function_modifiers import base
1414
from hamilton.function_modifiers.dependencies import (
@@ -772,8 +772,15 @@ def validate(self, fn: Callable):
772772
else:
773773
# check that fields is a subset of TypedDict that is defined
774774
typed_dict_fields = typing.get_type_hints(output_type)
775-
for k, v in self.fields.items():
776-
if typed_dict_fields.get(k, None) != v:
775+
for field_name, field_type in self.fields.items():
776+
expected_type = typed_dict_fields.get(field_name, None)
777+
if expected_type == field_type:
778+
pass # we're definitely good
779+
elif expected_type is not None and htypes.custom_subclass_check(
780+
field_type, expected_type
781+
):
782+
pass
783+
else:
777784
raise base.InvalidDecoratorException(
778785
f"Error {self.fields} did not match a subset of the TypedDict annotation's fields {typed_dict_fields}."
779786
)

tests/function_modifiers/test_expanders.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,41 @@ def return_dict() -> return_type:
340340
annotation.validate(return_dict)
341341

342342

343+
class SomeObject:
344+
pass
345+
346+
347+
class InheritedObject(SomeObject):
348+
pass
349+
350+
351+
class MyDictInheritance(TypedDict):
352+
test: SomeObject
353+
test2: str
354+
355+
356+
class MyDictInheritanceBadCase(TypedDict):
357+
test: InheritedObject
358+
test2: str
359+
360+
361+
def test_extract_fields_validate_happy_inheritance():
362+
def return_dict() -> MyDictInheritance:
363+
return {}
364+
365+
annotation = function_modifiers.extract_fields({"test": InheritedObject})
366+
annotation.validate(return_dict)
367+
368+
369+
def test_extract_fields_validate_not_subclass():
370+
def return_dict() -> MyDictInheritanceBadCase:
371+
return {}
372+
373+
annotation = function_modifiers.extract_fields({"test": SomeObject})
374+
with pytest.raises(base.InvalidDecoratorException):
375+
annotation.validate(return_dict)
376+
377+
343378
@pytest.mark.parametrize(
344379
"return_type",
345380
[(int), (list), (np.ndarray), (pd.DataFrame), (MyDictBad)],

0 commit comments

Comments
 (0)