Skip to content

Commit 6263800

Browse files
sararobcopybara-github
authored andcommitted
feat: add streamGenerateContent method
PiperOrigin-RevId: 586307690
1 parent 686a0be commit 6263800

File tree

5 files changed

+279
-154
lines changed

5 files changed

+279
-154
lines changed

src/index.ts

+46-36
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
/* tslint:disable */
1919
import {GoogleAuth} from 'google-auth-library';
2020

21-
import {emptyGenerator, processNonStream, processStream} from './process_stream';
22-
import {Content, CountTokensRequest, CountTokensResponse, GenerateContentParams, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting} from './types/content';
21+
import {processNonStream, processStream} from './process_stream';
22+
import {Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting, StreamGenerateContentResult} from './types/content';
2323
import {postRequest} from './util';
2424

2525
// TODO: update this when model names are available
@@ -142,8 +142,8 @@ export class ChatSession {
142142
this._vertex_instance = request._vertex_instance;
143143
}
144144

145-
// TODO: update this to sendMessage / streamSendMessage after generateContent
146-
// is split
145+
// TODO: unbreak this and update to sendMessage / streamSendMessage after
146+
// generateContent is split
147147
async sendMessage(request: string|
148148
Array<string|Part>): Promise<GenerateContentResult> {
149149
let newParts: Part[] = [];
@@ -166,23 +166,15 @@ export class ChatSession {
166166
// successfully?
167167
this.historyInternal.push(newContent);
168168

169-
let generateContentrequest: GenerateContentParams = {
169+
let generateContentrequest: GenerateContentRequest = {
170170
contents: this.historyInternal,
171171
safety_settings: this.safety_settings,
172172
generation_config: this.generation_config,
173-
stream: true,
174173
};
175174

176175
const generateContentResponse =
177176
await this._model_instance.generateContent(generateContentrequest);
178177

179-
// This is currently not iterating over generateContentResponse.stream, it's
180-
// iterating over the list of returned responses
181-
for (const result of generateContentResponse.responses) {
182-
for (const candidate of result.candidates) {
183-
this.historyInternal.push(candidate.content);
184-
}
185-
}
186178
return generateContentResponse;
187179
}
188180
}
@@ -214,7 +206,7 @@ export class GenerativeModel {
214206
* @param request A GenerateContentRequest object with the request contents.
215207
* @return The GenerateContentResponse object with the response candidates.
216208
*/
217-
async generateContent(request: GenerateContentParams):
209+
async generateContent(request: GenerateContentRequest):
218210
Promise<GenerateContentResult> {
219211
const publisherModelEndpoint = `publishers/google/models/${this.model}`;
220212

@@ -230,9 +222,7 @@ export class GenerativeModel {
230222
region: this._vertex_instance.location,
231223
project: this._vertex_instance.project,
232224
resourcePath: publisherModelEndpoint,
233-
// TODO: update when this method is split for streaming / non-streaming
234-
resourceMethod: request.stream ? 'streamGenerateContent' :
235-
'generateContent',
225+
resourceMethod: 'generateContent',
236226
token: await this._vertex_instance.token,
237227
data: generateContentRequest,
238228
apiEndpoint: this._vertex_instance.apiEndpoint,
@@ -246,29 +236,49 @@ export class GenerativeModel {
246236
} catch (e) {
247237
console.log(e);
248238
}
249-
250-
if (!request.stream) {
251-
const result: GenerateContentResult = processNonStream(response);
252-
return Promise.resolve(result);
239+
240+
const result: GenerateContentResult = processNonStream(response);
241+
return Promise.resolve(result);
242+
}
243+
244+
/**
245+
* Make a streamGenerateContent request.
246+
* @param request A GenerateContentRequest object with the request contents.
247+
* @return The GenerateContentResponse object with the response candidates.
248+
*/
249+
async streamGenerateContent(request: GenerateContentRequest):
250+
Promise<StreamGenerateContentResult> {
251+
const publisherModelEndpoint = `publishers/google/models/${this.model}`;
252+
253+
const generateContentRequest: GenerateContentRequest = {
254+
contents: request.contents,
255+
generation_config: request.generation_config ?? this.generation_config,
256+
safety_settings: request.safety_settings ?? this.safety_settings,
253257
}
254258

255-
const streamResult = processStream(response);
256-
// TODO: update chat unit test mock response to reflect logic in stream processing
257-
// then remove the ts-ignore comment and remove request.stream===false
258-
// @ts-ignore
259-
if (request.stream === false && streamResult.stream !== undefined) {
260-
const responses = [];
261-
for await (const resp of streamResult.stream) {
262-
responses.push(resp);
259+
let response;
260+
try {
261+
response = await postRequest({
262+
region: this._vertex_instance.location,
263+
project: this._vertex_instance.project,
264+
resourcePath: publisherModelEndpoint,
265+
resourceMethod: 'streamGenerateContent',
266+
token: await this._vertex_instance.token,
267+
data: generateContentRequest,
268+
apiEndpoint: this._vertex_instance.apiEndpoint,
269+
});
270+
if (response === undefined) {
271+
throw new Error('did not get a valid response.')
263272
}
264-
return {
265-
stream: emptyGenerator(),
266-
responses,
267-
};
268-
} else {
269-
// True or undefined (default true)
270-
return streamResult;
273+
if (!response.ok) {
274+
throw new Error(`${response.status} ${response.statusText}`)
275+
}
276+
} catch (e) {
277+
console.log(e);
271278
}
279+
280+
const streamResult = processStream(response);
281+
return Promise.resolve(streamResult);
272282
}
273283

274284
/**

src/index_test.ts

+66-42
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,40 @@ import 'jasmine';
2020

2121
import {ChatSession, GenerativeModel, StartChatParams, VertexAI} from './index';
2222
import * as StreamFunctions from './process_stream';
23-
import {CountTokensRequest, CountTokensResponse, GenerateContentParams, GenerateContentResult} from './types/content';
24-
import * as PostRequest from './util/post_request';
23+
import {CountTokensRequest, GenerateContentRequest, GenerateContentResponse, GenerateContentResult, StreamGenerateContentResult} from './types/content';
2524

2625
const PROJECT = 'test_project';
2726
const LOCATION = 'test_location';
28-
const MODEL_ID = 'test_model_id';
2927
const TEST_USER_CHAT_MESSAGE =
3028
[{role: 'user', parts: [{text: 'How are you doing today?'}]}];
31-
const TEST_MODEL_RESPONSE = [{
32-
candidates: [
33-
{
34-
index: 1,
35-
content:
36-
{role: 'assistant', parts: [{text: 'I\m doing great! How are you?'}]},
37-
finish_reason: 0,
38-
finish_message: '',
39-
safety_ratings: [{category: 0, threshold: 0}],
40-
},
41-
],
29+
const TEST_CANDIDATES = [
30+
{
31+
index: 1,
32+
content:
33+
{role: 'assistant', parts: [{text: 'I\m doing great! How are you?'}]},
34+
finish_reason: 0,
35+
finish_message: '',
36+
safety_ratings: [{category: 0, threshold: 0}],
37+
},
38+
];
39+
const TEST_MODEL_RESPONSE = {
40+
candidates: TEST_CANDIDATES,
4241
usage_metadata: {prompt_token_count: 0, candidates_token_count: 0}
4342

44-
}];
43+
};
4544

4645
const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com';
4746

47+
/**
48+
* Returns a generator, used to mock the streamGenerateContent response
49+
*/
50+
export async function*
51+
testGenerator(): AsyncGenerator<GenerateContentResponse> {
52+
yield {
53+
candidates: TEST_CANDIDATES,
54+
};
55+
}
56+
4857
describe('VertexAI', () => {
4958
let vertexai: VertexAI;
5059
let model: GenerativeModel;
@@ -59,21 +68,18 @@ describe('VertexAI', () => {
5968
expect(vertexai).toBeInstanceOf(VertexAI);
6069
});
6170

62-
// TODO: update this test when stream and unary implementation is separated
6371
describe('generateContent', () => {
64-
it('returns a GenerateContentResponse when stream=false', async () => {
65-
const req: GenerateContentParams = {
72+
it('returns a GenerateContentResponse', async () => {
73+
const req: GenerateContentRequest = {
6674
contents: TEST_USER_CHAT_MESSAGE,
67-
stream: false,
6875
};
6976
const expectedResult: GenerateContentResult = {
70-
responses: TEST_MODEL_RESPONSE,
77+
response: TEST_MODEL_RESPONSE,
7178
};
7279
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult);
7380
const resp = await model.generateContent(req);
7481
expect(resp).toEqual(expectedResult);
7582
});
76-
// TODO: add test from stream=true here
7783
});
7884

7985
describe('generateContent', () => {
@@ -85,12 +91,11 @@ describe('VertexAI', () => {
8591
model: 'gemini-pro'
8692
});
8793

88-
const req: GenerateContentParams = {
94+
const req: GenerateContentRequest = {
8995
contents: TEST_USER_CHAT_MESSAGE,
90-
stream: false,
9196
};
9297
const expectedResult: GenerateContentResult = {
93-
responses: TEST_MODEL_RESPONSE,
98+
response: TEST_MODEL_RESPONSE,
9499
};
95100
const requestSpy = spyOn(global, 'fetch');
96101
spyOn(StreamFunctions,
@@ -110,12 +115,11 @@ describe('VertexAI', () => {
110115
model: 'gemini-pro'
111116
});
112117

113-
const req: GenerateContentParams = {
118+
const req: GenerateContentRequest = {
114119
contents: TEST_USER_CHAT_MESSAGE,
115-
stream: false,
116120
};
117121
const expectedResult: GenerateContentResult = {
118-
responses: TEST_MODEL_RESPONSE,
122+
response: TEST_MODEL_RESPONSE,
119123
};
120124
const requestSpy = spyOn(global, 'fetch');
121125
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); await
@@ -125,6 +129,21 @@ describe('VertexAI', () => {
125129
});
126130
});
127131

132+
describe('streamGenerateContent', () => {
133+
it('returns a GenerateContentResponse', async () => {
134+
const req: GenerateContentRequest = {
135+
contents: TEST_USER_CHAT_MESSAGE,
136+
};
137+
const expectedResult: StreamGenerateContentResult = {
138+
response: Promise.resolve(TEST_MODEL_RESPONSE),
139+
stream: testGenerator(),
140+
};
141+
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult);
142+
const resp = await model.streamGenerateContent(req);
143+
expect(resp).toEqual(expectedResult);
144+
});
145+
});
146+
128147
describe('startChat', () => {
129148
it('returns a ChatSession', () => {
130149
const req: StartChatParams = {
@@ -174,19 +193,24 @@ describe('ChatSession', () => {
174193
expect(chatSession.history.length).toEqual(1);
175194
});
176195

177-
describe('sendMessage', () => {
178-
it('returns a GenerateContentResponse', async () => {
179-
const req = 'How are you doing today?';
180-
const expectedResult: GenerateContentResult = {
181-
responses: TEST_MODEL_RESPONSE,
182-
stream: StreamFunctions.emptyGenerator(),
183-
};
184-
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult);
185-
const resp = await chatSession.sendMessage(req);
186-
expect(resp).toEqual(expectedResult);
187-
expect(chatSession.history.length).toEqual(3);
188-
});
189-
190-
// TODO: add test cases for different content types passed to sendMessage
191-
});
196+
// TODO: update sendMessage after generateContent and streamGenerateContent
197+
// are working
198+
describe(
199+
'sendMessage',
200+
() => {
201+
// it('returns a GenerateContentResponse', async () => {
202+
// const req = 'How are you doing today?';
203+
// const expectedResult: GenerateContentResult = {
204+
// responses: TEST_MODEL_RESPONSE,
205+
// };
206+
// spyOn(StreamFunctions,
207+
// 'processStream').and.returnValue(expectedResult);
208+
// const resp = await chatSession.sendMessage(req);
209+
// expect(resp).toEqual(expectedResult);
210+
// expect(chatSession.history.length).toEqual(3);
211+
// });
212+
213+
// TODO: add test cases for different content types passed to
214+
// sendMessage
215+
});
192216
});

0 commit comments

Comments
 (0)