Skip to content

Commit 7c05295

Browse files
authored
support MCP session management (#3803)
* support MCP session management Signed-off-by: zane-neo <[email protected]> * Addressing comments Signed-off-by: zane-neo <[email protected]> * add feature flag for mcp server and renaming mcp connector feature flag Signed-off-by: zane-neo <[email protected]> * Address critical comments in #3781 Signed-off-by: zane-neo <[email protected]> --------- Signed-off-by: zane-neo <[email protected]>
1 parent c51c5e4 commit 7c05295

File tree

34 files changed

+865
-227
lines changed

34 files changed

+865
-227
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import java.util.Set;
99

1010
import org.opensearch.Version;
11-
import org.opensearch.common.settings.Setting;
1211

1312
import com.google.common.collect.ImmutableSet;
1413

@@ -46,6 +45,7 @@ public class CommonValue {
4645
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
4746
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
4847
public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job";
48+
public static final String MCP_SESSION_MANAGEMENT_INDEX = ".plugins-ml-mcp-session-management";
4949
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
5050
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";
5151

@@ -59,6 +59,7 @@ public class CommonValue {
5959
public static final String ML_AGENT_INDEX_MAPPING_PATH = "index-mappings/ml_agent.json";
6060
public static final String ML_MEMORY_META_INDEX_MAPPING_PATH = "index-mappings/ml_memory_meta.json";
6161
public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING_PATH = "index-mappings/ml_memory_message.json";
62+
public static final String ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_session_management.json";
6263

6364
// Calculate Versions independently of OpenSearch core version
6465
public static final Version VERSION_2_11_0 = Version.fromString("2.11.0");
@@ -97,11 +98,6 @@ public class CommonValue {
9798
public static final String MCP_CONNECTORS_FIELD = "mcp_connectors";
9899
public static final String MCP_CONNECTOR_ID_FIELD = "mcp_connector_id";
99100

100-
public static final Setting<Boolean> ML_COMMONS_MCP_FEATURE_ENABLED = Setting
101-
.boolSetting("plugins.ml_commons.mcp_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
102-
public static final String ML_COMMONS_MCP_FEATURE_DISABLED_MESSAGE =
103-
"The MCP feature is not enabled. To enable, please update the setting " + ML_COMMONS_MCP_FEATURE_ENABLED.getKey();
104-
105101
// TOOL Constants
106102
public static final String TOOL_INPUT_SCHEMA_FIELD = "input_schema";
107103
}

common/src/main/java/org/opensearch/ml/common/MLIndex.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING_PATH;
1414
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
1515
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX_MAPPING_PATH;
16+
import static org.opensearch.ml.common.CommonValue.ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH;
1617
import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX;
1718
import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_MAPPING_PATH;
1819
import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX;
@@ -38,7 +39,8 @@ public enum MLIndex {
3839
CONTROLLER(ML_CONTROLLER_INDEX, false, ML_CONTROLLER_INDEX_MAPPING_PATH),
3940
AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING_PATH),
4041
MEMORY_META(ML_MEMORY_META_INDEX, false, ML_MEMORY_META_INDEX_MAPPING_PATH),
41-
MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING_PATH);
42+
MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING_PATH),
43+
MCP_SESSION_MANAGEMENT(ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH, false, ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH);
4244

4345
private final String indexName;
4446
// whether we use an alias for the index

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.opensearch.common.settings.Setting;
1717
import org.opensearch.core.common.unit.ByteSizeUnit;
1818
import org.opensearch.core.common.unit.ByteSizeValue;
19-
import org.opensearch.ml.common.CommonValue;
2019

2120
import com.google.common.collect.ImmutableList;
2221

@@ -217,7 +216,15 @@ private MLCommonsSettings() {}
217216
public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting
218217
.boolSetting("plugins.ml_commons.memory_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
219218

220-
public static final Setting<Boolean> ML_COMMONS_MCP_FEATURE_ENABLED = CommonValue.ML_COMMONS_MCP_FEATURE_ENABLED;
219+
public static final Setting<Boolean> ML_COMMONS_MCP_CONNECTOR_ENABLED = Setting
220+
.boolSetting("plugins.ml_commons.mcp_connector_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
221+
public static final String ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE =
222+
"The MCP connector is not enabled. To enable, please update the setting " + ML_COMMONS_MCP_CONNECTOR_ENABLED.getKey();
223+
224+
public static final Setting<Boolean> ML_COMMONS_MCP_SERVER_ENABLED = Setting
225+
.boolSetting("plugins.ml_commons.mcp_server_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
226+
public static final String ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE =
227+
"The MCP server is not enabled. To enable, please update the setting " + ML_COMMONS_MCP_SERVER_ENABLED.getKey();
221228

222229
// Feature flag for enabling search processors for Retrieval Augmented Generation using OpenSearch and Remote Inference.
223230
public static final Setting<Boolean> ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED = Setting
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.mcpserver.action;
7+
8+
import org.opensearch.action.ActionType;
9+
import org.opensearch.action.support.clustermanager.AcknowledgedResponse;
10+
11+
public class MLMcpMessageAction extends ActionType<AcknowledgedResponse> {
12+
public static MLMcpMessageAction INSTANCE = new MLMcpMessageAction();
13+
public static final String NAME = "cluster:admin/opensearch/ml/mcp/message";
14+
15+
private MLMcpMessageAction() {
16+
super(NAME, AcknowledgedResponse::new);
17+
}
18+
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.mcpserver.action;
7+
8+
import org.opensearch.action.ActionType;
9+
import org.opensearch.action.support.clustermanager.AcknowledgedResponse;
10+
11+
public class MLMcpMessageDispatchAction extends ActionType<AcknowledgedResponse> {
12+
public static MLMcpMessageDispatchAction INSTANCE = new MLMcpMessageDispatchAction();
13+
public static final String NAME = "cluster:admin/opensearch/ml/mcp/message/dispatch";
14+
15+
private MLMcpMessageDispatchAction() {
16+
super(NAME, AcknowledgedResponse::new);
17+
}
18+
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.mcpserver.requests.message;
7+
8+
import java.io.ByteArrayInputStream;
9+
import java.io.ByteArrayOutputStream;
10+
import java.io.IOException;
11+
import java.io.UncheckedIOException;
12+
13+
import org.opensearch.action.ActionRequest;
14+
import org.opensearch.action.ActionRequestValidationException;
15+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
16+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
17+
import org.opensearch.core.common.io.stream.StreamInput;
18+
import org.opensearch.core.common.io.stream.StreamOutput;
19+
import org.opensearch.transport.TransportRequest;
20+
21+
import lombok.Builder;
22+
import lombok.Getter;
23+
24+
@Getter
25+
public class MLMcpMessageRequest extends ActionRequest {
26+
27+
private final String nodeId;
28+
29+
private final String sessionId;
30+
31+
private final String requestBody;
32+
33+
public MLMcpMessageRequest(StreamInput in) throws IOException {
34+
super(in);
35+
this.nodeId = in.readString();
36+
this.sessionId = in.readString();
37+
this.requestBody = in.readString();
38+
}
39+
40+
@Builder
41+
public MLMcpMessageRequest(String nodeId, String sessionId, String requestBody) {
42+
super();
43+
this.nodeId = nodeId;
44+
this.sessionId = sessionId;
45+
this.requestBody = requestBody;
46+
}
47+
48+
@Override
49+
public void writeTo(StreamOutput out) throws IOException {
50+
super.writeTo(out);
51+
out.writeString(nodeId);
52+
out.writeString(sessionId);
53+
out.writeString(requestBody);
54+
}
55+
56+
public static MLMcpMessageRequest fromActionRequest(TransportRequest actionRequest) {
57+
if (actionRequest instanceof MLMcpMessageRequest) {
58+
return (MLMcpMessageRequest) actionRequest;
59+
}
60+
61+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
62+
actionRequest.writeTo(osso);
63+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
64+
return new MLMcpMessageRequest(input);
65+
}
66+
} catch (IOException e) {
67+
throw new UncheckedIOException("Failed to parse ActionRequest into MLMcpMessageRequest", e);
68+
}
69+
}
70+
71+
@Override
72+
public ActionRequestValidationException validate() {
73+
return null;
74+
}
75+
}

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpTool.java

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,42 +30,43 @@
3030
@Log4j2
3131
@Data
3232
public class McpTool implements ToXContentObject, Writeable {
33-
private static final String NAME_FIELD = "name";
33+
private static final String TYPE_FIELD = "type";
3434
private static final String DESCRIPTION_FIELD = "description";
35-
private static final String PARAMS_FIELD = "params";
36-
private static final String SCHEMA_FIELD = "schema";
37-
private final String name;
35+
private static final String PARAMS_FIELD = "parameters";
36+
private static final String ATTRIBUTES_FIELD = "attributes";
37+
public static final String SCHEMA_FIELD = "input_schema";
38+
private final String type;
3839
private final String description;
39-
private Map<String, Object> params;
40-
private Map<String, Object> schema;
41-
private static final String nameNotShownExceptionMessage = "name field required";
40+
private Map<String, Object> parameters;
41+
private Map<String, Object> attributes;
42+
private static final String TYPE_NOT_SHOWN_EXCEPTION_MESSAGE = "type field required";
4243

4344
public McpTool(StreamInput streamInput) throws IOException {
44-
name = streamInput.readString();
45-
if (name == null) {
46-
throw new IllegalArgumentException(nameNotShownExceptionMessage);
45+
type = streamInput.readString();
46+
if (type == null) {
47+
throw new IllegalArgumentException(TYPE_NOT_SHOWN_EXCEPTION_MESSAGE);
4748
}
4849
description = streamInput.readOptionalString();
4950
if (streamInput.readBoolean()) {
50-
params = streamInput.readMap(StreamInput::readString, StreamInput::readGenericValue);
51+
parameters = streamInput.readMap(StreamInput::readString, StreamInput::readGenericValue);
5152
}
5253
if (streamInput.readBoolean()) {
53-
schema = streamInput.readMap(StreamInput::readString, StreamInput::readGenericValue);
54+
attributes = streamInput.readMap(StreamInput::readString, StreamInput::readGenericValue);
5455
}
5556
}
5657

57-
public McpTool(String name, String description, Map<String, Object> params, Map<String, Object> schema) {
58-
if (name == null) {
59-
throw new IllegalArgumentException(nameNotShownExceptionMessage);
58+
public McpTool(String type, String description, Map<String, Object> parameters, Map<String, Object> attributes) {
59+
if (type == null) {
60+
throw new IllegalArgumentException(TYPE_NOT_SHOWN_EXCEPTION_MESSAGE);
6061
}
61-
this.name = name;
62+
this.type = type;
6263
this.description = description;
63-
this.params = params;
64-
this.schema = schema;
64+
this.parameters = parameters;
65+
this.attributes = attributes;
6566
}
6667

6768
public static McpTool parse(XContentParser parser) throws IOException {
68-
String name = null;
69+
String type = null;
6970
String description = null;
7071
Map<String, Object> params = null;
7172
Map<String, Object> schema = null;
@@ -75,8 +76,8 @@ public static McpTool parse(XContentParser parser) throws IOException {
7576
parser.nextToken();
7677

7778
switch (fieldName) {
78-
case NAME_FIELD:
79-
name = parser.text();
79+
case TYPE_FIELD:
80+
type = parser.text();
8081
break;
8182
case DESCRIPTION_FIELD:
8283
description = parser.text();
@@ -92,26 +93,26 @@ public static McpTool parse(XContentParser parser) throws IOException {
9293
break;
9394
}
9495
}
95-
if (name == null) {
96-
throw new IllegalArgumentException(nameNotShownExceptionMessage);
96+
if (type == null) {
97+
throw new IllegalArgumentException(TYPE_NOT_SHOWN_EXCEPTION_MESSAGE);
9798
}
98-
return new McpTool(name, description, params, schema);
99+
return new McpTool(type, description, params, schema);
99100
}
100101

101102
@Override
102103
public void writeTo(StreamOutput streamOutput) throws IOException {
103-
streamOutput.writeString(name);
104+
streamOutput.writeString(type);
104105
streamOutput.writeOptionalString(description);
105-
if (params != null) {
106+
if (parameters != null) {
106107
streamOutput.writeBoolean(true);
107-
streamOutput.writeMap(params, StreamOutput::writeString, StreamOutput::writeGenericValue);
108+
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeGenericValue);
108109
} else {
109110
streamOutput.writeBoolean(false);
110111
}
111112

112-
if (schema != null) {
113+
if (attributes != null) {
113114
streamOutput.writeBoolean(true);
114-
streamOutput.writeMap(schema, StreamOutput::writeString, StreamOutput::writeGenericValue);
115+
streamOutput.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeGenericValue);
115116
} else {
116117
streamOutput.writeBoolean(false);
117118
}
@@ -120,15 +121,15 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
120121
@Override
121122
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params xcontentParams) throws IOException {
122123
builder.startObject();
123-
builder.field(NAME_FIELD, name);
124+
builder.field(TYPE_FIELD, type);
124125
if (description != null) {
125126
builder.field(DESCRIPTION_FIELD, description);
126127
}
127-
if (params != null && !params.isEmpty()) {
128-
builder.field(PARAMS_FIELD, params);
128+
if (parameters != null && !parameters.isEmpty()) {
129+
builder.field(PARAMS_FIELD, parameters);
129130
}
130-
if (schema != null && !schema.isEmpty()) {
131-
builder.field(SCHEMA_FIELD, schema);
131+
if (attributes != null && !attributes.isEmpty()) {
132+
builder.field(SCHEMA_FIELD, attributes);
132133
}
133134
builder.endObject();
134135
return builder;
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"_meta": {
3+
"schema_version": 1
4+
},
5+
"properties": {
6+
"node_id": {
7+
"type": "keyword"
8+
},
9+
"status": {
10+
"type": "keyword"
11+
},
12+
"create_time": {
13+
"type": "date",
14+
"format": "strict_date_time||epoch_millis"
15+
},
16+
"last_updated_time": {
17+
"type": "date",
18+
"format": "strict_date_time||epoch_millis"
19+
}
20+
}
21+
}

common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodeRequestTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public void testStreamSerialization() throws IOException {
5959

6060
assertEquals(1, deserializedRequest.getMcpTools().getTools().size());
6161
assertEquals(timestamp, deserializedRequest.getMcpTools().getCreatedTime());
62-
assertEquals("test_tool", deserializedRequest.getMcpTools().getTools().get(0).getName());
62+
assertEquals("test_tool", deserializedRequest.getMcpTools().getTools().get(0).getType());
6363
}
6464

6565
@Test
@@ -82,7 +82,7 @@ public void writeTo(StreamOutput out) throws IOException {
8282
MLMcpToolsRegisterNodeRequest result = MLMcpToolsRegisterNodeRequest.fromActionRequest(transportRequest);
8383

8484
assertNotNull("Converted request should not be null", result);
85-
assertEquals("test_tool", result.getMcpTools().getTools().get(0).getName());
85+
assertEquals("test_tool", result.getMcpTools().getTools().get(0).getType());
8686
assertEquals(timestamp, result.getMcpTools().getCreatedTime());
8787
}
8888

common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequestTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public void testConstructorWithNodeIds() {
4343

4444
assertArrayEquals(nodeIds, request.nodesIds());
4545
assertEquals(1, request.getMcpTools().getTools().size());
46-
assertEquals("metric_analyzer", request.getMcpTools().getTools().get(0).getName());
46+
assertEquals("metric_analyzer", request.getMcpTools().getTools().get(0).getType());
4747
}
4848

4949
@Test
@@ -58,7 +58,7 @@ public void testStreamSerialization() throws IOException {
5858

5959
assertArrayEquals(nodeIds, deserialized.nodesIds());
6060
assertEquals(sampleTools.getCreatedTime(), deserialized.getMcpTools().getCreatedTime());
61-
assertEquals("metric_analyzer", deserialized.getMcpTools().getTools().get(0).getName());
61+
assertEquals("metric_analyzer", deserialized.getMcpTools().getTools().get(0).getType());
6262
}
6363

6464
@Test
@@ -89,7 +89,7 @@ public void writeTo(StreamOutput out) throws IOException {
8989

9090
MLMcpToolsRegisterNodesRequest converted = MLMcpToolsRegisterNodesRequest.fromActionRequest(wrappedRequest);
9191

92-
assertEquals("metric_analyzer", converted.getMcpTools().getTools().get(0).getName());
92+
assertEquals("metric_analyzer", converted.getMcpTools().getTools().get(0).getType());
9393
assertArrayEquals(nodeIds, converted.nodesIds());
9494
}
9595

0 commit comments

Comments
 (0)