|
8 | 8 | import static org.junit.Assert.assertEquals;
|
9 | 9 | import static org.junit.Assert.assertThrows;
|
10 | 10 | import static org.mockito.Mockito.when;
|
| 11 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH; |
| 12 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE; |
11 | 13 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT;
|
| 14 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; |
12 | 15 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
|
13 | 16 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
|
| 17 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_PATH; |
| 18 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_INPUT; |
| 19 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_NAME; |
| 20 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID; |
| 21 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID_PATH; |
14 | 22 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
|
15 | 23 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
|
16 | 24 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
|
|
21 | 29 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT;
|
22 | 30 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE;
|
23 | 31 |
|
| 32 | +import java.util.ArrayList; |
24 | 33 | import java.util.Arrays;
|
25 | 34 | import java.util.Collections;
|
26 | 35 | import java.util.HashMap;
|
@@ -695,6 +704,112 @@ public void testConstructToolParams_PlaceholderConfigInputJson() {
|
695 | 704 | Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
|
696 | 705 | }
|
697 | 706 |
|
| 707 | + @Test |
| 708 | + public void testParseLLMOutputWithToolCallsAndResponse() { |
| 709 | + // Test case 1: Response containing both llm_response_filter and tool_calls_path |
| 710 | + String response1 = "{\"metrics\":{\"latencyMs\":4589},\"output\":{\"message\":{\"content\":[{\"text\":\"Let me try another approach to find recommendation-related spans:\"},{\"toolUse\":{\"input\":{\"index\":\"otel-v1-apm-span-000001\",\"query\":{\"size\":10,\"query\":{\"bool\":{\"should\":[{\"wildcard\":{\"name\":\"*recommend*\"}},{\"exists\":{\"field\":\"span.attributes.app@products_recommended@count\"}}],\"minimum_should_match\":1}}}},\"name\":\"SearchIndexTool\",\"toolUseId\":\"tooluse_df9l5U5pTmeS_NFK4VI_zw\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"cacheReadInputTokenCount\":0,\"cacheReadInputTokens\":0,\"cacheWriteInputTokenCount\":0,\"cacheWriteInputTokens\":0,\"inputTokens\":12902,\"outputTokens\":155,\"totalTokens\":13057}}"; |
| 711 | + |
| 712 | + // Test case 2: Response containing only tool_calls_path |
| 713 | + String response2 = "{\"metrics\":{\"latencyMs\":2131},\"output\":{\"message\":{\"content\":[{\"toolUse\":{\"input\":{\"index\":[\"ss4o_logs-2025.04.16\"]},\"name\":\"IndexMappingTool\",\"toolUseId\":\"tooluse_Q9Hkj3YrT3qfcwAIY3Y2WA\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"cacheReadInputTokenCount\":0,\"cacheReadInputTokens\":0,\"cacheWriteInputTokenCount\":0,\"cacheWriteInputTokens\":0,\"inputTokens\":3411,\"outputTokens\":70,\"totalTokens\":3481}}"; |
| 714 | + |
| 715 | + // Test case 3: Response containing only llm_response_filter |
| 716 | + String response3 = "{\"metrics\":{\"latencyMs\":4589},\"output\":{\"message\":{\"content\":[{\"text\":\"Let me try another approach to find recommendation-related spans:\"}],\"role\":\"assistant\"}},\"usage\":{\"cacheReadInputTokenCount\":0,\"cacheReadInputTokens\":0,\"cacheWriteInputTokenCount\":0,\"cacheWriteInputTokens\":0,\"inputTokens\":12902,\"outputTokens\":155,\"totalTokens\":13057}}"; |
| 717 | + |
| 718 | + // Set up parameters |
| 719 | + Map<String, String> parameters = new HashMap<>(); |
| 720 | + parameters.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content"); |
| 721 | + parameters.put(TOOL_CALLS_PATH, "$.choices[0].message.tool_calls"); |
| 722 | + parameters.put(LLM_FINISH_REASON_PATH, "$.stopReason"); |
| 723 | + parameters.put(LLM_FINISH_REASON_TOOL_USE, "tool_use"); |
| 724 | + parameters.put(TOOL_CALLS_TOOL_NAME, "name"); |
| 725 | + parameters.put(TOOL_CALLS_TOOL_INPUT, "input"); |
| 726 | + parameters.put(TOOL_CALL_ID_PATH, "toolUseId"); |
| 727 | + |
| 728 | + // Test case 1 |
| 729 | + ModelTensorOutput modelTensorOutput1 = ModelTensorOutput |
| 730 | + .builder() |
| 731 | + .mlModelOutputs( |
| 732 | + List.of( |
| 733 | + ModelTensors |
| 734 | + .builder() |
| 735 | + .mlModelTensors( |
| 736 | + List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", response1)).build()) |
| 737 | + ) |
| 738 | + .build() |
| 739 | + ) |
| 740 | + ) |
| 741 | + .build(); |
| 742 | + |
| 743 | + Map<String, String> output1 = AgentUtils.parseLLMOutput( |
| 744 | + parameters, |
| 745 | + modelTensorOutput1, |
| 746 | + null, |
| 747 | + Set.of("SearchIndexTool"), |
| 748 | + new ArrayList<>() |
| 749 | + ); |
| 750 | + |
| 751 | + Assert.assertEquals("", output1.get(THOUGHT)); |
| 752 | + Assert.assertEquals("SearchIndexTool", output1.get(ACTION)); |
| 753 | + Assert.assertTrue(output1.get(ACTION_INPUT).contains("otel-v1-apm-span-000001")); |
| 754 | + Assert.assertEquals("tooluse_df9l5U5pTmeS_NFK4VI_zw", output1.get(TOOL_CALL_ID)); |
| 755 | + |
| 756 | + // Test case 2 |
| 757 | + ModelTensorOutput modelTensorOutput2 = ModelTensorOutput |
| 758 | + .builder() |
| 759 | + .mlModelOutputs( |
| 760 | + List.of( |
| 761 | + ModelTensors |
| 762 | + .builder() |
| 763 | + .mlModelTensors( |
| 764 | + List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", response2)).build()) |
| 765 | + ) |
| 766 | + .build() |
| 767 | + ) |
| 768 | + ) |
| 769 | + .build(); |
| 770 | + |
| 771 | + Map<String, String> output2 = AgentUtils.parseLLMOutput( |
| 772 | + parameters, |
| 773 | + modelTensorOutput2, |
| 774 | + null, |
| 775 | + Set.of("IndexMappingTool"), |
| 776 | + new ArrayList<>() |
| 777 | + ); |
| 778 | + |
| 779 | + Assert.assertEquals("", output2.get(THOUGHT)); |
| 780 | + Assert.assertEquals("IndexMappingTool", output2.get(ACTION)); |
| 781 | + Assert.assertTrue(output2.get(ACTION_INPUT).contains("ss4o_logs-2025.04.16")); |
| 782 | + Assert.assertEquals("tooluse_Q9Hkj3YrT3qfcwAIY3Y2WA", output2.get(TOOL_CALL_ID)); |
| 783 | + |
| 784 | + // Test case 3 |
| 785 | + ModelTensorOutput modelTensorOutput3 = ModelTensorOutput |
| 786 | + .builder() |
| 787 | + .mlModelOutputs( |
| 788 | + List.of( |
| 789 | + ModelTensors |
| 790 | + .builder() |
| 791 | + .mlModelTensors( |
| 792 | + List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", response3)).build()) |
| 793 | + ) |
| 794 | + .build() |
| 795 | + ) |
| 796 | + ) |
| 797 | + .build(); |
| 798 | + |
| 799 | + Map<String, String> output3 = AgentUtils.parseLLMOutput( |
| 800 | + parameters, |
| 801 | + modelTensorOutput3, |
| 802 | + null, |
| 803 | + Set.of(), |
| 804 | + new ArrayList<>() |
| 805 | + ); |
| 806 | + |
| 807 | + Assert.assertNull(output3.get(ACTION)); |
| 808 | + Assert.assertNull(output3.get(ACTION_INPUT)); |
| 809 | + Assert.assertNull(output3.get(TOOL_CALL_ID)); |
| 810 | + Assert.assertTrue(output3.get(FINAL_ANSWER).contains("Let me try another approach to find recommendation-related spans:")); |
| 811 | + } |
| 812 | + |
698 | 813 | private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
|
699 | 814 | Map<String, Tool> tools = Map.of("tool1", tool1);
|
700 | 815 | Map<String, MLToolSpec> toolSpecMap = Map
|
|
0 commit comments