Skip to content

Commit 60d540a

Browse files
b4sjoorithin-pullela-aws
authored andcommitted
Use model type to check local or remote model (opensearch-project#3597)
* use model type to check local or remote model Signed-off-by: Sicheng Song <[email protected]> * spotless Signed-off-by: Sicheng Song <[email protected]> * Ignore test resource Signed-off-by: Sicheng Song <[email protected]> * Add java doc Signed-off-by: Sicheng Song <[email protected]> * Handle when model not in cache Signed-off-by: Sicheng Song <[email protected]> * Handle when model not in cache Signed-off-by: Sicheng Song <[email protected]> --------- Signed-off-by: Sicheng Song <[email protected]> (cherry picked from commit 696b1e1)
1 parent f68b0ed commit 60d540a

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

+27-18
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.io.IOException;
1919
import java.util.List;
2020
import java.util.Locale;
21+
import java.util.Objects;
2122
import java.util.Optional;
2223

2324
import org.opensearch.client.node.NodeClient;
@@ -82,27 +83,30 @@ public List<Route> routes() {
8283

8384
@Override
8485
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
85-
String algorithm = request.param(PARAMETER_ALGORITHM);
86+
String userAlgorithm = request.param(PARAMETER_ALGORITHM);
8687
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
8788
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);
8889

89-
if (algorithm == null && functionName.isPresent()) {
90-
algorithm = functionName.get().name();
91-
}
92-
93-
if (algorithm != null) {
94-
MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, algorithm, request);
95-
return channel -> client
96-
.execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel));
90+
// check if the model is in cache
91+
if (functionName.isPresent()) {
92+
MLPredictionTaskRequest predictionRequest = getRequest(
93+
modelId,
94+
functionName.get().name(),
95+
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
96+
request
97+
);
98+
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
9799
}
98100

101+
// If the model isn't in cache
99102
return channel -> {
100103
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
101-
String algoName = mlModel.getAlgorithm().name();
104+
String modelType = mlModel.getAlgorithm().name();
105+
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name());
102106
client
103107
.execute(
104108
MLPredictionTaskAction.INSTANCE,
105-
getRequest(modelId, algoName, request),
109+
getRequest(modelId, modelType, modelAlgorithm, request),
106110
new RestToXContentListener<>(channel)
107111
);
108112
}, e -> {
@@ -120,17 +124,22 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
120124
}
121125

122126
/**
123-
* Creates a MLPredictionTaskRequest from a RestRequest
127+
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
128+
* enabled features and model types, and parses the input data for prediction.
124129
*
125-
* @param request RestRequest
126-
* @return MLPredictionTaskRequest
130+
* @param modelId The ID of the ML model to use for prediction
131+
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
132+
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
133+
* @param request The REST request containing prediction input data
134+
* @return MLPredictionTaskRequest configured with the model and input parameters
127135
*/
128136
@VisibleForTesting
129-
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
137+
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
130138
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
131-
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
139+
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
132140
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
133-
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
141+
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
142+
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
134143
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
135144
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
136145
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
@@ -140,7 +149,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
140149

141150
XContentParser parser = request.contentParser();
142151
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
143-
MLInput mlInput = MLInput.parse(parser, algorithm, actionType);
152+
MLInput mlInput = MLInput.parse(parser, userAlgorithm, actionType);
144153
return new MLPredictionTaskRequest(modelId, mlInput, null);
145154
}
146155

plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java

+7-5
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public class RestMLPredictionActionTests extends OpenSearchTestCase {
6969
@Before
7070
public void setup() {
7171
MockitoAnnotations.openMocks(this);
72-
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty());
72+
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.of(FunctionName.REMOTE));
7373
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
7474
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);
7575
restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting);
@@ -121,7 +121,8 @@ public void testRoutes_Batch() {
121121

122122
public void testGetRequest() throws IOException {
123123
RestRequest request = getRestRequest_PredictModel();
124-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);
124+
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
125+
.getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request);
125126

126127
MLInput mlInput = mlPredictionTaskRequest.getMlInput();
127128
verifyParsedKMeansMLInput(mlInput);
@@ -133,7 +134,8 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException {
133134

134135
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
135136
RestRequest request = getRestRequest_PredictModel();
136-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request);
137+
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
138+
.getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request);
137139
}
138140

139141
public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
@@ -143,7 +145,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
143145
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
144146
RestRequest request = getRestRequest_PredictModel();
145147
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
146-
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request);
148+
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request);
147149
}
148150

149151
public void testPrepareRequest() throws Exception {
@@ -182,7 +184,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception {
182184
thrown.expectMessage("Wrong Action Type");
183185

184186
RestRequest request = getBatchRestRequest_WrongActionType();
185-
restMLPredictionAction.getRequest("model id", "remote", request);
187+
restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request);
186188
}
187189

188190
@Ignore

0 commit comments

Comments
 (0)