Skip to content

Commit daf1669

Browse files
authored
Added support for msearch API to pass search pipeline name (#15923)
* Added support for search pipeline name in multi search API Signed-off-by: Owais <[email protected]> * Updated CHANGELOG Signed-off-by: Owais <[email protected]> * Pulled search pipeline in MultiSearchRequest and updated test Signed-off-by: Owais <[email protected]> * Updated test Signed-off-by: Owais <[email protected]> * Updated SearchRequest with search pipeline from source Signed-off-by: Owais <[email protected]> * Added tests for parseSearchRequest Signed-off-by: Owais <[email protected]> * Guard serialization with version check Signed-off-by: Owais <[email protected]> * Updated version and added another test for serialization Signed-off-by: Owais <[email protected]> --------- Signed-off-by: Owais <[email protected]>
1 parent a42e51d commit daf1669

File tree

7 files changed

+139
-7
lines changed

7 files changed

+139
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1212
- Implement WithFieldName interface in ValuesSourceAggregationBuilder & FieldSortBuilder ([#15916](https://github.com/opensearch-project/OpenSearch/pull/15916))
1313
- Add successfulSearchShardIndices in searchRequestContext ([#15967](https://github.com/opensearch-project/OpenSearch/pull/15967))
1414
- Remove identity-related feature flagged code from the RestController ([#15430](https://github.com/opensearch-project/OpenSearch/pull/15430))
15+
- Add support for msearch API to pass search pipeline name - ([#15923](https://github.com/opensearch-project/OpenSearch/pull/15923))
1516

1617
### Dependencies
1718
- Bump `com.azure:azure-identity` from 1.13.0 to 1.13.2 ([#15578](https://github.com/opensearch-project/OpenSearch/pull/15578))

server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ public static void readMultiLineFormat(
310310
) {
311311
consumer.accept(searchRequest, parser);
312312
}
313+
314+
if (searchRequest.source() != null && searchRequest.source().pipeline() != null) {
315+
searchRequest.pipeline(searchRequest.source().pipeline());
316+
}
313317
// move pointers
314318
from = nextMarker + 1;
315319
}

server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ public static void parseSearchRequest(
210210
searchRequest.routing(request.param("routing"));
211211
searchRequest.preference(request.param("preference"));
212212
searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions()));
213-
searchRequest.pipeline(request.param("search_pipeline"));
213+
searchRequest.pipeline(request.param("search_pipeline", searchRequest.source().pipeline()));
214214

215215
checkRestTotalHits(request, searchRequest);
216216
request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false);

server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ public static HighlightBuilder highlight() {
224224

225225
private Map<String, Object> searchPipelineSource = null;
226226

227+
private String searchPipeline;
228+
227229
/**
228230
* Constructs a new search source builder.
229231
*/
@@ -297,6 +299,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException {
297299
derivedFields = in.readList(DerivedField::new);
298300
}
299301
}
302+
if (in.getVersion().onOrAfter(Version.V_3_0_0)) {
303+
searchPipeline = in.readOptionalString();
304+
}
300305
}
301306

302307
@Override
@@ -377,6 +382,9 @@ public void writeTo(StreamOutput out) throws IOException {
377382
out.writeList(derivedFields);
378383
}
379384
}
385+
if (out.getVersion().onOrAfter(Version.V_3_0_0)) {
386+
out.writeOptionalString(searchPipeline);
387+
}
380388
}
381389

382390
/**
@@ -1111,6 +1119,13 @@ public Map<String, Object> searchPipelineSource() {
11111119
return searchPipelineSource;
11121120
}
11131121

1122+
/**
1123+
* @return a search pipeline name defined within the search source (see {@link org.opensearch.search.pipeline.SearchPipelineService})
1124+
*/
1125+
public String pipeline() {
1126+
return searchPipeline;
1127+
}
1128+
11141129
/**
11151130
* Define a search pipeline to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
11161131
*/
@@ -1119,6 +1134,14 @@ public SearchSourceBuilder searchPipelineSource(Map<String, Object> searchPipeli
11191134
return this;
11201135
}
11211136

1137+
/**
1138+
* Define a search pipeline name to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
1139+
*/
1140+
public SearchSourceBuilder pipeline(String searchPipeline) {
1141+
this.searchPipeline = searchPipeline;
1142+
return this;
1143+
}
1144+
11221145
/**
11231146
* Rewrites this search source builder into its primitive form. e.g. by
11241147
* rewriting the QueryBuilder. If the builder did not change the identity
@@ -1216,6 +1239,7 @@ private SearchSourceBuilder shallowCopy(
12161239
rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder;
12171240
rewrittenBuilder.derivedFieldsObject = derivedFieldsObject;
12181241
rewrittenBuilder.derivedFields = derivedFields;
1242+
rewrittenBuilder.searchPipeline = searchPipeline;
12191243
return rewrittenBuilder;
12201244
}
12211245

@@ -1283,6 +1307,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th
12831307
sort(parser.text());
12841308
} else if (PROFILE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
12851309
profile = parser.booleanValue();
1310+
} else if (SEARCH_PIPELINE.match(currentFieldName, parser.getDeprecationHandler())) {
1311+
searchPipeline = parser.text();
12861312
} else {
12871313
throw new ParsingException(
12881314
parser.getTokenLocation(),
@@ -1612,6 +1638,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t
16121638

16131639
}
16141640

1641+
if (searchPipeline != null) {
1642+
builder.field(SEARCH_PIPELINE.getPreferredName(), searchPipeline);
1643+
}
1644+
16151645
return builder;
16161646
}
16171647

@@ -1889,7 +1919,8 @@ public int hashCode() {
18891919
trackTotalHitsUpTo,
18901920
pointInTimeBuilder,
18911921
derivedFieldsObject,
1892-
derivedFields
1922+
derivedFields,
1923+
searchPipeline
18931924
);
18941925
}
18951926

@@ -1934,7 +1965,8 @@ public boolean equals(Object obj) {
19341965
&& Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo)
19351966
&& Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder)
19361967
&& Objects.equals(derivedFieldsObject, other.derivedFieldsObject)
1937-
&& Objects.equals(derivedFields, other.derivedFields);
1968+
&& Objects.equals(derivedFields, other.derivedFields)
1969+
&& Objects.equals(searchPipeline, other.searchPipeline);
19381970
}
19391971

19401972
@Override

server/src/test/java/org/opensearch/action/search/SearchRequestTests.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import org.opensearch.geometry.LinearRing;
4343
import org.opensearch.index.query.GeoShapeQueryBuilder;
4444
import org.opensearch.index.query.QueryBuilders;
45+
import org.opensearch.rest.RestRequest;
46+
import org.opensearch.rest.action.search.RestSearchAction;
4547
import org.opensearch.search.AbstractSearchTestCase;
4648
import org.opensearch.search.Scroll;
4749
import org.opensearch.search.builder.PointInTimeBuilder;
@@ -50,14 +52,18 @@
5052
import org.opensearch.search.rescore.QueryRescorerBuilder;
5153
import org.opensearch.test.OpenSearchTestCase;
5254
import org.opensearch.test.VersionUtils;
55+
import org.opensearch.test.rest.FakeRestRequest;
5356

5457
import java.io.IOException;
5558
import java.util.ArrayList;
5659
import java.util.List;
60+
import java.util.function.IntConsumer;
5761

5862
import static java.util.Collections.emptyMap;
63+
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
5964
import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
6065
import static org.hamcrest.Matchers.equalTo;
66+
import static org.mockito.Mockito.mock;
6167

6268
public class SearchRequestTests extends AbstractSearchTestCase {
6369

@@ -242,6 +248,19 @@ public void testCopyConstructor() throws IOException {
242248
assertNotSame(deserializedRequest, searchRequest);
243249
}
244250

251+
public void testParseSearchRequestWithUnsupportedSearchType() throws IOException {
252+
RestRequest restRequest = new FakeRestRequest();
253+
SearchRequest searchRequest = createSearchRequest();
254+
IntConsumer setSize = mock(IntConsumer.class);
255+
restRequest.params().put("search_type", "query_and_fetch");
256+
257+
IllegalArgumentException exception = expectThrows(
258+
IllegalArgumentException.class,
259+
() -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize)
260+
);
261+
assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage());
262+
}
263+
245264
public void testEqualsAndHashcode() throws IOException {
246265
checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate);
247266
}
@@ -268,10 +287,7 @@ private SearchRequest mutate(SearchRequest searchRequest) {
268287
);
269288
mutators.add(
270289
() -> mutation.searchType(
271-
randomValueOtherThan(
272-
searchRequest.searchType(),
273-
() -> randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH)
274-
)
290+
randomValueOtherThan(searchRequest.searchType(), () -> randomFrom(DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH))
275291
)
276292
);
277293
mutators.add(() -> mutation.source(randomValueOtherThan(searchRequest.source(), this::createSearchSourceBuilder)));

server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,27 @@ public void testDerivedFieldsParsingAndSerializationObjectType() throws IOExcept
421421
}
422422
}
423423

424+
public void testSearchPipelineParsingAndSerialization() throws IOException {
425+
String restContent = "{ \"query\": { \"match_all\": {} }, \"from\": 0, \"size\": 10, \"search_pipeline\": \"my_pipeline\" }";
426+
String expectedContent = "{\"from\":0,\"size\":10,\"query\":{\"match_all\":{\"boost\":1.0}},\"search_pipeline\":\"my_pipeline\"}";
427+
428+
try (XContentParser parser = createParser(JsonXContent.jsonXContent, restContent)) {
429+
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.fromXContent(parser);
430+
searchSourceBuilder = rewrite(searchSourceBuilder);
431+
432+
try (BytesStreamOutput output = new BytesStreamOutput()) {
433+
searchSourceBuilder.writeTo(output);
434+
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry)) {
435+
SearchSourceBuilder deserializedBuilder = new SearchSourceBuilder(in);
436+
String actualContent = deserializedBuilder.toString();
437+
assertEquals(expectedContent, actualContent);
438+
assertEquals(searchSourceBuilder.hashCode(), deserializedBuilder.hashCode());
439+
assertNotSame(searchSourceBuilder, deserializedBuilder);
440+
}
441+
}
442+
}
443+
}
444+
424445
public void testAggsParsing() throws IOException {
425446
{
426447
String restContent = "{\n"

server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,64 @@ public void testInlinePipeline() throws Exception {
969969
}
970970
}
971971

972+
public void testInlineDefinedPipeline() throws Exception {
973+
SearchPipelineService searchPipelineService = createWithProcessors();
974+
975+
SearchPipelineMetadata metadata = new SearchPipelineMetadata(
976+
Map.of(
977+
"p1",
978+
new PipelineConfiguration(
979+
"p1",
980+
new BytesArray(
981+
"{"
982+
+ "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }],"
983+
+ "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]"
984+
+ "}"
985+
),
986+
MediaTypeRegistry.JSON
987+
)
988+
)
989+
990+
);
991+
ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build();
992+
ClusterState previousState = clusterState;
993+
clusterState = ClusterState.builder(clusterState)
994+
.metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata))
995+
.build();
996+
searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState));
997+
998+
SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1");
999+
SearchRequest searchRequest = new SearchRequest().source(sourceBuilder);
1000+
searchRequest.pipeline(searchRequest.source().pipeline());
1001+
1002+
// Verify pipeline
1003+
PipelinedRequest pipelinedRequest = syncTransformRequest(
1004+
searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver)
1005+
);
1006+
Pipeline pipeline = pipelinedRequest.getPipeline();
1007+
assertEquals("p1", pipeline.getId());
1008+
assertEquals(1, pipeline.getSearchRequestProcessors().size());
1009+
assertEquals(1, pipeline.getSearchResponseProcessors().size());
1010+
1011+
// Verify that pipeline transforms request
1012+
assertEquals(200, pipelinedRequest.source().size());
1013+
1014+
int size = 10;
1015+
SearchHit[] hits = new SearchHit[size];
1016+
for (int i = 0; i < size; i++) {
1017+
hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap());
1018+
hits[i].score(i);
1019+
}
1020+
SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size);
1021+
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
1022+
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);
1023+
1024+
SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse);
1025+
for (int i = 0; i < size; i++) {
1026+
assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001);
1027+
}
1028+
}
1029+
9721030
public void testInfo() {
9731031
SearchPipelineService searchPipelineService = createWithProcessors();
9741032
SearchPipelineInfo info = searchPipelineService.info();

0 commit comments

Comments
 (0)