Skip to content

[Backport 2.x] (query assist) remove caching agent id #1734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions server/routes/query_assist/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import { isResponseError } from '../../../../../src/core/server/opensearch/client/errors';
import { ERROR_DETAILS, QUERY_ASSIST_API } from '../../../common/constants/query_assist';
import { generateFieldContext } from '../../common/helpers/query_assist/generate_field_context';
import { getAgentIdByConfig, requestWithRetryAgentSearch } from './utils/agents';
import { getAgentIdByConfig, getAgentIdAndRequest } from './utils/agents';
import { AGENT_CONFIGS } from './utils/constants';

export function registerQueryAssistRoutes(router: IRouter) {
Expand All @@ -28,7 +28,7 @@
context,
request,
response
): Promise<IOpenSearchDashboardsResponse<any | ResponseError>> => {

Check warning on line 31 in server/routes/query_assist/routes.ts

View workflow job for this annotation

GitHub Actions / Lint

Unexpected any. Specify a different type
const client = context.core.opensearch.client.asCurrentUser;
try {
// if the call does not throw any error, then the agent is properly configured
Expand All @@ -54,10 +54,10 @@
context,
request,
response
): Promise<IOpenSearchDashboardsResponse<any | ResponseError>> => {

Check warning on line 57 in server/routes/query_assist/routes.ts

View workflow job for this annotation

GitHub Actions / Lint

Unexpected any. Specify a different type
const client = context.core.opensearch.client.asCurrentUser;
try {
const pplRequest = await requestWithRetryAgentSearch({
const pplRequest = await getAgentIdAndRequest({
client,
configName: AGENT_CONFIGS.PPL_AGENT,
body: {
Expand Down Expand Up @@ -110,7 +110,7 @@
context,
request,
response
): Promise<IOpenSearchDashboardsResponse<any | ResponseError>> => {

Check warning on line 113 in server/routes/query_assist/routes.ts

View workflow job for this annotation

GitHub Actions / Lint

Unexpected any. Specify a different type
const client = context.core.opensearch.client.asCurrentUser;
const { index, question, query, response: _response, isError } = request.body;
const queryResponse = JSON.stringify(_response);
Expand All @@ -118,7 +118,7 @@

try {
if (!isError) {
summaryRequest = await requestWithRetryAgentSearch({
summaryRequest = await getAgentIdAndRequest({
client,
configName: AGENT_CONFIGS.RESPONSE_SUMMARY_AGENT,
body: {
Expand All @@ -131,7 +131,7 @@
client.search({ index, size: 1 }),
]);
const fields = generateFieldContext(mappings, sampleDoc);
summaryRequest = await requestWithRetryAgentSearch({
summaryRequest = await getAgentIdAndRequest({
client,
configName: AGENT_CONFIGS.ERROR_SUMMARY_AGENT,
body: {
Expand Down
53 changes: 3 additions & 50 deletions server/routes/query_assist/utils/__tests__/agents.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { ApiResponse } from '@opensearch-project/opensearch';
import { ResponseError } from '@opensearch-project/opensearch/lib/errors';
import { CoreRouteHandlerContext } from '../../../../../../../src/core/server/core_route_handler_context';
import { coreMock, httpServerMock } from '../../../../../../../src/core/server/mocks';
import { agentIdMap, getAgentIdByConfig, requestWithRetryAgentSearch } from '../agents';
import { getAgentIdByConfig, getAgentIdAndRequest } from '../agents';

describe('Agents helper functions', () => {
const coreContext = new CoreRouteHandlerContext(
Expand Down Expand Up @@ -75,27 +75,7 @@ describe('Agents helper functions', () => {
);
});

it('requests with valid agent id', async () => {
agentIdMap.test_agent = 'test-id';
mockedTransport.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
client,
configName: 'test_agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
expect.objectContaining({
path: '/_plugins/_ml/agents/test-id/_execute',
}),
expect.anything()
);
expect(response.body.inference_results[0].output[0].result).toEqual('test response');
});

it('searches for agent id if id is undefined', async () => {
it('searches for agent id and sends request', async () => {
mockedTransport
.mockResolvedValueOnce({
body: {
Expand All @@ -106,36 +86,9 @@ describe('Agents helper functions', () => {
.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
const response = await getAgentIdAndRequest({
client,
configName: 'new_agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
expect.objectContaining({ path: '/_plugins/_ml/agents/new-id/_execute' }),
expect.anything()
);
expect(response.body.inference_results[0].output[0].result).toEqual('test response');
});

it('searches for agent id if id is not found', async () => {
agentIdMap.test_agent = 'non-exist-agent';
mockedTransport
.mockRejectedValueOnce({ statusCode: 404, body: {}, headers: {} })
.mockResolvedValueOnce({
body: {
type: 'agent',
configuration: { agent_id: 'new-id' },
},
})
.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
client,
configName: 'test_agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
Expand Down
40 changes: 12 additions & 28 deletions server/routes/query_assist/utils/agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
*/

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { RequestBody } from '@opensearch-project/opensearch/lib/Transport';
import { RequestBody, TransportRequestPromise } from '@opensearch-project/opensearch/lib/Transport';
import { OpenSearchClient } from '../../../../../../src/core/server';
import { isResponseError } from '../../../../../../src/core/server/opensearch/client/errors';
import { ML_COMMONS_API_PREFIX } from '../../../../common/constants/query_assist';

const AGENT_REQUEST_OPTIONS = {
Expand All @@ -27,8 +26,6 @@ type AgentResponse = ApiResponse<{
}>;
}>;

export const agentIdMap: Record<string, string> = {};

export const getAgentIdByConfig = async (
opensearchClient: OpenSearchClient,
configName: string
Expand All @@ -49,32 +46,19 @@ export const getAgentIdByConfig = async (
}
};

export const requestWithRetryAgentSearch = async (options: {
export const getAgentIdAndRequest = async (options: {
client: OpenSearchClient;
configName: string;
shouldRetryAgentSearch?: boolean;
body: RequestBody;
}): Promise<AgentResponse> => {
const { client, configName, shouldRetryAgentSearch = true, body } = options;
let retry = shouldRetryAgentSearch;
if (!agentIdMap[configName]) {
agentIdMap[configName] = await getAgentIdByConfig(client, configName);
retry = false;
}
return client.transport
.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${agentIdMap[configName]}/_execute`,
body,
},
AGENT_REQUEST_OPTIONS
)
.catch(async (error) => {
if (retry && isResponseError(error) && error.statusCode === 404) {
agentIdMap[configName] = await getAgentIdByConfig(client, configName);
return requestWithRetryAgentSearch({ ...options, shouldRetryAgentSearch: false });
}
return Promise.reject(error);
}) as Promise<AgentResponse>;
const { client, configName, body } = options;
const agentId = await getAgentIdByConfig(client, configName);
return client.transport.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${agentId}/_execute`,
body,
},
AGENT_REQUEST_OPTIONS
) as TransportRequestPromise<AgentResponse>;
};
Loading