Skip to content

Commit d10f036

Browse files
authored
Revert "Use contextvars for tracking the MCP sampling model" (#2132)
1 parent 09e0821 commit d10f036

File tree

3 files changed

+6
-26
lines changed

3 files changed

+6
-26
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,7 @@ async def run_mcp_servers(
17401740
try:
17411741
for mcp_server in self._mcp_servers:
17421742
if sampling_model is not None: # pragma: no branch
1743-
exit_stack.enter_context(mcp_server.override_sampling_model(sampling_model))
1743+
mcp_server.sampling_model = sampling_model
17441744
await exit_stack.enter_async_context(mcp_server)
17451745
yield
17461746
finally:

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import base64
44
import functools
55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
8-
from contextvars import ContextVar
6+
from collections.abc import AsyncIterator, Awaitable, Sequence
7+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
98
from dataclasses import dataclass
109
from pathlib import Path
1110
from types import TracebackType
@@ -61,22 +60,6 @@ class MCPServer(ABC):
6160
_exit_stack: AsyncExitStack
6261
sampling_model: models.Model | None = None
6362

64-
def __post_init__(self):
65-
self._override_sampling_model: ContextVar[models.Model | None] = ContextVar(
66-
'_override_sampling_model', default=None
67-
)
68-
69-
@contextmanager
70-
def override_sampling_model(
71-
self,
72-
model: models.Model,
73-
) -> Iterator[None]:
74-
token = self._override_sampling_model.set(model)
75-
try:
76-
yield
77-
finally:
78-
self._override_sampling_model.reset(token)
79-
8063
@abstractmethod
8164
@asynccontextmanager
8265
async def client_streams(
@@ -201,8 +184,7 @@ async def _sampling_callback(
201184
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
202185
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
203186
"""MCP sampling callback."""
204-
sampling_model = self._override_sampling_model.get() or self.sampling_model
205-
if sampling_model is None:
187+
if self.sampling_model is None:
206188
raise ValueError('Sampling model is not set') # pragma: no cover
207189

208190
pai_messages = _mcp.map_from_mcp_params(params)
@@ -214,15 +196,15 @@ async def _sampling_callback(
214196
if stop_sequences := params.stopSequences: # pragma: no branch
215197
model_settings['stop_sequences'] = stop_sequences
216198

217-
model_response = await sampling_model.request(
199+
model_response = await self.sampling_model.request(
218200
pai_messages,
219201
model_settings,
220202
models.ModelRequestParameters(),
221203
)
222204
return mcp_types.CreateMessageResult(
223205
role='assistant',
224206
content=_mcp.map_from_model_response(model_response),
225-
model=sampling_model.model_name,
207+
model=self.sampling_model.model_name,
226208
)
227209

228210
def _map_tool_result_part(

tests/test_examples.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import shutil
77
import sys
88
from collections.abc import AsyncIterator, Iterable, Sequence
9-
from contextlib import nullcontext
109
from dataclasses import dataclass
1110
from inspect import FrameInfo
1211
from io import StringIO
@@ -259,7 +258,6 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str:
259258

260259
class MockMCPServer:
261260
is_running = True
262-
override_sampling_model = nullcontext
263261

264262
async def __aenter__(self) -> MockMCPServer:
265263
return self

0 commit comments

Comments
 (0)