Skip to content

Commit 5ade755

Browse files
happy-qiaocopybara-github
authored andcommitted
feat: Support RAG in public preview
PiperOrigin-RevId: 635991192
1 parent b8d4af1 commit 5ade755

File tree

5 files changed

+211
-30
lines changed

5 files changed

+211
-30
lines changed

src/functions/generate_content.ts

+10-6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ import {
4343
formatContentRequest,
4444
validateGenerateContentRequest,
4545
validateGenerationConfig,
46+
hasVertexRagStore,
47+
getApiVersion,
4648
} from './pre_fetch_processing';
4749

4850
export async function generateContent(
@@ -75,12 +77,13 @@ export async function generateContent(
7577
};
7678
const response: Response | undefined = await postRequest({
7779
region: location,
78-
resourcePath: resourcePath,
80+
resourcePath,
7981
resourceMethod: constants.GENERATE_CONTENT_METHOD,
8082
token: await token,
8183
data: generateContentRequest,
82-
apiEndpoint: apiEndpoint,
83-
requestOptions: requestOptions,
84+
apiEndpoint,
85+
requestOptions,
86+
apiVersion: getApiVersion(request),
8487
}).catch(e => {
8588
throw new GoogleGenerativeAIError('exception posting request to model', e);
8689
});
@@ -126,12 +129,13 @@ export async function generateContentStream(
126129
};
127130
const response = await postRequest({
128131
region: location,
129-
resourcePath: resourcePath,
132+
resourcePath,
130133
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
131134
token: await token,
132135
data: generateContentRequest,
133-
apiEndpoint: apiEndpoint,
134-
requestOptions: requestOptions,
136+
apiEndpoint,
137+
requestOptions,
138+
apiVersion: getApiVersion(request),
135139
}).catch(e => {
136140
throw new GoogleGenerativeAIError('exception posting request', e);
137141
});

src/functions/pre_fetch_processing.ts

+37
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
import {
1919
GenerateContentRequest,
2020
GenerationConfig,
21+
RetrievalTool,
2122
SafetySetting,
23+
Tool,
2224
} from '../types/content';
25+
import {ClientError} from '../types/errors';
2326
import * as constants from '../util/constants';
2427

2528
export function formatContentRequest(
@@ -55,6 +58,12 @@ export function validateGenerateContentRequest(
5558
}
5659
}
5760
}
61+
62+
if (hasVertexAISearch(request) && hasVertexRagStore(request)) {
63+
throw new ClientError(
64+
'Found both vertexAiSearch and vertexRagStore field are set in tool. Either set vertexAiSearch or vertexRagStore.'
65+
);
66+
}
5867
}
5968

6069
export function validateGenerationConfig(
@@ -67,3 +76,31 @@ export function validateGenerationConfig(
6776
}
6877
return generationConfig;
6978
}
79+
80+
export function getApiVersion(
81+
request: GenerateContentRequest
82+
): 'v1' | 'v1beta1' {
83+
return hasVertexRagStore(request) ? 'v1beta1' : 'v1';
84+
}
85+
86+
export function hasVertexRagStore(request: GenerateContentRequest): boolean {
87+
for (const tool of request?.tools ?? []) {
88+
const retrieval = (tool as RetrievalTool).retrieval;
89+
if (!retrieval) continue;
90+
if (retrieval.vertexRagStore) {
91+
return true;
92+
}
93+
}
94+
return false;
95+
}
96+
97+
export function hasVertexAISearch(request: GenerateContentRequest): boolean {
98+
for (const tool of request?.tools ?? []) {
99+
const retrieval = (tool as RetrievalTool).retrieval;
100+
if (!retrieval) continue;
101+
if (retrieval.vertexAiSearch) {
102+
return true;
103+
}
104+
}
105+
return false;
106+
}

src/functions/test/functions_test.ts

+49-15
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ const TEST_USER_CHAT_MESSAGE = [
4747
{role: constants.USER_ROLE, parts: [{text: TEST_CHAT_MESSAGE_TEXT}]},
4848
];
4949

