Skip to content

Commit 1ea5812

Browse files
authored
Fix step_timeout causing ParentCommand/GraphInterrupt exception to bubble up (#4950)
2 parents 161a1e3 + 3141155 commit 1ea5812

File tree

3 files changed

+76
-7
lines changed

3 files changed

+76
-7
lines changed

libs/langgraph/langgraph/pregel/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,8 @@ def commit(
431431
writes.extend(resumes)
432432
self.put_writes()(task.id, writes) # type: ignore[misc]
433433
elif isinstance(exception, GraphBubbleUp):
434-
raise exception
434+
# exception will be raised in _panic_or_proceed
435+
pass
435436
else:
436437
# save error to checkpointer
437438
task.writes.append((ERROR, exception))

libs/langgraph/tests/test_pregel.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from langgraph.checkpoint.memory import InMemorySaver
4444
from langgraph.config import get_stream_writer
4545
from langgraph.constants import CONFIG_KEY_NODE_FINISHED, ERROR, PULL, START
46-
from langgraph.errors import InvalidUpdateError
46+
from langgraph.errors import InvalidUpdateError, ParentCommand
4747
from langgraph.func import entrypoint, task
4848
from langgraph.graph import END, StateGraph
4949
from langgraph.graph.message import MessagesState, add_messages
@@ -7920,9 +7920,10 @@ def my_workflow(number: int):
79207920
]
79217921

79227922

7923+
@pytest.mark.parametrize("with_timeout", [False, "inner", "outer", "both"])
79237924
@pytest.mark.parametrize("subgraph_persist", [True, False])
79247925
def test_parent_command_goto(
7925-
sync_checkpointer: BaseCheckpointSaver, subgraph_persist: bool
7926+
sync_checkpointer: BaseCheckpointSaver, subgraph_persist: bool, with_timeout: bool
79267927
) -> None:
79277928
class State(TypedDict):
79287929
dialog_state: Annotated[list[str], operator.add]
@@ -7943,6 +7944,8 @@ def node_b_child(state):
79437944
sub_builder.add_edge(START, "node_a_child")
79447945
sub_builder.add_edge("node_a_child", "node_b_child")
79457946
sub_graph = sub_builder.compile(checkpointer=subgraph_persist)
7947+
if with_timeout in ("inner", "both"):
7948+
sub_graph.step_timeout = 1
79467949

79477950
def node_b_parent(state):
79487951
return {"dialog_state": ["node_b_parent"]}
@@ -7951,10 +7954,40 @@ def node_b_parent(state):
79517954
main_builder.add_node(node_b_parent)
79527955
main_builder.add_edge(START, "subgraph_node")
79537956
main_builder.add_node("subgraph_node", sub_graph, destinations=("node_b_parent",))
7954-
79557957
main_graph = main_builder.compile(sync_checkpointer, name="parent")
7958+
if with_timeout in ("outer", "both"):
7959+
main_graph.step_timeout = 1
7960+
79567961
config = {"configurable": {"thread_id": 1}}
79577962

79587963
assert main_graph.invoke(input={"dialog_state": ["init_state"]}, config=config) == {
79597964
"dialog_state": ["init_state", "b_child_state", "node_b_parent"]
79607965
}
7966+
7967+
7968+
@pytest.mark.parametrize("with_timeout", [True, False])
7969+
def test_timeout_with_parent_command(
7970+
sync_checkpointer: BaseCheckpointSaver, with_timeout: bool
7971+
) -> None:
7972+
"""Test that parent commands are properly propagated during timeouts."""
7973+
7974+
class State(TypedDict):
7975+
value: str
7976+
7977+
def parent_command_node(state: State) -> State:
7978+
time.sleep(0.1) # Add some delay before raising
7979+
return Command(graph=Command.PARENT, goto="test_cmd", update={"key": "value"})
7980+
7981+
builder = StateGraph(State)
7982+
builder.add_node("parent_cmd", parent_command_node)
7983+
builder.set_entry_point("parent_cmd")
7984+
graph = builder.compile(checkpointer=sync_checkpointer)
7985+
if with_timeout:
7986+
graph.step_timeout = 1
7987+
7988+
# Should propagate parent command, not timeout
7989+
thread1 = {"configurable": {"thread_id": "1"}}
7990+
with pytest.raises(ParentCommand) as exc_info:
7991+
graph.invoke({"value": "start"}, thread1)
7992+
assert exc_info.value.args[0].goto == "test_cmd"
7993+
assert exc_info.value.args[0].update == {"key": "value"}

libs/langgraph/tests/test_pregel_async.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from langgraph.checkpoint.memory import InMemorySaver
4343
from langgraph.constants import CONFIG_KEY_NODE_FINISHED, ERROR, PULL, PUSH, START
44-
from langgraph.errors import InvalidUpdateError, NodeInterrupt
44+
from langgraph.errors import InvalidUpdateError, NodeInterrupt, ParentCommand
4545
from langgraph.func import entrypoint, task
4646
from langgraph.graph import END, StateGraph
4747
from langgraph.graph.message import MessagesState, add_messages
@@ -8655,9 +8655,12 @@ async def my_workflow(number: int):
86558655
]
86568656

