@@ -3,20 +3,21 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'
3
3
import { StdioClientTransport , StdioServerParameters } from '@modelcontextprotocol/sdk/client/stdio.js'
4
4
import { BaseToolkit , tool , Tool } from '@langchain/core/tools'
5
5
import { z } from 'zod'
6
+ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
7
+ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
6
8
7
9
export class MCPToolkit extends BaseToolkit {
8
10
tools : Tool [ ] = [ ]
9
11
_tools : ListToolsResult | null = null
10
12
model_config : any
11
- transport : StdioClientTransport | null = null
13
+ transport : StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport | null = null
12
14
client : Client | null = null
13
- constructor ( serverParams : StdioServerParameters | any , transport : 'stdio' | 'sse' ) {
15
+ serverParams : StdioServerParameters | any
16
+ transportType : 'stdio' | 'sse'
17
+ constructor ( serverParams : StdioServerParameters | any , transportType : 'stdio' | 'sse' ) {
14
18
super ( )
15
- if ( transport === 'stdio' ) {
16
- this . transport = new StdioClientTransport ( serverParams as StdioServerParameters )
17
- } else {
18
- // TODO: this.transport = new SSEClientTransport(serverParams.url);
19
- }
19
+ this . serverParams = serverParams
20
+ this . transportType = transportType
20
21
}
21
22
async initialize ( ) {
22
23
if ( this . _tools === null ) {
@@ -29,10 +30,30 @@ export class MCPToolkit extends BaseToolkit {
29
30
capabilities : { }
30
31
}
31
32
)
32
- if ( this . transport === null ) {
33
- throw new Error ( 'Transport is not initialized' )
33
+ if ( this . transportType === 'stdio' ) {
34
+ // Compatible with overridden PATH configuration
35
+ this . serverParams . env = {
36
+ ...( this . serverParams . env || { } ) ,
37
+ PATH : process . env . PATH
38
+ }
39
+
40
+ this . transport = new StdioClientTransport ( this . serverParams as StdioServerParameters )
41
+ await this . client . connect ( this . transport )
42
+ } else {
43
+ if ( this . serverParams . url === undefined ) {
44
+ throw new Error ( 'URL is required for SSE transport' )
45
+ }
46
+
47
+ const baseUrl = new URL ( this . serverParams . url )
48
+ try {
49
+ this . transport = new StreamableHTTPClientTransport ( baseUrl )
50
+ await this . client . connect ( this . transport )
51
+ } catch ( error ) {
52
+ this . transport = new SSEClientTransport ( baseUrl )
53
+ await this . client . connect ( this . transport )
54
+ }
34
55
}
35
- await this . client . connect ( this . transport )
56
+
36
57
this . _tools = await this . client . request ( { method : 'tools/list' } , ListToolsResultSchema )
37
58
38
59
this . tools = await this . get_tools ( )
0 commit comments