50+
const CONTENTS = [
51+
{
52+
role: 'user',
53+
parts: [{text: 'What is the weater like in Boston?'}],
54+
},
55+
];
56+
5057
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [
5158
{
5259
role: constants.USER_ROLE,
@@ -208,6 +215,12 @@ const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [
208215
},
209216
];
210217

218+
const TEST_TOOLS_WITH_RAG: Tool[] = [
219+
{
220+
retrieval: {vertexRagStore: {ragResources: [{ragCorpus: 'ragCorpus'}]}},
221+
},
222+
];
223+
211224
const fetchResponseObj = {
212225
status: 200,
213226
statusText: 'OK',
@@ -520,9 +533,7 @@ describe('generateContent', () => {
520533

521534
it('returns a FunctionCall when passed a FunctionDeclaration', async () => {
522535
const req: GenerateContentRequest = {
523-
contents: [
524-
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]},
525-
],
536+
contents: CONTENTS,
526537
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
527538
};
528539
const expectedResult: GenerateContentResult = {
@@ -559,12 +570,7 @@ describe('generateContent', () => {
559570

560571
it('returns a empty FunctionCall list when response contains invalid data', async () => {
561572
const req: GenerateContentRequest = {
562-
contents: [
563-
{
564-
role: 'user',
565-
parts: [{text: 'What is the weater like in Boston?'}],
566-
},
567-
],
573+
contents: CONTENTS,
568574
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
569575
};
570576
const expectedResult: GenerateContentResult = {
@@ -594,12 +600,7 @@ describe('generateContent', () => {
594600

595601
it('returns empty candidates when response is empty', async () => {
596602
const req: GenerateContentRequest = {
597-
contents: [
598-
{
599-
role: 'user',
600-
parts: [{text: 'What is the weater like in Boston?'}],
601-
},
602-
],
603+
contents: CONTENTS,
603604
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
604605
};
605606
fetchSpy.and.resolveTo(new Response(JSON.stringify({}), fetchResponseObj));
@@ -613,6 +614,39 @@ describe('generateContent', () => {
613614
);
614615
expect(actualResult.response.candidates).not.toBeDefined();
615616
});
617+
618+
it('should use v1 apiVersion', async () => {
619+
const request: GenerateContentRequest = {
620+
contents: CONTENTS,
621+
};
622+
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
623+
await generateContent(
624+
TEST_LOCATION,
625+
TEST_RESOURCE_PATH,
626+
TEST_TOKEN_PROMISE,
627+
request,
628+
TEST_API_ENDPOINT
629+
);
630+
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0];
631+
expect(vertexEndpoint).toContain('/v1/');
632+
});
633+
634+
it('should use v1beta1 apiVersion when set RAG in tools', async () => {
635+
const request: GenerateContentRequest = {
636+
contents: CONTENTS,
637+
tools: TEST_TOOLS_WITH_RAG,
638+
};
639+
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
640+
await generateContent(
641+
TEST_LOCATION,
642+
TEST_RESOURCE_PATH,
643+
TEST_TOKEN_PROMISE,
644+
request,
645+
TEST_API_ENDPOINT
646+
);
647+
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0];
648+
expect(vertexEndpoint).toContain('/v1beta1/');
649+
});
616650
});
617651

618652
describe('generateContentStream', () => {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import {Tool} from '../../types/content';
2+
import {
3+
getApiVersion,
4+
validateGenerateContentRequest,
5+
} from '../pre_fetch_processing';
6+
7+
const TOOL1 = {retrieval: {vertexAiSearch: {datastore: 'datastore'}}} as Tool;
8+
const TOOL2 = {
9+
retrieval: {vertexRagStore: {ragResources: [{ragCorpus: 'ragCorpus'}]}},
10+
} as Tool;
11+
const TOOL3 = {
12+
retrieval: {
13+
vertexAiSearch: {datastore: 'datastore'},
14+
vertexRagStore: {ragResources: [{ragCorpus: 'ragCorpus'}]},
15+
},
16+
} as Tool;
17+
18+
const VALID_TOOL_ERROR_MESSAGE =
19+
'[VertexAI.ClientError]: Found both vertexAiSearch and vertexRagStore field are set in tool. Either set vertexAiSearch or vertexRagStore.';
20+
21+
describe('validateTools', () => {
22+
it('should pass validation when set tool correctly', () => {
23+
expect(() =>
24+
validateGenerateContentRequest({tools: [TOOL1], contents: []})
25+
).not.toThrow();
26+
expect(() =>
27+
validateGenerateContentRequest({tools: [TOOL2], contents: []})
28+
).not.toThrow();
29+
});
30+
31+
it('should throw error when set VertexAiSearch and VertexRagStore in two tools in request', () => {
32+
expect(() =>
33+
validateGenerateContentRequest({tools: [TOOL1, TOOL2], contents: []})
34+
).toThrowError(VALID_TOOL_ERROR_MESSAGE);
35+
});
36+
37+
it('should throw error when set VertexAiSearch and VertexRagStore in a single tool in request', () => {
38+
expect(() =>
39+
validateGenerateContentRequest({tools: [TOOL3], contents: []})
40+
).toThrowError(VALID_TOOL_ERROR_MESSAGE);
41+
});
42+
});
43+
44+
describe('getApiVersion', () => {
45+
it('should return v1', () => {
46+
expect(getApiVersion({contents: [], tools: [TOOL1]})).toEqual('v1');
47+
});
48+
49+
it('should return v1beta1', () => {
50+
expect(getApiVersion({contents: [], tools: [TOOL2]})).toEqual('v1beta1');
51+
expect(getApiVersion({contents: [], tools: [TOOL1, TOOL2]})).toEqual(
52+
'v1beta1'
53+
);
54+
});
55+
});

src/types/content.ts

+60-9
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ export declare interface VertexInit {
5151
export declare interface GenerateContentRequest extends BaseModelParams {
5252
/** Array of {@link Content}.*/
5353
contents: Content[];
54-
/** Optional. The user provided system instructions for the model.
54+
/**
55+
* Optional. The user provided system instructions for the model.
5556
* Note: only text should be used in parts of {@link Content}
5657
*/
5758
systemInstruction?: string | Content;
@@ -110,7 +111,8 @@ export declare interface GetGenerativeModelParams extends ModelParams {
110111
tools?: Tool[];
111112
/** Optional. The request options to use for generation. */
112113
requestOptions?: RequestOptions;
113-
/** Optional. The user provided system instructions for the model.
114+
/**
115+
* Optional. The user provided system instructions for the model.
114116
* Note: only text should be used in parts of {@link Content}
115117
*/
116118
systemInstruction?: string | Content;
@@ -138,7 +140,8 @@ export declare interface BaseModelParams {
138140
generationConfig?: GenerationConfig;
139141
/** Optional. Array of {@link Tool}. */
140142
tools?: Tool[];
141-
/** Optional. The user provided system instructions for the model.
143+
/**
144+
* Optional. The user provided system instructions for the model.
142145
* Note: only text should be used in parts of {@link Content}
143146
*/
144147
systemInstruction?: string | Content;
@@ -563,12 +566,20 @@ export declare interface CitationMetadata {
563566
* date).
564567
*/
565568
export declare interface GoogleDate {
566-
/** Year of the date. Must be from 1 to 9999, or 0 to specify a date without a year. */
569+
/**
570+
* Year of the date. Must be from 1 to 9999, or 0 to specify a date without a
571+
* year.
572+
*/
567573
year?: number;
568-
/** Month of the date. Must be from 1 to 12, or 0 to specify a year without a monthi and day. */
574+
/**
575+
* Month of the date. Must be from 1 to 12, or 0 to specify a year without a
576+
* monthi and day.
577+
*/
569578
month?: number;
570-
/** Day of the date. Must be from 1 to 31 and valid for the year and month.
571-
* or 0 to specify a year by itself or a year and month where the day isn't significant
579+
/**
580+
* Day of the date. Must be from 1 to 31 and valid for the year and month.
581+
* or 0 to specify a year by itself or a year and month where the day isn't
582+
* significant
572583
*/
573584
day?: number;
574585
}
@@ -763,6 +774,40 @@ export declare interface RetrievalTool {
763774
retrieval?: Retrieval;
764775
}
765776

777+
export declare interface VertexRagStore {
778+
/**
779+
* Optional. List of corpora for retrieval. Currently only support one corpus
780+
* or multiple files from one corpus. In the future we may open up multiple
781+
* corpora support.
782+
*/
783+
ragResources?: RagResource[];
784+
785+
/** Optional. Number of top k results to return from the selected corpora. */
786+
similarityTopK?: number;
787+
788+
/** Optional. If set this field, results with vector distance smaller than this threshold will be returned. */
789+
vectorDistanceThreshold?: number;
790+
}
791+
792+
/**
793+
* Config of Vertex RagStore grounding checking.
794+
*/
795+
export declare interface RagResource {
796+
/**
797+
* Optional. Vertex RAG Store corpus resource name.
798+
*
799+
* @example
800+
* `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`
801+
*/
802+
ragCorpus?: string;
803+
804+
/**
805+
* Optional. Set this field to select the files under the ragCorpora for
806+
* retrieval.
807+
*/
808+
ragFileIds?: string[];
809+
}
810+
766811
/**
767812
* Defines a retrieval tool that model can call to access external knowledge.
768813
*/
@@ -786,6 +831,10 @@ export declare interface Retrieval {
786831
* VertexAISearch}.
787832
*/
788833
vertexAiSearch?: VertexAISearch;
834+
835+
/** Optional. Set to use data source powered by Vertex RAG store. */
836+
vertexRagStore?: VertexRagStore;
837+
789838
/**
790839
* Optional. Disable using the result from this tool in detecting grounding
791840
* attribution. This does not affect how the result is given to the model for
@@ -896,7 +945,8 @@ export declare interface StartChatParams {
896945
tools?: Tool[];
897946
/** Optional. The base Vertex AI endpoint to use for the request. */
898947
apiEndpoint?: string;
899-
/** Optional. The user provided system instructions for the model.
948+
/**
949+
* Optional. The user provided system instructions for the model.
900950
* Note: only text should be used in parts of {@link Content}
901951
*/
902952
systemInstruction?: string | Content;
@@ -916,7 +966,8 @@ export declare interface StartChatSessionRequest extends StartChatParams {
916966
publisherModelEndpoint: string;
917967
/** The resource path to use for the request. */
918968
resourcePath: string;
919-
/** Optional. The user provided system instructions for the model.
969+
/**
970+
* Optional. The user provided system instructions for the model.
920971
* Note: only text should be used in parts of {@link Content}
921972
*/
922973
systemInstruction?: string | Content;

0 commit comments

Comments
 (0)