3
3
import base64
4
4
import functools
5
5
from abc import ABC , abstractmethod
6
- from collections .abc import AsyncIterator , Awaitable , Iterator , Sequence
7
- from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager , contextmanager
8
- from contextvars import ContextVar
6
+ from collections .abc import AsyncIterator , Awaitable , Sequence
7
+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
9
8
from dataclasses import dataclass
10
9
from pathlib import Path
11
10
from types import TracebackType
@@ -61,22 +60,6 @@ class MCPServer(ABC):
61
60
_exit_stack : AsyncExitStack
62
61
sampling_model : models .Model | None = None
63
62
64
- def __post_init__ (self ):
65
- self ._override_sampling_model : ContextVar [models .Model | None ] = ContextVar (
66
- '_override_sampling_model' , default = None
67
- )
68
-
69
- @contextmanager
70
- def override_sampling_model (
71
- self ,
72
- model : models .Model ,
73
- ) -> Iterator [None ]:
74
- token = self ._override_sampling_model .set (model )
75
- try :
76
- yield
77
- finally :
78
- self ._override_sampling_model .reset (token )
79
-
80
63
@abstractmethod
81
64
@asynccontextmanager
82
65
async def client_streams (
@@ -201,8 +184,7 @@ async def _sampling_callback(
201
184
self , context : RequestContext [ClientSession , Any ], params : mcp_types .CreateMessageRequestParams
202
185
) -> mcp_types .CreateMessageResult | mcp_types .ErrorData :
203
186
"""MCP sampling callback."""
204
- sampling_model = self ._override_sampling_model .get () or self .sampling_model
205
- if sampling_model is None :
187
+ if self .sampling_model is None :
206
188
raise ValueError ('Sampling model is not set' ) # pragma: no cover
207
189
208
190
pai_messages = _mcp .map_from_mcp_params (params )
@@ -214,15 +196,15 @@ async def _sampling_callback(
214
196
if stop_sequences := params .stopSequences : # pragma: no branch
215
197
model_settings ['stop_sequences' ] = stop_sequences
216
198
217
- model_response = await sampling_model .request (
199
+ model_response = await self . sampling_model .request (
218
200
pai_messages ,
219
201
model_settings ,
220
202
models .ModelRequestParameters (),
221
203
)
222
204
return mcp_types .CreateMessageResult (
223
205
role = 'assistant' ,
224
206
content = _mcp .map_from_model_response (model_response ),
225
- model = sampling_model .model_name ,
207
+ model = self . sampling_model .model_name ,
226
208
)
227
209
228
210
def _map_tool_result_part (
0 commit comments