Skip to content

Commit 622866a

Browse files
authored
Adds typed dict extract fields support (#1253)
Adds typeddict support for extract_fields This enables more ergonomic usage of TypedDict with extract fields. I skipped adding support for returning any `Mapping`. Though that should be an easy addition. ```python from typing import TypedDict from hamilton.function_modifiers import extract_fields class MyDict(TypedDict): foo: str bar: int @extract_fields() def some_function() -> MyDict: return MyDict(foo="s", bar=1) ``` The above will automatically extract the fields foo and bar. You can also do: ```python from typing import TypedDict from hamilton.function_modifiers import extract_fields class MyDict(TypedDict): foo: str bar: int @extract_fields({"foo": str}) def some_function()->MyDict: return MyDict(foo="s", bar=1) ``` To only expose a subset of the fields. Squashed commits: * Adds sketch of improving extract_fields with typeddict This in response to #1252. We should be able to handle typeddict better. This sketches some ideas: 1. field validation should happen in .validate() not the constructor. 2. extract_fields shouldn't need fields if the typeddict is the annotation type. 3. we properly check that typeddict can be a return type. * Adds typeddict tests * Adding validation to cover all extract_fields paths * Adds Typeddict Extract fields subclass type check and test for it
1 parent fc239a9 commit 622866a

File tree

2 files changed

+92
-7
lines changed

2 files changed

+92
-7
lines changed

hamilton/function_modifiers/expanders.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import typing
66
from typing import Any, Callable, Collection, Dict, Tuple, Union
77

8+
import typing_extensions
89
import typing_inspect
910

10-
from hamilton import node, registry
11+
from hamilton import htypes, node, registry
1112
from hamilton.dev_utils import deprecation
1213
from hamilton.function_modifiers import base
1314
from hamilton.function_modifiers.dependencies import (
@@ -733,7 +734,7 @@ def _validate_extract_fields(fields: dict):
733734
class extract_fields(base.SingleNodeNodeTransformer):
734735
"""Extracts fields from a dictionary of output."""
735736

736-
def __init__(self, fields: dict, fill_with: Any = None):
737+
def __init__(self, fields: dict = None, fill_with: Any = None):
737738
"""Constructor for a modifier that expands a single function into the following nodes:
738739
739740
- n functions, each of which take in the original dict and output a specific field
@@ -745,7 +746,6 @@ def __init__(self, fields: dict, fill_with: Any = None):
745746
field value.
746747
"""
747748
super(extract_fields, self).__init__()
748-
_validate_extract_fields(fields)
749749
self.fields = fields
750750
self.fill_with = fill_with
751751

@@ -759,13 +759,32 @@ def validate(self, fn: Callable):
759759
if typing_inspect.is_generic_type(output_type):
760760
base_type = typing_inspect.get_origin(output_type)
761761
if base_type == dict or base_type == Dict:
762-
pass
762+
_validate_extract_fields(self.fields)
763763
else:
764764
raise base.InvalidDecoratorException(
765765
f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}"
766766
)
767767
elif output_type == dict:
768-
pass
768+
_validate_extract_fields(self.fields)
769+
elif typing_extensions.is_typeddict(output_type):
770+
if self.fields is None:
771+
self.fields = typing.get_type_hints(output_type)
772+
else:
773+
# check that fields is a subset of TypedDict that is defined
774+
typed_dict_fields = typing.get_type_hints(output_type)
775+
for field_name, field_type in self.fields.items():
776+
expected_type = typed_dict_fields.get(field_name, None)
777+
if expected_type == field_type:
778+
pass # we're definitely good
779+
elif expected_type is not None and htypes.custom_subclass_check(
780+
field_type, expected_type
781+
):
782+
pass
783+
else:
784+
raise base.InvalidDecoratorException(
785+
f"Error {self.fields} did not match a subset of the TypedDict annotation's fields {typed_dict_fields}."
786+
)
787+
_validate_extract_fields(self.fields)
769788
else:
770789
raise base.InvalidDecoratorException(
771790
f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}"

tests/function_modifiers/test_expanders.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Any, Dict, List, Optional, Type
2+
from typing import Any, Dict, List, Optional, Type, TypedDict
33

44
import numpy as np
55
import pandas as pd
@@ -313,13 +313,23 @@ def test_extract_fields_constructor_happy(fields):
313313
expanders._validate_extract_fields(fields)
314314

315315

316+
class MyDict(TypedDict):
317+
test: int
318+
test2: str
319+
320+
321+
class MyDictBad(TypedDict):
322+
test2: str
323+
324+
316325
@pytest.mark.parametrize(
317326
"return_type",
318327
[
319328
dict,
320329
Dict,
321330
Dict[str, str],
322331
Dict[str, Any],
332+
MyDict,
323333
],
324334
)
325335
def test_extract_fields_validate_happy(return_type):
@@ -330,7 +340,45 @@ def return_dict() -> return_type:
330340
annotation.validate(return_dict)
331341

332342

333-
@pytest.mark.parametrize("return_type", [(int), (list), (np.ndarray), (pd.DataFrame)])
343+
class SomeObject:
344+
pass
345+
346+
347+
class InheritedObject(SomeObject):
348+
pass
349+
350+
351+
class MyDictInheritance(TypedDict):
352+
test: SomeObject
353+
test2: str
354+
355+
356+
class MyDictInheritanceBadCase(TypedDict):
357+
test: InheritedObject
358+
test2: str
359+
360+
361+
def test_extract_fields_validate_happy_inheritance():
362+
def return_dict() -> MyDictInheritance:
363+
return {}
364+
365+
annotation = function_modifiers.extract_fields({"test": InheritedObject})
366+
annotation.validate(return_dict)
367+
368+
369+
def test_extract_fields_validate_not_subclass():
370+
def return_dict() -> MyDictInheritanceBadCase:
371+
return {}
372+
373+
annotation = function_modifiers.extract_fields({"test": SomeObject})
374+
with pytest.raises(base.InvalidDecoratorException):
375+
annotation.validate(return_dict)
376+
377+
378+
@pytest.mark.parametrize(
379+
"return_type",
380+
[(int), (list), (np.ndarray), (pd.DataFrame), (MyDictBad)],
381+
)
334382
def test_extract_fields_validate_errors(return_type):
335383
def return_dict() -> return_type:
336384
return {}
@@ -340,6 +388,24 @@ def return_dict() -> return_type:
340388
annotation.validate(return_dict)
341389

342390

391+
def test_extract_fields_typeddict_empty_fields():
392+
def return_dict() -> MyDict:
393+
return {}
394+
395+
# don't need fields for TypedDict
396+
annotation = function_modifiers.extract_fields()
397+
annotation.validate(return_dict)
398+
399+
400+
def test_extract_fields_typeddict_subset():
401+
def return_dict() -> MyDict:
402+
return {}
403+
404+
# test that a subset of fields is fine
405+
annotation = function_modifiers.extract_fields({"test2": str})
406+
annotation.validate(return_dict)
407+
408+
343409
def test_valid_extract_fields():
344410
"""Tests whole extract_fields decorator."""
345411
annotation = function_modifiers.extract_fields(

0 commit comments

Comments
 (0)