Skip to content

Commit 1df6f83

Browse files
pyek-botakolarkunnu
authored andcommitted
[BUG] Agent framework: Fix SearchIndexTool to parse special floating point values and NaN (opensearch-project#3754)
* fix: support null/NaN values returned from document search Signed-off-by: Pavan Yekbote <[email protected]> * spotless Signed-off-by: Pavan Yekbote <[email protected]> * chore: add document for NaN test case Signed-off-by: Pavan Yekbote <[email protected]> * fix: use static variable for gson instance Signed-off-by: Pavan Yekbote <[email protected]> --------- Signed-off-by: Pavan Yekbote <[email protected]> Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
1 parent 8dc7b48 commit 1df6f83

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
2828
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
2929
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
30-
import org.opensearch.ml.common.utils.StringUtils;
3130
import org.opensearch.search.SearchHit;
3231
import org.opensearch.search.builder.SearchSourceBuilder;
3332
import org.opensearch.transport.client.Client;
3433

34+
import com.google.gson.Gson;
35+
import com.google.gson.GsonBuilder;
3536
import com.google.gson.JsonElement;
3637
import com.google.gson.JsonObject;
3738

@@ -62,6 +63,8 @@ public class SearchIndexTool implements Tool {
6263
+ "Invalid value: \\n{\\\"match\\\":{\\\"population_description\\\":\\\"seattle 2023 population\\\"}}\\nThe value is invalid because the match not wrapped by \\\"query\\\".\","
6364
+ "\"additionalProperties\":false}},\"required\":[\"index\",\"query\"],\"additionalProperties\":false}";
6465

66+
private static final Gson GSON = new GsonBuilder().serializeSpecialFloatingPointValues().create();
67+
6568
private String name = TYPE;
6669
private Map<String, Object> attributes;
6770
private String description = DEFAULT_DESCRIPTION;
@@ -114,7 +117,7 @@ private static Map<String, Object> processResponse(SearchHit hit) {
114117
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
115118
try {
116119
String input = parameters.get(INPUT_FIELD);
117-
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class);
120+
JsonObject jsonObject = GSON.fromJson(input, JsonObject.class);
118121
String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null);
119122
String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null);
120123
if (index == null || query == null) {
@@ -131,7 +134,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
131134
for (SearchHit hit : hits) {
132135
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
133136
Map<String, Object> docContent = processResponse(hit);
134-
return StringUtils.gson.toJson(docContent);
137+
return GSON.toJson(docContent);
135138
});
136139
contextBuilder.append(doc).append("\n");
137140
}

ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@
2929
"_source": {
3030
"passage_text": "the price of the api is 2$ per invocation"
3131
}
32+
},
33+
{
34+
"_index": "hybrid-index",
35+
"_id": "3",
36+
"_score": 0.9,
37+
"_source": {
38+
"passage_text": null
39+
}
3240
}
3341
]
3442
}

0 commit comments

Comments
 (0)