Skip to content

Commit 622ce27

Browse files
committed
Tests for model.conversation(before_call=, after_call=)
1 parent bf6fcb9 commit 622ce27

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/test_tools.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,45 @@ async def return_attachment() -> llm.Attachment:
343343
output = await chain_response.text()
344344
assert '"type": "image/png"' in output
345345
assert '"output": "Output"' in output
346+
347+
348+
def test_tool_conversation_settings():
349+
model = llm.get_model("echo")
350+
before_collected = []
351+
after_collected = []
352+
353+
def before(*args):
354+
before_collected.append(args)
355+
356+
def after(*args):
357+
after_collected.append(args)
358+
359+
conversation = model.conversation(
360+
tools=[llm_time], before_call=before, after_call=after
361+
)
362+
# Run two things
363+
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
364+
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
365+
assert len(before_collected) == 2
366+
assert len(after_collected) == 2
367+
368+
369+
@pytest.mark.asyncio
370+
async def test_tool_conversation_settings_async():
371+
model = llm.get_async_model("echo")
372+
before_collected = []
373+
after_collected = []
374+
375+
async def before(*args):
376+
before_collected.append(args)
377+
378+
async def after(*args):
379+
after_collected.append(args)
380+
381+
conversation = model.conversation(
382+
tools=[llm_time], before_call=before, after_call=after
383+
)
384+
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
385+
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
386+
assert len(before_collected) == 2
387+
assert len(after_collected) == 2

0 commit comments

Comments
 (0)