41
41
)
42
42
from langgraph .checkpoint .memory import InMemorySaver
43
43
from langgraph .constants import CONFIG_KEY_NODE_FINISHED , ERROR , PULL , PUSH , START
44
- from langgraph .errors import InvalidUpdateError , NodeInterrupt
44
+ from langgraph .errors import InvalidUpdateError , NodeInterrupt , ParentCommand
45
45
from langgraph .func import entrypoint , task
46
46
from langgraph .graph import END , StateGraph
47
47
from langgraph .graph .message import MessagesState , add_messages
@@ -8655,9 +8655,12 @@ async def my_workflow(number: int):
8655
8655
]
8656
8656
8657
8657
8658
+ @pytest .mark .parametrize ("with_timeout" , [False , "inner" , "outer" , "both" ])
8658
8659
@pytest .mark .parametrize ("subgraph_persist" , [True , False ])
8659
8660
async def test_parent_command_goto (
8660
- async_checkpointer : BaseCheckpointSaver , subgraph_persist : bool
8661
+ async_checkpointer : BaseCheckpointSaver ,
8662
+ subgraph_persist : bool ,
8663
+ with_timeout : Literal [False , "inner" , "outer" , "both" ],
8661
8664
) -> None :
8662
8665
class State (TypedDict ):
8663
8666
dialog_state : Annotated [list [str ], operator .add ]
@@ -8678,6 +8681,8 @@ async def node_b_child(state):
8678
8681
sub_builder .add_edge (START , "node_a_child" )
8679
8682
sub_builder .add_edge ("node_a_child" , "node_b_child" )
8680
8683
sub_graph = sub_builder .compile (checkpointer = subgraph_persist )
8684
+ if with_timeout in ("inner" , "both" ):
8685
+ sub_graph .step_timeout = 1
8681
8686
8682
8687
async def node_b_parent (state ):
8683
8688
return {"dialog_state" : ["node_b_parent" ]}
@@ -8686,10 +8691,40 @@ async def node_b_parent(state):
8686
8691
main_builder .add_node (node_b_parent )
8687
8692
main_builder .add_edge (START , "subgraph_node" )
8688
8693
main_builder .add_node ("subgraph_node" , sub_graph , destinations = ("node_b_parent" ,))
8689
-
8690
8694
main_graph = main_builder .compile (async_checkpointer , name = "parent" )
8695
+ if with_timeout in ("outer" , "both" ):
8696
+ main_graph .step_timeout = 1
8697
+
8691
8698
config = {"configurable" : {"thread_id" : 1 }}
8692
8699
8693
8700
assert await main_graph .ainvoke (
8694
8701
input = {"dialog_state" : ["init_state" ]}, config = config
8695
8702
) == {"dialog_state" : ["init_state" , "b_child_state" , "node_b_parent" ]}
8703
+
8704
+
8705
+ @pytest .mark .parametrize ("with_timeout" , [True , False ])
8706
+ async def test_timeout_with_parent_command (
8707
+ async_checkpointer : BaseCheckpointSaver , with_timeout : bool
8708
+ ) -> None :
8709
+ """Test that parent commands are properly propagated during timeouts."""
8710
+
8711
+ class State (TypedDict ):
8712
+ value : str
8713
+
8714
+ async def parent_command_node (state : State ) -> State :
8715
+ await asyncio .sleep (0.1 ) # Add some delay before raising
8716
+ return Command (graph = Command .PARENT , goto = "test_cmd" , update = {"key" : "value" })
8717
+
8718
+ builder = StateGraph (State )
8719
+ builder .add_node ("parent_cmd" , parent_command_node )
8720
+ builder .set_entry_point ("parent_cmd" )
8721
+ graph = builder .compile (checkpointer = async_checkpointer )
8722
+ if with_timeout :
8723
+ graph .step_timeout = 1
8724
+
8725
+ # Should propagate parent command, not timeout
8726
+ thread1 = {"configurable" : {"thread_id" : "1" }}
8727
+ with pytest .raises (ParentCommand ) as exc_info :
8728
+ await graph .ainvoke ({"value" : "start" }, thread1 )
8729
+ assert exc_info .value .args [0 ].goto == "test_cmd"
8730
+ assert exc_info .value .args [0 ].update == {"key" : "value" }
0 commit comments