18
18
import java .io .IOException ;
19
19
import java .util .List ;
20
20
import java .util .Locale ;
21
+ import java .util .Objects ;
21
22
import java .util .Optional ;
22
23
23
24
import org .opensearch .client .node .NodeClient ;
@@ -82,27 +83,30 @@ public List<Route> routes() {
82
83
83
84
@ Override
84
85
public RestChannelConsumer prepareRequest (RestRequest request , NodeClient client ) throws IOException {
85
- String algorithm = request .param (PARAMETER_ALGORITHM );
86
+ String userAlgorithm = request .param (PARAMETER_ALGORITHM );
86
87
String modelId = getParameterId (request , PARAMETER_MODEL_ID );
87
88
Optional <FunctionName > functionName = modelManager .getOptionalModelFunctionName (modelId );
88
89
89
- if (algorithm == null && functionName .isPresent ()) {
90
- algorithm = functionName .get ().name ();
91
- }
92
-
93
- if (algorithm != null ) {
94
- MLPredictionTaskRequest mlPredictionTaskRequest = getRequest (modelId , algorithm , request );
95
- return channel -> client
96
- .execute (MLPredictionTaskAction .INSTANCE , mlPredictionTaskRequest , new RestToXContentListener <>(channel ));
90
+ // check if the model is in cache
91
+ if (functionName .isPresent ()) {
92
+ MLPredictionTaskRequest predictionRequest = getRequest (
93
+ modelId ,
94
+ functionName .get ().name (),
95
+ Objects .requireNonNullElse (userAlgorithm , functionName .get ().name ()),
96
+ request
97
+ );
98
+ return channel -> client .execute (MLPredictionTaskAction .INSTANCE , predictionRequest , new RestToXContentListener <>(channel ));
97
99
}
98
100
101
+ // If the model isn't in cache
99
102
return channel -> {
100
103
ActionListener <MLModel > listener = ActionListener .wrap (mlModel -> {
101
- String algoName = mlModel .getAlgorithm ().name ();
104
+ String modelType = mlModel .getAlgorithm ().name ();
105
+ String modelAlgorithm = Objects .requireNonNullElse (userAlgorithm , mlModel .getAlgorithm ().name ());
102
106
client
103
107
.execute (
104
108
MLPredictionTaskAction .INSTANCE ,
105
- getRequest (modelId , algoName , request ),
109
+ getRequest (modelId , modelType , modelAlgorithm , request ),
106
110
new RestToXContentListener <>(channel )
107
111
);
108
112
}, e -> {
@@ -120,17 +124,22 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
120
124
}
121
125
122
126
/**
123
- * Creates a MLPredictionTaskRequest from a RestRequest
127
+ * Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
128
+ * enabled features and model types, and parses the input data for prediction.
124
129
*
125
- * @param request RestRequest
126
- * @return MLPredictionTaskRequest
130
+ * @param modelId The ID of the ML model to use for prediction
131
+ * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
132
+ * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
133
+ * @param request The REST request containing prediction input data
134
+ * @return MLPredictionTaskRequest configured with the model and input parameters
127
135
*/
128
136
@ VisibleForTesting
129
- MLPredictionTaskRequest getRequest (String modelId , String algorithm , RestRequest request ) throws IOException {
137
+ MLPredictionTaskRequest getRequest (String modelId , String modelType , String userAlgorithm , RestRequest request ) throws IOException {
130
138
ActionType actionType = ActionType .from (getActionTypeFromRestRequest (request ));
131
- if (FunctionName .REMOTE .name ().equals (algorithm ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
139
+ if (FunctionName .REMOTE .name ().equals (modelType ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
132
140
throw new IllegalStateException (REMOTE_INFERENCE_DISABLED_ERR_MSG );
133
- } else if (FunctionName .isDLModel (FunctionName .from (algorithm .toUpperCase ())) && !mlFeatureEnabledSetting .isLocalModelEnabled ()) {
141
+ } else if (FunctionName .isDLModel (FunctionName .from (modelType .toUpperCase (Locale .ROOT )))
142
+ && !mlFeatureEnabledSetting .isLocalModelEnabled ()) {
134
143
throw new IllegalStateException (LOCAL_MODEL_DISABLED_ERR_MSG );
135
144
} else if (ActionType .BATCH_PREDICT == actionType && !mlFeatureEnabledSetting .isOfflineBatchInferenceEnabled ()) {
136
145
throw new IllegalStateException (BATCH_INFERENCE_DISABLED_ERR_MSG );
@@ -140,7 +149,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
140
149
141
150
XContentParser parser = request .contentParser ();
142
151
ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .nextToken (), parser );
143
- MLInput mlInput = MLInput .parse (parser , algorithm , actionType );
152
+ MLInput mlInput = MLInput .parse (parser , userAlgorithm , actionType );
144
153
return new MLPredictionTaskRequest (modelId , mlInput , null );
145
154
}
146
155
0 commit comments