Skip to content

Commit b1adec6

Browse files
committed
fix: check TextPart.has_content() before considering it as final result
This fixes an issue with Ollama where empty TextPart responses before tool calls would prematurely stop the streaming process. Fixes #1292
1 parent 9b8cb71 commit b1adec6

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ async def run(
245245
result_type: None = None,
246246
message_history: list[_messages.ModelMessage] | None = None,
247247
model: models.Model | models.KnownModelName | None = None,
248-
deps: AgentDepsT = None,
248+
deps: AgentDepsT | None = None,
249249
model_settings: ModelSettings | None = None,
250250
usage_limits: _usage.UsageLimits | None = None,
251251
usage: _usage.Usage | None = None,
@@ -260,7 +260,7 @@ async def run(
260260
result_type: type[RunResultDataT],
261261
message_history: list[_messages.ModelMessage] | None = None,
262262
model: models.Model | models.KnownModelName | None = None,
263-
deps: AgentDepsT = None,
263+
deps: AgentDepsT | None = None,
264264
model_settings: ModelSettings | None = None,
265265
usage_limits: _usage.UsageLimits | None = None,
266266
usage: _usage.Usage | None = None,
@@ -274,7 +274,7 @@ async def run(
274274
result_type: type[RunResultDataT] | None = None,
275275
message_history: list[_messages.ModelMessage] | None = None,
276276
model: models.Model | models.KnownModelName | None = None,
277-
deps: AgentDepsT = None,
277+
deps: AgentDepsT | None = None,
278278
model_settings: ModelSettings | None = None,
279279
usage_limits: _usage.UsageLimits | None = None,
280280
usage: _usage.Usage | None = None,
@@ -338,7 +338,7 @@ async def iter(
338338
result_type: type[RunResultDataT] | None = None,
339339
message_history: list[_messages.ModelMessage] | None = None,
340340
model: models.Model | models.KnownModelName | None = None,
341-
deps: AgentDepsT = None,
341+
deps: AgentDepsT | None = None,
342342
model_settings: ModelSettings | None = None,
343343
usage_limits: _usage.UsageLimits | None = None,
344344
usage: _usage.Usage | None = None,
@@ -499,7 +499,7 @@ def run_sync(
499499
*,
500500
message_history: list[_messages.ModelMessage] | None = None,
501501
model: models.Model | models.KnownModelName | None = None,
502-
deps: AgentDepsT = None,
502+
deps: AgentDepsT | None = None,
503503
model_settings: ModelSettings | None = None,
504504
usage_limits: _usage.UsageLimits | None = None,
505505
usage: _usage.Usage | None = None,
@@ -514,7 +514,7 @@ def run_sync(
514514
result_type: type[RunResultDataT] | None,
515515
message_history: list[_messages.ModelMessage] | None = None,
516516
model: models.Model | models.KnownModelName | None = None,
517-
deps: AgentDepsT = None,
517+
deps: AgentDepsT | None = None,
518518
model_settings: ModelSettings | None = None,
519519
usage_limits: _usage.UsageLimits | None = None,
520520
usage: _usage.Usage | None = None,
@@ -528,7 +528,7 @@ def run_sync(
528528
result_type: type[RunResultDataT] | None = None,
529529
message_history: list[_messages.ModelMessage] | None = None,
530530
model: models.Model | models.KnownModelName | None = None,
531-
deps: AgentDepsT = None,
531+
deps: AgentDepsT | None = None,
532532
model_settings: ModelSettings | None = None,
533533
usage_limits: _usage.UsageLimits | None = None,
534534
usage: _usage.Usage | None = None,
@@ -589,7 +589,7 @@ def run_stream(
589589
result_type: None = None,
590590
message_history: list[_messages.ModelMessage] | None = None,
591591
model: models.Model | models.KnownModelName | None = None,
592-
deps: AgentDepsT = None,
592+
deps: AgentDepsT | None = None,
593593
model_settings: ModelSettings | None = None,
594594
usage_limits: _usage.UsageLimits | None = None,
595595
usage: _usage.Usage | None = None,
@@ -604,7 +604,7 @@ def run_stream(
604604
result_type: type[RunResultDataT],
605605
message_history: list[_messages.ModelMessage] | None = None,
606606
model: models.Model | models.KnownModelName | None = None,
607-
deps: AgentDepsT = None,
607+
deps: AgentDepsT | None = None,
608608
model_settings: ModelSettings | None = None,
609609
usage_limits: _usage.UsageLimits | None = None,
610610
usage: _usage.Usage | None = None,
@@ -619,7 +619,7 @@ async def run_stream( # noqa C901
619619
result_type: type[RunResultDataT] | None = None,
620620
message_history: list[_messages.ModelMessage] | None = None,
621621
model: models.Model | models.KnownModelName | None = None,
622-
deps: AgentDepsT = None,
622+
deps: AgentDepsT | None = None,
623623
model_settings: ModelSettings | None = None,
624624
usage_limits: _usage.UsageLimits | None = None,
625625
usage: _usage.Usage | None = None,
@@ -685,15 +685,20 @@ async def stream_to_final(
685685
s: models.StreamedResponse,
686686
) -> FinalResult[models.StreamedResponse] | None:
687687
result_schema = graph_ctx.deps.result_schema
688+
parts_seen: list[_messages.ModelResponsePart] = []
689+
has_tool_call = False
688690
async for maybe_part_event in streamed_response:
689691
if isinstance(maybe_part_event, _messages.PartStartEvent):
690692
new_part = maybe_part_event.part
691-
if isinstance(new_part, _messages.TextPart):
692-
if _agent_graph.allow_text_result(result_schema):
693-
return FinalResult(s, None, None)
694-
elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
693+
parts_seen.append(new_part)
694+
if isinstance(new_part, _messages.ToolCallPart) and result_schema:
695+
has_tool_call = True
695696
for call, _ in result_schema.find_tool([new_part]):
696697
return FinalResult(s, call.tool_name, call.tool_call_id)
698+
# Only check for final result after seeing all parts
699+
if not has_tool_call and len(parts_seen) == 1 and isinstance(parts_seen[0], _messages.TextPart):
700+
if _agent_graph.allow_text_result(result_schema):
701+
return FinalResult(s, None, None)
697702
return None
698703

699704
final_result_details = await stream_to_final(streamed_response)
@@ -1192,7 +1197,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | None) -> mode
11921197

11931198
return model_
11941199

1195-
def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
1200+
def _get_deps(self: Agent[T, ResultDataT], deps: T | None) -> T:
11961201
"""Get deps for a run.
11971202
11981203
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
@@ -1202,7 +1207,9 @@ def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
12021207
if some_deps := self._override_deps:
12031208
return some_deps.value
12041209
else:
1205-
return deps
1210+
if deps is None and self._deps_type is not NoneType:
1211+
raise ValueError("deps cannot be None when _override_deps is not set and deps_type is not NoneType")
1212+
return deps # type: ignore
12061213

12071214
def _infer_name(self, function_frame: FrameType | None) -> None:
12081215
"""Infer the agent name from the call frame.

0 commit comments

Comments
 (0)