86578657

8658+
@pytest.mark.parametrize("with_timeout", [False, "inner", "outer", "both"])
86588659
@pytest.mark.parametrize("subgraph_persist", [True, False])
86598660
async def test_parent_command_goto(
8660-
async_checkpointer: BaseCheckpointSaver, subgraph_persist: bool
8661+
async_checkpointer: BaseCheckpointSaver,
8662+
subgraph_persist: bool,
8663+
with_timeout: Literal[False, "inner", "outer", "both"],
86618664
) -> None:
86628665
class State(TypedDict):
86638666
dialog_state: Annotated[list[str], operator.add]
@@ -8678,6 +8681,8 @@ async def node_b_child(state):
86788681
sub_builder.add_edge(START, "node_a_child")
86798682
sub_builder.add_edge("node_a_child", "node_b_child")
86808683
sub_graph = sub_builder.compile(checkpointer=subgraph_persist)
8684+
if with_timeout in ("inner", "both"):
8685+
sub_graph.step_timeout = 1
86818686

86828687
async def node_b_parent(state):
86838688
return {"dialog_state": ["node_b_parent"]}
@@ -8686,10 +8691,40 @@ async def node_b_parent(state):
86868691
main_builder.add_node(node_b_parent)
86878692
main_builder.add_edge(START, "subgraph_node")
86888693
main_builder.add_node("subgraph_node", sub_graph, destinations=("node_b_parent",))
8689-
86908694
main_graph = main_builder.compile(async_checkpointer, name="parent")
8695+
if with_timeout in ("outer", "both"):
8696+
main_graph.step_timeout = 1
8697+
86918698
config = {"configurable": {"thread_id": 1}}
86928699

86938700
assert await main_graph.ainvoke(
86948701
input={"dialog_state": ["init_state"]}, config=config
86958702
) == {"dialog_state": ["init_state", "b_child_state", "node_b_parent"]}
8703+
8704+
8705+
@pytest.mark.parametrize("with_timeout", [True, False])
8706+
async def test_timeout_with_parent_command(
8707+
async_checkpointer: BaseCheckpointSaver, with_timeout: bool
8708+
) -> None:
8709+
"""Test that parent commands are properly propagated during timeouts."""
8710+
8711+
class State(TypedDict):
8712+
value: str
8713+
8714+
async def parent_command_node(state: State) -> State:
8715+
await asyncio.sleep(0.1) # Add some delay before raising
8716+
return Command(graph=Command.PARENT, goto="test_cmd", update={"key": "value"})
8717+
8718+
builder = StateGraph(State)
8719+
builder.add_node("parent_cmd", parent_command_node)
8720+
builder.set_entry_point("parent_cmd")
8721+
graph = builder.compile(checkpointer=async_checkpointer)
8722+
if with_timeout:
8723+
graph.step_timeout = 1
8724+
8725+
# Should propagate parent command, not timeout
8726+
thread1 = {"configurable": {"thread_id": "1"}}
8727+
with pytest.raises(ParentCommand) as exc_info:
8728+
await graph.ainvoke({"value": "start"}, thread1)
8729+
assert exc_info.value.args[0].goto == "test_cmd"
8730+
assert exc_info.value.args[0].update == {"key": "value"}

0 commit comments

Comments
 (0)