Skip to content

Commit 4f25359

Browse files
committed
fix: check TextPart.has_content() before considering it as final result
This fixes an issue with Ollama where empty TextPart responses before tool calls would prematurely stop the streaming process. Fixes #1292
1 parent 9b8cb71 commit 4f25359

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

+33-14
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ async def run(
245245
result_type: None = None,
246246
message_history: list[_messages.ModelMessage] | None = None,
247247
model: models.Model | models.KnownModelName | None = None,
248-
deps: AgentDepsT = None,
248+
deps: AgentDepsT | None = None,
249249
model_settings: ModelSettings | None = None,
250250
usage_limits: _usage.UsageLimits | None = None,
251251
usage: _usage.Usage | None = None,
@@ -260,7 +260,7 @@ async def run(
260260
result_type: type[RunResultDataT],
261261
message_history: list[_messages.ModelMessage] | None = None,
262262
model: models.Model | models.KnownModelName | None = None,
263-
deps: AgentDepsT = None,
263+
deps: AgentDepsT | None = None,
264264
model_settings: ModelSettings | None = None,
265265
usage_limits: _usage.UsageLimits | None = None,
266266
usage: _usage.Usage | None = None,
@@ -274,7 +274,7 @@ async def run(
274274
result_type: type[RunResultDataT] | None = None,
275275
message_history: list[_messages.ModelMessage] | None = None,
276276
model: models.Model | models.KnownModelName | None = None,
277-
deps: AgentDepsT = None,
277+
deps: AgentDepsT | None = None,
278278
model_settings: ModelSettings | None = None,
279279
usage_limits: _usage.UsageLimits | None = None,
280280
usage: _usage.Usage | None = None,
@@ -338,7 +338,7 @@ async def iter(
338338
result_type: type[RunResultDataT] | None = None,
339339
message_history: list[_messages.ModelMessage] | None = None,
340340
model: models.Model | models.KnownModelName | None = None,
341-
deps: AgentDepsT = None,
341+
deps: AgentDepsT | None = None,
342342
model_settings: ModelSettings | None = None,
343343
usage_limits: _usage.UsageLimits | None = None,
344344
usage: _usage.Usage | None = None,
@@ -499,7 +499,7 @@ def run_sync(
499499
*,
500500
message_history: list[_messages.ModelMessage] | None = None,
501501
model: models.Model | models.KnownModelName | None = None,
502-
deps: AgentDepsT = None,
502+
deps: AgentDepsT | None = None,
503503
model_settings: ModelSettings | None = None,
504504
usage_limits: _usage.UsageLimits | None = None,
505505
usage: _usage.Usage | None = None,
@@ -514,7 +514,7 @@ def run_sync(
514514
result_type: type[RunResultDataT] | None,
515515
message_history: list[_messages.ModelMessage] | None = None,
516516
model: models.Model | models.KnownModelName | None = None,
517-
deps: AgentDepsT = None,
517+
deps: AgentDepsT | None = None,
518518
model_settings: ModelSettings | None = None,
519519
usage_limits: _usage.UsageLimits | None = None,
520520
usage: _usage.Usage | None = None,
@@ -528,7 +528,7 @@ def run_sync(
528528
result_type: type[RunResultDataT] | None = None,
529529
message_history: list[_messages.ModelMessage] | None = None,
530530
model: models.Model | models.KnownModelName | None = None,
531-
deps: AgentDepsT = None,
531+
deps: AgentDepsT | None = None,
532532
model_settings: ModelSettings | None = None,
533533
usage_limits: _usage.UsageLimits | None = None,
534534
usage: _usage.Usage | None = None,
@@ -589,7 +589,7 @@ def run_stream(
589589
result_type: None = None,
590590
message_history: list[_messages.ModelMessage] | None = None,
591591
model: models.Model | models.KnownModelName | None = None,
592-
deps: AgentDepsT = None,
592+
deps: AgentDepsT | None = None,
593593
model_settings: ModelSettings | None = None,
594594
usage_limits: _usage.UsageLimits | None = None,
595595
usage: _usage.Usage | None = None,
@@ -604,7 +604,7 @@ def run_stream(
604604
result_type: type[RunResultDataT],
605605
message_history: list[_messages.ModelMessage] | None = None,
606606
model: models.Model | models.KnownModelName | None = None,
607-
deps: AgentDepsT = None,
607+
deps: AgentDepsT | None = None,
608608
model_settings: ModelSettings | None = None,
609609
usage_limits: _usage.UsageLimits | None = None,
610610
usage: _usage.Usage | None = None,
@@ -619,7 +619,7 @@ async def run_stream( # noqa C901
619619
result_type: type[RunResultDataT] | None = None,
620620
message_history: list[_messages.ModelMessage] | None = None,
621621
model: models.Model | models.KnownModelName | None = None,
622-
deps: AgentDepsT = None,
622+
deps: AgentDepsT | None = None,
623623
model_settings: ModelSettings | None = None,
624624
usage_limits: _usage.UsageLimits | None = None,
625625
usage: _usage.Usage | None = None,
@@ -685,12 +685,29 @@ async def stream_to_final(
685685
s: models.StreamedResponse,
686686
) -> FinalResult[models.StreamedResponse] | None:
687687
result_schema = graph_ctx.deps.result_schema
688+
parts_seen: list[_messages.ModelResponsePart] = []
688689
async for maybe_part_event in streamed_response:
689690
if isinstance(maybe_part_event, _messages.PartStartEvent):
690691
new_part = maybe_part_event.part
692+
parts_seen.append(new_part)
691693
if isinstance(new_part, _messages.TextPart):
692-
if _agent_graph.allow_text_result(result_schema):
693-
return FinalResult(s, None, None)
694+
# Only treat empty text as final if:
695+
# 1. It's the only part we've seen AND
696+
# 2. We've consumed all events (no more parts coming)
697+
# Otherwise, require non-empty content
698+
if new_part.has_content():
699+
if _agent_graph.allow_text_result(result_schema):
700+
return FinalResult(s, None, None)
701+
elif len(parts_seen) == 1:
702+
# For empty text, peek ahead to see if there are more parts
703+
more_parts = False
704+
async for next_event in streamed_response:
705+
if isinstance(next_event, _messages.PartStartEvent):
706+
parts_seen.append(next_event.part)
707+
more_parts = True
708+
break
709+
if not more_parts and _agent_graph.allow_text_result(result_schema):
710+
return FinalResult(s, None, None)
694711
elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
695712
for call, _ in result_schema.find_tool([new_part]):
696713
return FinalResult(s, call.tool_name, call.tool_call_id)
@@ -1192,7 +1209,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | None) -> mode
11921209

11931210
return model_
11941211

1195-
def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
1212+
def _get_deps(self: Agent[T, ResultDataT], deps: T | None) -> T:
11961213
"""Get deps for a run.
11971214
11981215
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
@@ -1202,7 +1219,9 @@ def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
12021219
if some_deps := self._override_deps:
12031220
return some_deps.value
12041221
else:
1205-
return deps
1222+
if deps is None and self._deps_type is not NoneType:
1223+
raise ValueError("deps cannot be None when _override_deps is not set and deps_type is not NoneType")
1224+
return deps # type: ignore
12061225

12071226
def _infer_name(self, function_frame: FrameType | None) -> None:
12081227
"""Infer the agent name from the call frame.

0 commit comments

Comments
 (0)