Skip to content

Commit f5b7674

Browse files
committed
wip
1 parent 22dc728 commit f5b7674

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

+115
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@
88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertThrows;
1010
import static org.mockito.Mockito.when;
11+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH;
12+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE;
1113
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT;
14+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
1215
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
1316
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
17+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_PATH;
18+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_INPUT;
19+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_NAME;
20+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID;
21+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID_PATH;
1422
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
1523
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
1624
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
@@ -21,6 +29,7 @@
2129
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT;
2230
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE;
2331

32+
import java.util.ArrayList;
2433
import java.util.Arrays;
2534
import java.util.Collections;
2635
import java.util.HashMap;
@@ -695,6 +704,112 @@ public void testConstructToolParams_PlaceholderConfigInputJson() {
695704
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
696705
}
697706

707+
@Test
708+
public void testParseLLMOutputWithToolCallsAndResponse() {
709+
// Test case 1: Response containing both llm_response_filter and tool_calls_path
710+
String response1 = "{\"metrics\":{\"latencyMs\":4589},\"output\":{\"message\":{\"content\":[{\"text\":\"Let me try another approach to find recommendation-related spans:\"},{\"toolUse\":{\"input\":{\"index\":\"otel-v1-apm-span-000001\",\"query\":{\"size\":10,\"query\":{\"bool\":{\"should\":[{\"wildcard\":{\"name\":\"*recommend*\"}},{\"exists\":{\"field\":\"span.attributes.app@products_recommended@count\"}}],\"minimum_should_match\":1}}}},\"name\":\"SearchIndexTool\",\"toolUseId\":\"tooluse_df9l5U5pTmeS_NFK4VI_zw\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"cacheReadInputTokenCount\":0,\"cacheReadInputTokens\":0,\"cacheWriteInputTokenCount\":0,\"cacheWriteInputTokens\":0,\"inputTokens\":12902,\"outputTokens\":155,\"totalTokens\":13057}}";
711+
712+
// Test case 2: Response containing only tool_calls_path
713+
String response2 = "{\"metrics\":{\"latencyMs\":2131},\"output\":{\"message\":{\"content\":[{\"toolUse\":{\"input\":{\"index\":[\"ss4o_logs-2025.04.16\"]},\"name\":\"IndexMappingTool\",\"toolUseId\":\"tooluse_Q9Hkj3YrT3qfcwAIY3Y2WA\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"cacheReadInputTokenCount\":0,\"cacheReadInputTokens\":0,\"cacheWriteInputTokenCount\":0,\"cacheWriteInputTokens\":0,\"inputTokens\":3411,\"outputTokens\":70,\"totalTokens\":3481}}";
714+
715+
// Test case 3: Response containing only llm_response_filter
716+
String response3 = "{\"metrics\":{\"latencyMs\":4589},\"output\":{\"message\":{\"content\":[{\"text\":\"Let me try another approach to find recommendation-related spans:\"}],\"role\":\"assistant\"}},\"usage\":{\"cacheReadInputTokenCount\":0,\"cacheReadInputTokens\":0,\"cacheWriteInputTokenCount\":0,\"cacheWriteInputTokens\":0,\"inputTokens\":12902,\"outputTokens\":155,\"totalTokens\":13057}}";
717+
718+
// Set up parameters
719+
Map<String, String> parameters = new HashMap<>();
720+
parameters.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content");
721+
parameters.put(TOOL_CALLS_PATH, "$.choices[0].message.tool_calls");
722+
parameters.put(LLM_FINISH_REASON_PATH, "$.stopReason");
723+
parameters.put(LLM_FINISH_REASON_TOOL_USE, "tool_use");
724+
parameters.put(TOOL_CALLS_TOOL_NAME, "name");
725+
parameters.put(TOOL_CALLS_TOOL_INPUT, "input");
726+
parameters.put(TOOL_CALL_ID_PATH, "toolUseId");
727+
728+
// Test case 1
729+
ModelTensorOutput modelTensorOutput1 = ModelTensorOutput
730+
.builder()
731+
.mlModelOutputs(
732+
List.of(
733+
ModelTensors
734+
.builder()
735+
.mlModelTensors(
736+
List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", response1)).build())
737+
)
738+
.build()
739+
)
740+
)
741+
.build();
742+
743+
Map<String, String> output1 = AgentUtils.parseLLMOutput(
744+
parameters,
745+
modelTensorOutput1,
746+
null,
747+
Set.of("SearchIndexTool"),
748+
new ArrayList<>()
749+
);
750+
751+
Assert.assertEquals("", output1.get(THOUGHT));
752+
Assert.assertEquals("SearchIndexTool", output1.get(ACTION));
753+
Assert.assertTrue(output1.get(ACTION_INPUT).contains("otel-v1-apm-span-000001"));
754+
Assert.assertEquals("tooluse_df9l5U5pTmeS_NFK4VI_zw", output1.get(TOOL_CALL_ID));
755+
756+
// Test case 2
757+
ModelTensorOutput modelTensorOutput2 = ModelTensorOutput
758+
.builder()
759+
.mlModelOutputs(
760+
List.of(
761+
ModelTensors
762+
.builder()
763+
.mlModelTensors(
764+
List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", response2)).build())
765+
)
766+
.build()
767+
)
768+
)
769+
.build();
770+
771+
Map<String, String> output2 = AgentUtils.parseLLMOutput(
772+
parameters,
773+
modelTensorOutput2,
774+
null,
775+
Set.of("IndexMappingTool"),
776+
new ArrayList<>()
777+
);
778+
779+
Assert.assertEquals("", output2.get(THOUGHT));
780+
Assert.assertEquals("IndexMappingTool", output2.get(ACTION));
781+
Assert.assertTrue(output2.get(ACTION_INPUT).contains("ss4o_logs-2025.04.16"));
782+
Assert.assertEquals("tooluse_Q9Hkj3YrT3qfcwAIY3Y2WA", output2.get(TOOL_CALL_ID));
783+
784+
// Test case 3
785+
ModelTensorOutput modelTensorOutput3 = ModelTensorOutput
786+
.builder()
787+
.mlModelOutputs(
788+
List.of(
789+
ModelTensors
790+
.builder()
791+
.mlModelTensors(
792+
List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", response3)).build())
793+
)
794+
.build()
795+
)
796+
)
797+
.build();
798+
799+
Map<String, String> output3 = AgentUtils.parseLLMOutput(
800+
parameters,
801+
modelTensorOutput3,
802+
null,
803+
Set.of(),
804+
new ArrayList<>()
805+
);
806+
807+
Assert.assertNull(output3.get(ACTION));
808+
Assert.assertNull(output3.get(ACTION_INPUT));
809+
Assert.assertNull(output3.get(TOOL_CALL_ID));
810+
Assert.assertTrue(output3.get(FINAL_ANSWER).contains("Let me try another approach to find recommendation-related spans:"));
811+
}
812+
698813
private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
699814
Map<String, Tool> tools = Map.of("tool1", tool1);
700815
Map<String, MLToolSpec> toolSpecMap = Map

0 commit comments

Comments
 (0)