|
4 | 4 | import pytest
|
5 | 5 | from inline_snapshot import snapshot
|
6 | 6 |
|
7 |
| -from pydantic_ai import Agent |
| 7 | +from pydantic_ai import Agent, HistoryProcessors |
8 | 8 | from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, TextPart, UserPromptPart
|
9 | 9 | from pydantic_ai.models.function import AgentInfo, FunctionModel
|
10 | 10 | from pydantic_ai.tools import RunContext
|
@@ -301,3 +301,111 @@ class Deps:
|
301 | 301 | user_part = msg.parts[0]
|
302 | 302 | assert isinstance(user_part, UserPromptPart)
|
303 | 303 | assert cast(str, user_part.content).startswith('TEST: ')
|
| 304 | + |
| 305 | + |
| 306 | +async def test_history_processors_replace_history_true(function_model: FunctionModel): |
| 307 | + """Test HistoryProcessors with replace_history=True modifies original history.""" |
| 308 | + |
| 309 | + def keep_only_requests(messages: list[ModelMessage]) -> list[ModelMessage]: |
| 310 | + return [msg for msg in messages if isinstance(msg, ModelRequest)] |
| 311 | + |
| 312 | + processors = HistoryProcessors(funcs=[keep_only_requests], replace_history=True) # type: ignore |
| 313 | + agent = Agent(function_model, history_processors=processors) # type: ignore |
| 314 | + |
| 315 | + original_history = [ |
| 316 | + ModelRequest(parts=[UserPromptPart(content='Question 1')]), |
| 317 | + ModelResponse(parts=[TextPart(content='Answer 1')]), |
| 318 | + ModelRequest(parts=[UserPromptPart(content='Question 2')]), |
| 319 | + ModelResponse(parts=[TextPart(content='Answer 2')]), |
| 320 | + ] |
| 321 | + |
| 322 | + result = await agent.run('Question 3', message_history=original_history.copy()) |
| 323 | + |
| 324 | + # Verify the history was modified - responses should be removed |
| 325 | + all_messages = result.all_messages() |
| 326 | + requests = [msg for msg in all_messages if isinstance(msg, ModelRequest)] |
| 327 | + responses = [msg for msg in all_messages if isinstance(msg, ModelResponse)] |
| 328 | + |
| 329 | + # Should have 3 requests (2 original + 1 new) and 1 response (only the new one) |
| 330 | + assert len(requests) == 3 |
| 331 | + assert len(responses) == 1 |
| 332 | + |
| 333 | + |
| 334 | +async def test_history_processors_multiple_with_replace_history(function_model: FunctionModel): |
| 335 | + """Test multiple processors with replace_history=True.""" |
| 336 | + |
| 337 | + def remove_responses(messages: list[ModelMessage]) -> list[ModelMessage]: |
| 338 | + return [msg for msg in messages if isinstance(msg, ModelRequest)] |
| 339 | + |
| 340 | + def keep_recent(messages: list[ModelMessage]) -> list[ModelMessage]: |
| 341 | + return messages[-2:] if len(messages) > 2 else messages |
| 342 | + |
| 343 | + processors = HistoryProcessors( # type: ignore |
| 344 | + funcs=[remove_responses, keep_recent], replace_history=True |
| 345 | + ) |
| 346 | + agent = Agent(function_model, history_processors=processors) # type: ignore |
| 347 | + |
| 348 | + # Create history with 4 requests and 4 responses |
| 349 | + original_history: list[ModelMessage] = [] |
| 350 | + for i in range(4): |
| 351 | + original_history.append(ModelRequest(parts=[UserPromptPart(content=f'Question {i + 1}')])) |
| 352 | + original_history.append(ModelResponse(parts=[TextPart(content=f'Answer {i + 1}')])) |
| 353 | + |
| 354 | + result = await agent.run('Final question', message_history=original_history.copy()) |
| 355 | + |
| 356 | + # After processing: remove responses -> keep recent 2 -> add new exchange |
| 357 | + all_messages = result.all_messages() |
| 358 | + requests = [msg for msg in all_messages if isinstance(msg, ModelRequest)] |
| 359 | + responses = [msg for msg in all_messages if isinstance(msg, ModelResponse)] |
| 360 | + |
| 361 | + # Should have 2 requests (1 requests + 1 new) and 1 response (new only), responses should be removed |
| 362 | + assert len(requests) == 2 |
| 363 | + assert len(responses) == 1 |
| 364 | + |
| 365 | + |
| 366 | +async def test_history_processors_streaming_with_replace_history(function_model: FunctionModel): |
| 367 | + """Test replace_history=True works with streaming runs.""" |
| 368 | + |
| 369 | + def summarize_history(messages: list[ModelMessage]) -> list[ModelMessage]: |
| 370 | + # Simple summarization - keep only the last message |
| 371 | + return messages[-1:] if messages else [] |
| 372 | + |
| 373 | + processors = HistoryProcessors(funcs=[summarize_history], replace_history=True) # type: ignore |
| 374 | + agent = Agent(function_model, history_processors=processors) # type: ignore |
| 375 | + |
| 376 | + original_history = [ |
| 377 | + ModelRequest(parts=[UserPromptPart(content='Question 1')]), |
| 378 | + ModelResponse(parts=[TextPart(content='Answer 1')]), |
| 379 | + ModelRequest(parts=[UserPromptPart(content='Question 2')]), |
| 380 | + ModelResponse(parts=[TextPart(content='Answer 2')]), |
| 381 | + ] |
| 382 | + |
| 383 | + async with agent.run_stream('Question 3', message_history=original_history.copy()) as result: |
| 384 | + async for _ in result.stream_text(): |
| 385 | + pass |
| 386 | + |
| 387 | + # Verify history was modified during streaming |
| 388 | + all_messages = result.all_messages() |
| 389 | + # Should only have: new request + new response = 2 total |
| 390 | + assert len(all_messages) == 2 |
| 391 | + |
| 392 | + |
| 393 | +async def test_history_processors_replace_history_false_default(function_model: FunctionModel): |
| 394 | + """Test HistoryProcessors with replace_history=False (default) preserves original history.""" |
| 395 | + |
| 396 | + def keep_only_requests(messages: list[ModelMessage]) -> list[ModelMessage]: |
| 397 | + return [msg for msg in messages if isinstance(msg, ModelRequest)] |
| 398 | + |
| 399 | + processors = HistoryProcessors(funcs=[keep_only_requests]) # replace_history=False by default # type: ignore |
| 400 | + agent = Agent(function_model, history_processors=processors) # type: ignore |
| 401 | + |
| 402 | + original_history = [ |
| 403 | + ModelRequest(parts=[UserPromptPart(content='Question 1')]), |
| 404 | + ModelResponse(parts=[TextPart(content='Answer 1')]), |
| 405 | + ] |
| 406 | + |
| 407 | + result = await agent.run('Question 2', message_history=original_history.copy()) |
| 408 | + |
| 409 | + # Verify original history is preserved |
| 410 | + all_messages = result.all_messages() |
| 411 | + assert len(all_messages) == 4 # 2 original + 1 new request + 1 new response |
0 commit comments