Skip to content

Commit 085cc7a

Browse files
committed
Added support for msearch API to pass search pipeline name (opensearch-project#15923)
* Added support for search pipeline name in multi search API Signed-off-by: Owais <[email protected]> * Updated CHANGELOG Signed-off-by: Owais <[email protected]> * Pulled search pipeline in MultiSearchRequest and updated test Signed-off-by: Owais <[email protected]> * Updated test Signed-off-by: Owais <[email protected]> * Updated SearchRequest with search pipeline from source Signed-off-by: Owais <[email protected]> * Added tests for parseSearchRequest Signed-off-by: Owais <[email protected]> * Guard serialization with version check Signed-off-by: Owais <[email protected]> * Updated version and added another test for serialization Signed-off-by: Owais <[email protected]> --------- Signed-off-by: Owais <[email protected]>
1 parent a1846cf commit 085cc7a

File tree

7 files changed

+139
-7
lines changed

7 files changed

+139
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1313
- Add successfulSearchShardIndices in searchRequestContext ([#15967](https://github.com/opensearch-project/OpenSearch/pull/15967))
1414
- Remove identity-related feature flagged code from the RestController ([#15430](https://github.com/opensearch-project/OpenSearch/pull/15430))
1515
- Fallback to Remote cluster-state on Term-Version check mismatch - ([#15424](https://github.com/opensearch-project/OpenSearch/pull/15424))
16+
- Add support for msearch API to pass search pipeline name - ([#15923](https://github.com/opensearch-project/OpenSearch/pull/15923))
1617

1718
### Dependencies
1819
- Bump `org.apache.logging.log4j:log4j-core` from 2.23.1 to 2.24.0 ([#15858](https://github.com/opensearch-project/OpenSearch/pull/15858))

server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ public static void readMultiLineFormat(
310310
) {
311311
consumer.accept(searchRequest, parser);
312312
}
313+
314+
if (searchRequest.source() != null && searchRequest.source().pipeline() != null) {
315+
searchRequest.pipeline(searchRequest.source().pipeline());
316+
}
313317
// move pointers
314318
from = nextMarker + 1;
315319
}

server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ public static void parseSearchRequest(
210210
searchRequest.routing(request.param("routing"));
211211
searchRequest.preference(request.param("preference"));
212212
searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions()));
213-
searchRequest.pipeline(request.param("search_pipeline"));
213+
searchRequest.pipeline(request.param("search_pipeline", searchRequest.source().pipeline()));
214214

215215
checkRestTotalHits(request, searchRequest);
216216
request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false);

server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ public static HighlightBuilder highlight() {
225225

226226
private Map<String, Object> searchPipelineSource = null;
227227

228+
private String searchPipeline;
229+
228230
/**
229231
* Constructs a new search source builder.
230232
*/
@@ -306,6 +308,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException {
306308
derivedFields = in.readList(DerivedField::new);
307309
}
308310
}
311+
if (in.getVersion().onOrAfter(Version.V_3_0_0)) {
312+
searchPipeline = in.readOptionalString();
313+
}
309314
}
310315

311316
@Override
@@ -394,6 +399,9 @@ public void writeTo(StreamOutput out) throws IOException {
394399
out.writeList(derivedFields);
395400
}
396401
}
402+
if (out.getVersion().onOrAfter(Version.V_3_0_0)) {
403+
out.writeOptionalString(searchPipeline);
404+
}
397405
}
398406

399407
/**
@@ -1128,6 +1136,13 @@ public Map<String, Object> searchPipelineSource() {
11281136
return searchPipelineSource;
11291137
}
11301138

1139+
/**
1140+
* @return a search pipeline name defined within the search source (see {@link org.opensearch.search.pipeline.SearchPipelineService})
1141+
*/
1142+
public String pipeline() {
1143+
return searchPipeline;
1144+
}
1145+
11311146
/**
11321147
* Define a search pipeline to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
11331148
*/
@@ -1136,6 +1151,14 @@ public SearchSourceBuilder searchPipelineSource(Map<String, Object> searchPipeli
11361151
return this;
11371152
}
11381153

1154+
/**
1155+
* Define a search pipeline name to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
1156+
*/
1157+
public SearchSourceBuilder pipeline(String searchPipeline) {
1158+
this.searchPipeline = searchPipeline;
1159+
return this;
1160+
}
1161+
11391162
/**
11401163
* Rewrites this search source builder into its primitive form. e.g. by
11411164
* rewriting the QueryBuilder. If the builder did not change the identity
@@ -1233,6 +1256,7 @@ private SearchSourceBuilder shallowCopy(
12331256
rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder;
12341257
rewrittenBuilder.derivedFieldsObject = derivedFieldsObject;
12351258
rewrittenBuilder.derivedFields = derivedFields;
1259+
rewrittenBuilder.searchPipeline = searchPipeline;
12361260
return rewrittenBuilder;
12371261
}
12381262

@@ -1300,6 +1324,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th
13001324
sort(parser.text());
13011325
} else if (PROFILE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
13021326
profile = parser.booleanValue();
1327+
} else if (SEARCH_PIPELINE.match(currentFieldName, parser.getDeprecationHandler())) {
1328+
searchPipeline = parser.text();
13031329
} else {
13041330
throw new ParsingException(
13051331
parser.getTokenLocation(),
@@ -1629,6 +1655,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t
16291655

16301656
}
16311657

1658+
if (searchPipeline != null) {
1659+
builder.field(SEARCH_PIPELINE.getPreferredName(), searchPipeline);
1660+
}
1661+
16321662
return builder;
16331663
}
16341664

@@ -1906,7 +1936,8 @@ public int hashCode() {
19061936
trackTotalHitsUpTo,
19071937
pointInTimeBuilder,
19081938
derivedFieldsObject,
1909-
derivedFields
1939+
derivedFields,
1940+
searchPipeline
19101941
);
19111942
}
19121943

@@ -1951,7 +1982,8 @@ public boolean equals(Object obj) {
19511982
&& Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo)
19521983
&& Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder)
19531984
&& Objects.equals(derivedFieldsObject, other.derivedFieldsObject)
1954-
&& Objects.equals(derivedFields, other.derivedFields);
1985+
&& Objects.equals(derivedFields, other.derivedFields)
1986+
&& Objects.equals(searchPipeline, other.searchPipeline);
19551987
}
19561988

19571989
@Override

server/src/test/java/org/opensearch/action/search/SearchRequestTests.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,27 @@
4343
import org.opensearch.geometry.LinearRing;
4444
import org.opensearch.index.query.GeoShapeQueryBuilder;
4545
import org.opensearch.index.query.QueryBuilders;
46+
import org.opensearch.rest.RestRequest;
47+
import org.opensearch.rest.action.search.RestSearchAction;
4648
import org.opensearch.search.AbstractSearchTestCase;
4749
import org.opensearch.search.Scroll;
4850
import org.opensearch.search.builder.PointInTimeBuilder;
4951
import org.opensearch.search.builder.SearchSourceBuilder;
5052
import org.opensearch.search.rescore.QueryRescorerBuilder;
5153
import org.opensearch.test.OpenSearchTestCase;
5254
import org.opensearch.test.VersionUtils;
55+
import org.opensearch.test.rest.FakeRestRequest;
5356

5457
import java.io.IOException;
5558
import java.util.ArrayList;
5659
import java.util.List;
60+
import java.util.function.IntConsumer;
5761

5862
import static java.util.Collections.emptyMap;
63+
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
5964
import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
6065
import static org.hamcrest.Matchers.equalTo;
66+
import static org.mockito.Mockito.mock;
6167

6268
public class SearchRequestTests extends AbstractSearchTestCase {
6369

@@ -222,6 +228,19 @@ public void testCopyConstructor() throws IOException {
222228
assertNotSame(deserializedRequest, searchRequest);
223229
}
224230

231+
public void testParseSearchRequestWithUnsupportedSearchType() throws IOException {
232+
RestRequest restRequest = new FakeRestRequest();
233+
SearchRequest searchRequest = createSearchRequest();
234+
IntConsumer setSize = mock(IntConsumer.class);
235+
restRequest.params().put("search_type", "query_and_fetch");
236+
237+
IllegalArgumentException exception = expectThrows(
238+
IllegalArgumentException.class,
239+
() -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize)
240+
);
241+
assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage());
242+
}
243+
225244
public void testEqualsAndHashcode() throws IOException {
226245
checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate);
227246
}
@@ -248,10 +267,7 @@ private SearchRequest mutate(SearchRequest searchRequest) {
248267
);
249268
mutators.add(
250269
() -> mutation.searchType(
251-
randomValueOtherThan(
252-
searchRequest.searchType(),
253-
() -> randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH)
254-
)
270+
randomValueOtherThan(searchRequest.searchType(), () -> randomFrom(DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH))
255271
)
256272
);
257273
mutators.add(() -> mutation.source(randomValueOtherThan(searchRequest.source(), this::createSearchSourceBuilder)));

server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,27 @@ public void testDerivedFieldsParsingAndSerializationObjectType() throws IOExcept
421421
}
422422
}
423423

424+
public void testSearchPipelineParsingAndSerialization() throws IOException {
425+
String restContent = "{ \"query\": { \"match_all\": {} }, \"from\": 0, \"size\": 10, \"search_pipeline\": \"my_pipeline\" }";
426+
String expectedContent = "{\"from\":0,\"size\":10,\"query\":{\"match_all\":{\"boost\":1.0}},\"search_pipeline\":\"my_pipeline\"}";
427+
428+
try (XContentParser parser = createParser(JsonXContent.jsonXContent, restContent)) {
429+
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.fromXContent(parser);
430+
searchSourceBuilder = rewrite(searchSourceBuilder);
431+
432+
try (BytesStreamOutput output = new BytesStreamOutput()) {
433+
searchSourceBuilder.writeTo(output);
434+
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry)) {
435+
SearchSourceBuilder deserializedBuilder = new SearchSourceBuilder(in);
436+
String actualContent = deserializedBuilder.toString();
437+
assertEquals(expectedContent, actualContent);
438+
assertEquals(searchSourceBuilder.hashCode(), deserializedBuilder.hashCode());
439+
assertNotSame(searchSourceBuilder, deserializedBuilder);
440+
}
441+
}
442+
}
443+
}
444+
424445
public void testAggsParsing() throws IOException {
425446
{
426447
String restContent = "{\n"

server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,64 @@ public void testInlinePipeline() throws Exception {
969969
}
970970
}
971971

972+
public void testInlineDefinedPipeline() throws Exception {
973+
SearchPipelineService searchPipelineService = createWithProcessors();
974+
975+
SearchPipelineMetadata metadata = new SearchPipelineMetadata(
976+
Map.of(
977+
"p1",
978+
new PipelineConfiguration(
979+
"p1",
980+
new BytesArray(
981+
"{"
982+
+ "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }],"
983+
+ "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]"
984+
+ "}"
985+
),
986+
MediaTypeRegistry.JSON
987+
)
988+
)
989+
990+
);
991+
ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build();
992+
ClusterState previousState = clusterState;
993+
clusterState = ClusterState.builder(clusterState)
994+
.metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata))
995+
.build();
996+
searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState));
997+
998+
SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1");
999+
SearchRequest searchRequest = new SearchRequest().source(sourceBuilder);
1000+
searchRequest.pipeline(searchRequest.source().pipeline());
1001+
1002+
// Verify pipeline
1003+
PipelinedRequest pipelinedRequest = syncTransformRequest(
1004+
searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver)
1005+
);
1006+
Pipeline pipeline = pipelinedRequest.getPipeline();
1007+
assertEquals("p1", pipeline.getId());
1008+
assertEquals(1, pipeline.getSearchRequestProcessors().size());
1009+
assertEquals(1, pipeline.getSearchResponseProcessors().size());
1010+
1011+
// Verify that pipeline transforms request
1012+
assertEquals(200, pipelinedRequest.source().size());
1013+
1014+
int size = 10;
1015+
SearchHit[] hits = new SearchHit[size];
1016+
for (int i = 0; i < size; i++) {
1017+
hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap());
1018+
hits[i].score(i);
1019+
}
1020+
SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size);
1021+
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
1022+
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);
1023+
1024+
SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse);
1025+
for (int i = 0; i < size; i++) {
1026+
assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001);
1027+
}
1028+
}
1029+
9721030
public void testInfo() {
9731031
SearchPipelineService searchPipelineService = createWithProcessors();
9741032
SearchPipelineInfo info = searchPipelineService.info();

0 commit comments

Comments
 (0)