Skip to content

Commit 161b2a6

Browse files
skrawczelijahbenizzy
authored andcommitted
Adds wrapping a result builder
So that people can adjust the result accordingly.
1 parent 7f36808 commit 161b2a6

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

hamilton/plugins/h_threadpool.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from concurrent.futures import Future, ThreadPoolExecutor
2-
from typing import Any, Callable, Dict
2+
from typing import Any, Callable, Dict, List, Type
33

44
from hamilton import registry
55

@@ -45,14 +45,39 @@ class FutureAdapter(base.BaseDoRemoteExecute, lifecycle.ResultBuilder):
4545
4646
"""
4747

48-
def __init__(self, max_workers: int = None, thread_name_prefix: str = ""):
48+
def __init__(
49+
self,
50+
max_workers: int = None,
51+
thread_name_prefix: str = "",
52+
result_builder: lifecycle.ResultBuilder = None,
53+
):
4954
"""Constructor.
5055
:param max_workers: The maximum number of threads that can be used to execute the given calls.
5156
:param thread_name_prefix: An optional name prefix to give our threads.
57+
:param result_builder: Optional. Result builder to use for building the result.
5258
"""
5359
self.executor = ThreadPoolExecutor(
5460
max_workers=max_workers, thread_name_prefix=thread_name_prefix
5561
)
62+
self.result_builder = result_builder
63+
64+
def input_types(self) -> List[Type[Type]]:
65+
"""Gives the applicable types to this result builder.
66+
This is optional for backwards compatibility, but is recommended.
67+
68+
:return: A list of types that this can apply to.
69+
"""
70+
# since this wraps a potential result builder, expose the input types of the wrapped
71+
# result builder doesn't make sense.
72+
return [Any]
73+
74+
def output_type(self) -> Type:
75+
"""Returns the output type of this result builder
76+
:return: the type that this creates
77+
"""
78+
if self.result_builder:
79+
return self.result_builder.output_type()
80+
return Any
5681

5782
def do_remote_execute(
5883
self,
@@ -81,4 +106,6 @@ def build_result(self, **outputs: Any) -> Any:
81106
for k, v in outputs.items():
82107
if isinstance(v, Future):
83108
outputs[k] = v.result()
109+
if self.result_builder:
110+
return self.result_builder.build_result(**outputs)
84111
return outputs

tests/plugins/test_h_threadpool.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from concurrent.futures import Future
2+
from typing import Any
23

4+
from hamilton import lifecycle
35
from hamilton.plugins.h_threadpool import FutureAdapter, _new_fn
46

57

@@ -58,3 +60,60 @@ def test_future_adapter_build_result():
5860

5961
result = adapter.build_result(a=future_a, b=future_b)
6062
assert result == {"a": 1, "b": 2}
63+
64+
65+
def test_future_adapter_input_types():
66+
adapter = FutureAdapter()
67+
assert adapter.input_types() == [Any]
68+
69+
70+
def test_future_adapter_output_type():
71+
adapter = FutureAdapter()
72+
assert adapter.output_type() == Any
73+
74+
75+
def test_future_adapter_input_types_with_result_builder():
76+
"""Tests that we ignore exposing the input types of the wrapped result builder."""
77+
78+
class MockResultBuilder(lifecycle.ResultBuilder):
79+
def build_result(self, **outputs: Any) -> Any:
80+
pass
81+
82+
def input_types(self):
83+
return [int, str]
84+
85+
adapter = FutureAdapter(result_builder=MockResultBuilder())
86+
assert adapter.input_types() == [Any]
87+
88+
89+
def test_future_adapter_output_type_with_result_builder():
90+
class MockResultBuilder(lifecycle.ResultBuilder):
91+
def build_result(self, **outputs: Any) -> Any:
92+
pass
93+
94+
def output_type(self):
95+
return dict
96+
97+
adapter = FutureAdapter(result_builder=MockResultBuilder())
98+
assert adapter.output_type() == dict
99+
100+
101+
def test_future_adapter_build_result_with_result_builder():
102+
class MockResultBuilder(lifecycle.ResultBuilder):
103+
def build_result(self, **outputs):
104+
return sum(outputs.values())
105+
106+
def input_types(self):
107+
return [int]
108+
109+
def output_type(self):
110+
return int
111+
112+
adapter = FutureAdapter(result_builder=MockResultBuilder())
113+
future_a = Future()
114+
future_b = Future()
115+
future_a.set_result(1)
116+
future_b.set_result(2)
117+
118+
result = adapter.build_result(a=future_a, b=future_b)
119+
assert result == 3

0 commit comments

Comments
 (0)