Skip to content

Commit 3a96d52

Browse files
committed
Better handling of before_call cancellation, closes #1148
1 parent d96ae4e commit 3a96d52

File tree

3 files changed

+122
-12
lines changed

3 files changed

+122
-12
lines changed

docs/python-api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ response = model.chain(
178178
)
179179
print(response.text())
180180
```
181+
If you raise `llm.CancelToolCall` in the `before_call` function the model will be informed that the tool call was cancelled.
182+
181183
The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example:
182184
```python
183185
def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult):

llm/models.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,12 +1013,23 @@ def execute_tool_calls(
10131013
# Tool could be None if the tool was not found in the prompt tools,
10141014
# but we still call the before_call method:
10151015
if before_call:
1016-
cb_result = before_call(tool, tool_call)
1017-
if inspect.isawaitable(cb_result):
1018-
raise TypeError(
1019-
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
1020-
"Please use an async chain/response or a synchronous callback."
1016+
try:
1017+
cb_result = before_call(tool, tool_call)
1018+
if inspect.isawaitable(cb_result):
1019+
raise TypeError(
1020+
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
1021+
"Please use an async chain/response or a synchronous callback."
1022+
)
1023+
except CancelToolCall as ex:
1024+
tool_results.append(
1025+
ToolResult(
1026+
name=tool_call.name,
1027+
output="Cancelled: " + str(ex),
1028+
tool_call_id=tool_call.tool_call_id,
1029+
exception=ex,
1030+
)
10211031
)
1032+
continue
10221033

10231034
if tool is None:
10241035
msg = 'tool "{}" does not exist'.format(tool_call.name)
@@ -1202,9 +1213,17 @@ async def execute_tool_calls(
12021213
async def run_async(tc=tc, tool=tool, idx=idx):
12031214
# before_call inside the task
12041215
if before_call:
1205-
cb = before_call(tool, tc)
1206-
if inspect.isawaitable(cb):
1207-
await cb
1216+
try:
1217+
cb = before_call(tool, tc)
1218+
if inspect.isawaitable(cb):
1219+
await cb
1220+
except CancelToolCall as ex:
1221+
return idx, ToolResult(
1222+
name=tc.name,
1223+
output="Cancelled: " + str(ex),
1224+
tool_call_id=tc.tool_call_id,
1225+
exception=ex,
1226+
)
12081227

12091228
exception = None
12101229
attachments = []
@@ -1245,9 +1264,23 @@ async def run_async(tc=tc, tool=tool, idx=idx):
12451264
else:
12461265
# Sync implementation: do hooks and call inline
12471266
if before_call:
1248-
cb = before_call(tool, tc)
1249-
if inspect.isawaitable(cb):
1250-
await cb
1267+
try:
1268+
cb = before_call(tool, tc)
1269+
if inspect.isawaitable(cb):
1270+
await cb
1271+
except CancelToolCall as ex:
1272+
indexed_results.append(
1273+
(
1274+
idx,
1275+
ToolResult(
1276+
name=tc.name,
1277+
output="Cancelled: " + str(ex),
1278+
tool_call_id=tc.tool_call_id,
1279+
exception=ex,
1280+
),
1281+
)
1282+
)
1283+
continue
12511284

12521285
exception = None
12531286
attachments = []

tests/test_tools.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from importlib.metadata import version
44
import json
55
import llm
6-
from llm import cli
6+
from llm import cli, CancelToolCall
77
from llm.migrations import migrate
88
from llm.tools import llm_time
99
import os
@@ -432,3 +432,78 @@ def test_tool_errors(async_):
432432
" Error: Error!<br>\n"
433433
" **Error**: Exception: Error!\n"
434434
) in log_text_result.output
435+
436+
437+
def test_chain_sync_cancel_only_first_of_two():
438+
model = llm.get_model("echo")
439+
440+
def t1() -> str:
441+
return "ran1"
442+
443+
def t2() -> str:
444+
return "ran2"
445+
446+
def before(tool, tool_call):
447+
if tool.name == "t1":
448+
raise CancelToolCall("skip1")
449+
# allow t2
450+
return None
451+
452+
calls = [
453+
{"name": "t1"},
454+
{"name": "t2"},
455+
]
456+
payload = json.dumps({"tool_calls": calls})
457+
chain = model.chain(payload, tools=[t1, t2], before_call=before)
458+
_ = chain.text()
459+
460+
# second response has two results
461+
second = chain._responses[1]
462+
results = second.prompt.tool_results
463+
assert len(results) == 2
464+
465+
# first cancelled, second executed
466+
assert results[0].name == "t1"
467+
assert results[0].output == "Cancelled: skip1"
468+
assert isinstance(results[0].exception, CancelToolCall)
469+
470+
assert results[1].name == "t2"
471+
assert results[1].output == "ran2"
472+
assert results[1].exception is None
473+
474+
475+
# 2c async equivalent
476+
@pytest.mark.asyncio
477+
async def test_chain_async_cancel_only_first_of_two():
478+
async_model = llm.get_async_model("echo")
479+
480+
def t1() -> str:
481+
return "ran1"
482+
483+
async def t2() -> str:
484+
return "ran2"
485+
486+
async def before(tool, tool_call):
487+
if tool.name == "t1":
488+
raise CancelToolCall("skip1")
489+
return None
490+
491+
calls = [
492+
{"name": "t1"},
493+
{"name": "t2"},
494+
]
495+
payload = json.dumps({"tool_calls": calls})
496+
chain = async_model.chain(payload, tools=[t1, t2], before_call=before)
497+
_ = await chain.text()
498+
499+
second = chain._responses[1]
500+
results = second.prompt.tool_results
501+
assert len(results) == 2
502+
503+
assert results[0].name == "t1"
504+
assert results[0].output == "Cancelled: skip1"
505+
assert isinstance(results[0].exception, CancelToolCall)
506+
507+
assert results[1].name == "t2"
508+
assert results[1].output == "ran2"
509+
assert results[1].exception is None

0 commit comments

Comments
 (0)