|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.engine.algorithms.remote;
|
7 | 7 |
|
8 |
| -import static org.opensearch.ml.common.CommonValue.MCP_EXECUTOR_SERVICE; |
9 | 8 | import static org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT;
|
10 | 9 | import static org.opensearch.ml.common.CommonValue.MCP_TOOLS_FIELD;
|
11 | 10 | import static org.opensearch.ml.common.CommonValue.MCP_TOOL_DESCRIPTION_FIELD;
|
|
15 | 14 | import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
|
16 | 15 |
|
17 | 16 | import java.net.http.HttpRequest;
|
18 |
| -import java.security.AccessController; |
19 |
| -import java.security.PrivilegedExceptionAction; |
20 | 17 | import java.time.Duration;
|
21 | 18 | import java.util.ArrayList;
|
22 | 19 | import java.util.Collections;
|
23 | 20 | import java.util.HashMap;
|
24 | 21 | import java.util.List;
|
25 | 22 | import java.util.Map;
|
26 |
| -import java.util.concurrent.ExecutorService; |
27 |
| -import java.util.concurrent.SynchronousQueue; |
28 |
| -import java.util.concurrent.ThreadPoolExecutor; |
29 |
| -import java.util.concurrent.TimeUnit; |
30 | 23 | import java.util.function.Consumer;
|
31 | 24 |
|
32 | 25 | import org.apache.logging.log4j.Logger;
|
@@ -85,71 +78,57 @@ public List<MLToolSpec> getMcpToolSpecs() {
|
85 | 78 | }
|
86 | 79 | List<MLToolSpec> mcpToolSpecs = new ArrayList<>();
|
87 | 80 | try {
|
88 |
| - AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> { |
89 |
| - |
90 |
| - // TODO: USE DEFAULT EXECUTOR AFTER JSM SHUTDOWN |
91 |
| - // Create a privileged executor service |
92 |
| - // TODO: Make these hardcoded numbers configurable |
93 |
| - ExecutorService executor = new ThreadPoolExecutor(2, 10, 60L, TimeUnit.SECONDS, new SynchronousQueue<>(), r -> { |
94 |
| - Thread thread = new Thread(r); |
95 |
| - thread.setDaemon(true); |
96 |
| - return thread; |
97 |
| - }); |
98 |
| - |
99 |
| - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); |
100 |
| - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); |
101 |
| - |
102 |
| - Consumer<HttpRequest.Builder> headerConfig = builder -> { |
103 |
| - builder.header("Content-Type", "application/json"); |
104 |
| - |
105 |
| - for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) { |
106 |
| - builder.header(entry.getKey(), entry.getValue()); |
107 |
| - } |
108 |
| - }; |
109 |
| - |
110 |
| - // Create transport |
111 |
| - McpClientTransport transport = HttpClientSseClientTransport.builder(mcpServerUrl).customizeClient(clientBuilder -> { |
112 |
| - clientBuilder.executor(executor).connectTimeout(connectionTimeout); |
113 |
| - }).customizeRequest(headerConfig).build(); |
114 |
| - |
115 |
| - // Create and initialize client |
116 |
| - McpSyncClient client = McpClient |
117 |
| - .sync(transport) |
118 |
| - .requestTimeout(readTimeout) |
119 |
| - .capabilities(McpSchema.ClientCapabilities.builder().roots(false).build()) |
120 |
| - .build(); |
| 81 | + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); |
| 82 | + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); |
| 83 | + |
| 84 | + Consumer<HttpRequest.Builder> headerConfig = builder -> { |
| 85 | + builder.header("Content-Type", "application/json"); |
121 | 86 |
|
122 |
| - client.initialize(); |
123 |
| - McpSchema.ListToolsResult tools = client.listTools(); |
124 |
| - |
125 |
| - // Process the results |
126 |
| - Gson gson = new Gson(); |
127 |
| - String json = gson.toJson(tools, McpSchema.ListToolsResult.class); |
128 |
| - Map<String, Object> map = gson.fromJson(json, Map.class); |
129 |
| - |
130 |
| - List<Object> mcpTools = (List<Object>) map.get(MCP_TOOLS_FIELD); |
131 |
| - |
132 |
| - for (Object tool : mcpTools) { |
133 |
| - Map<String, Object> toolMap = (Map<String, Object>) tool; |
134 |
| - Map<String, String> attributes = new HashMap<>(); |
135 |
| - attributes.put(TOOL_INPUT_SCHEMA_FIELD, StringUtils.toJson(toolMap.get(MCP_TOOL_INPUT_SCHEMA_FIELD))); |
136 |
| - |
137 |
| - String description = (toolMap.containsKey(MCP_TOOL_DESCRIPTION_FIELD)) |
138 |
| - ? StringUtils.processTextDoc(toolMap.get(MCP_TOOL_DESCRIPTION_FIELD).toString()) |
139 |
| - : McpSseTool.DEFAULT_DESCRIPTION; |
140 |
| - MLToolSpec mlToolSpec = MLToolSpec |
141 |
| - .builder() |
142 |
| - .type(McpSseTool.TYPE) |
143 |
| - .name(toolMap.get(MCP_TOOL_NAME_FIELD).toString()) |
144 |
| - .description(description) |
145 |
| - .attributes(attributes) |
146 |
| - .build(); |
147 |
| - mlToolSpec.addRuntimeResource(MCP_SYNC_CLIENT, client); |
148 |
| - mlToolSpec.addRuntimeResource(MCP_EXECUTOR_SERVICE, executor); |
149 |
| - mcpToolSpecs.add(mlToolSpec); |
| 87 | + for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) { |
| 88 | + builder.header(entry.getKey(), entry.getValue()); |
150 | 89 | }
|
151 |
| - return null; |
152 |
| - }); |
| 90 | + }; |
| 91 | + |
| 92 | + // Create transport |
| 93 | + McpClientTransport transport = HttpClientSseClientTransport.builder(mcpServerUrl).customizeClient(clientBuilder -> { |
| 94 | + clientBuilder.connectTimeout(connectionTimeout); |
| 95 | + }).customizeRequest(headerConfig).build(); |
| 96 | + |
| 97 | + // Create and initialize client |
| 98 | + McpSyncClient client = McpClient |
| 99 | + .sync(transport) |
| 100 | + .requestTimeout(readTimeout) |
| 101 | + .capabilities(McpSchema.ClientCapabilities.builder().roots(false).build()) |
| 102 | + .build(); |
| 103 | + |
| 104 | + client.initialize(); |
| 105 | + McpSchema.ListToolsResult tools = client.listTools(); |
| 106 | + |
| 107 | + // Process the results |
| 108 | + Gson gson = new Gson(); |
| 109 | + String json = gson.toJson(tools, McpSchema.ListToolsResult.class); |
| 110 | + Map<String, Object> map = gson.fromJson(json, Map.class); |
| 111 | + |
| 112 | + List<Object> mcpTools = (List<Object>) map.get(MCP_TOOLS_FIELD); |
| 113 | + |
| 114 | + for (Object tool : mcpTools) { |
| 115 | + Map<String, Object> toolMap = (Map<String, Object>) tool; |
| 116 | + Map<String, String> attributes = new HashMap<>(); |
| 117 | + attributes.put(TOOL_INPUT_SCHEMA_FIELD, StringUtils.toJson(toolMap.get(MCP_TOOL_INPUT_SCHEMA_FIELD))); |
| 118 | + |
| 119 | + String description = (toolMap.containsKey(MCP_TOOL_DESCRIPTION_FIELD)) |
| 120 | + ? StringUtils.processTextDoc(toolMap.get(MCP_TOOL_DESCRIPTION_FIELD).toString()) |
| 121 | + : McpSseTool.DEFAULT_DESCRIPTION; |
| 122 | + MLToolSpec mlToolSpec = MLToolSpec |
| 123 | + .builder() |
| 124 | + .type(McpSseTool.TYPE) |
| 125 | + .name(toolMap.get(MCP_TOOL_NAME_FIELD).toString()) |
| 126 | + .description(description) |
| 127 | + .attributes(attributes) |
| 128 | + .build(); |
| 129 | + mlToolSpec.addRuntimeResource(MCP_SYNC_CLIENT, client); |
| 130 | + mcpToolSpecs.add(mlToolSpec); |
| 131 | + } |
153 | 132 |
|
154 | 133 | return mcpToolSpecs;
|
155 | 134 | } catch (Exception e) {
|
|
0 commit comments