Skip to content

Commit ff864e3

Browse files
committed
Bugfix pipe_output after config.when resolution
In case no pipe_output functions meet confgi.when conditions return original node and skip pipe_output.
1 parent 6367a45 commit ff864e3

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

hamilton/function_modifiers/macros.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,10 @@ def __identity(foo: Any) -> Any:
13141314
fn=fn,
13151315
)
13161316

1317+
# In case config resolves to no pipe functions applied we return the original node and skip pipe
1318+
if len(nodes) == 1:
1319+
return [node_]
1320+
13171321
last_node = nodes[-1].copy_with(name=f"{node_.name}", typ=nodes[-2].type)
13181322

13191323
out = [original_node]

tests/function_modifiers/test_macros.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,71 @@ def test_pipe_output_end_to_end():
976976
assert result["chain_2_using_pipe_output"] == result["chain_2_not_using_pipe_output"]
977977

978978

979+
def test_pipe_output_end_to_end_with_config():
980+
inputs = {
981+
"input_1": 10,
982+
"input_2": 20,
983+
"input_3": 30,
984+
}
985+
986+
dr = (
987+
driver.Builder()
988+
.with_modules(tests.resources.pipe_output)
989+
.with_adapter(base.DefaultAdapter())
990+
.with_config({"key": "Yes"})
991+
.build()
992+
)
993+
994+
result = dr.execute(
995+
[
996+
"chain_3_using_pipe_output",
997+
"chain_3_not_using_pipe_output_config_true",
998+
],
999+
inputs=inputs,
1000+
)
1001+
assert (
1002+
result["chain_3_using_pipe_output"] == result["chain_3_not_using_pipe_output_config_true"]
1003+
)
1004+
1005+
dr = (
1006+
driver.Builder()
1007+
.with_modules(tests.resources.pipe_output)
1008+
.with_adapter(base.DefaultAdapter())
1009+
.with_config({"key": "No"})
1010+
.build()
1011+
)
1012+
1013+
result = dr.execute(
1014+
[
1015+
"chain_3_using_pipe_output",
1016+
"chain_3_not_using_pipe_output_config_false",
1017+
],
1018+
inputs=inputs,
1019+
)
1020+
assert (
1021+
result["chain_3_using_pipe_output"] == result["chain_3_not_using_pipe_output_config_false"]
1022+
)
1023+
1024+
dr = (
1025+
driver.Builder()
1026+
.with_modules(tests.resources.pipe_output)
1027+
.with_adapter(base.DefaultAdapter())
1028+
.with_config({"key": "skip"})
1029+
.build()
1030+
)
1031+
result = dr.execute(
1032+
[
1033+
"chain_3_using_pipe_output",
1034+
"chain_3_not_using_pipe_output_config_no_conditions_met",
1035+
],
1036+
inputs=inputs,
1037+
)
1038+
assert (
1039+
result["chain_3_using_pipe_output"]
1040+
== result["chain_3_not_using_pipe_output_config_no_conditions_met"]
1041+
)
1042+
1043+
9791044
# Mutate will mark the modules (and leave a mark).
9801045
# Thus calling it a second time (for instance through pmultiple tests) might mess it up slightly...
9811046
# Using fixtures just to be sure.

tests/resources/pipe_output.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,34 @@ def chain_2_not_using_pipe_output(v: int, input_3: int, calc_c: bool = False) ->
105105
d = _add_n(c, n=input_3) # Assuming "upstream" refers to the same value as "v" here
106106
e = _add_two(d)
107107
return e
108+
109+
110+
@pipe_output(
111+
step(_square).named("a").when(key="Yes"),
112+
step(_multiply_n, n=value(2)).named("b").when(key="No"),
113+
step(_add_n, n=10).named("c").when(key="Yes"),
114+
step(_add_n, n=source("input_3")).named("d").when(key="No"),
115+
step(_add_two).named("e").when(key="Yes"),
116+
)
117+
def chain_3_using_pipe_output(v: int) -> int:
118+
return v + 10
119+
120+
121+
def chain_3_not_using_pipe_output_config_true(v: int, input_3: int) -> int:
122+
start = v + 10
123+
a = _square(start)
124+
c = _add_n(a, n=10)
125+
e = _add_two(c)
126+
return e
127+
128+
129+
def chain_3_not_using_pipe_output_config_false(v: int, input_3: int) -> int:
130+
start = v + 10
131+
b = _multiply_n(start, n=2)
132+
d = _add_n(b, n=input_3)
133+
return d
134+
135+
136+
def chain_3_not_using_pipe_output_config_no_conditions_met(v: int, input_3: int) -> int:
137+
start = v + 10
138+
return start

0 commit comments

Comments
 (0)