Skip to content

Commit 598b1dd

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: enable streamSendMessage for chat.
PiperOrigin-RevId: 586738022
1 parent 3b346c2 commit 598b1dd

File tree

2 files changed

+84
-27
lines changed

2 files changed

+84
-27
lines changed

src/index.ts

+59-27
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ export declare interface StartChatSessionRequest extends StartChatParams {
143143
* Session for a multiturn chat with the model
144144
*/
145145
export class ChatSession {
146-
// Substitute apiKey for these in Labs
147146
private project: string;
148147
private location: string;
149148

@@ -157,7 +156,7 @@ export class ChatSession {
157156
get history(): Content[] {
158157
return this.historyInternal;
159158
}
160-
159+
161160
constructor(request: StartChatSessionRequest) {
162161
this.project = request._vertex_instance.project;
163162
this.location = request._vertex_instance.location;
@@ -166,48 +165,62 @@ export class ChatSession {
166165
this._vertex_instance = request._vertex_instance;
167166
}
168167

169-
// TODO: add streamSendMessage that calls streamGenerateContent
170168
async sendMessage(request: string|
171169
Array<string|Part>): Promise<GenerateContentResult> {
172-
let newParts: Part[] = [];
173-
174-
if (typeof request === 'string') {
175-
newParts = [{text: request}];
176-
} else if (Array.isArray(request)) {
177-
for (const item of request) {
178-
if (typeof item === 'string') {
179-
newParts.push({text: item});
180-
} else {
181-
newParts.push(item);
182-
}
183-
}
184-
};
185-
186-
const newContent: Content = {role: 'user', parts: newParts};
187-
170+
const newContent: Content = formulateNewContent(request);
188171
let generateContentrequest: GenerateContentRequest = {
189172
contents: this.historyInternal.concat([newContent]),
190173
safety_settings: this.safety_settings,
191174
generation_config: this.generation_config,
192175
};
193176

194-
const generateContentResponse =
177+
const generateContentResult =
195178
await this._model_instance.generateContent(generateContentrequest);
196-
179+
const generateContentResponse = await generateContentResult.response;
197180
// Only push the latest message to history if the response returned a result
198-
if (generateContentResponse.response.candidates.length !== 0) {
181+
if (generateContentResponse.candidates.length !== 0) {
199182
this.historyInternal.push(newContent);
200183
this.historyInternal.push(
201-
generateContentResponse.response.candidates[0].content);
184+
generateContentResponse.candidates[0].content);
202185
} else {
203186
// TODO: handle promptFeedback in the response
204-
throw new Error('Did not get a response from the model');
187+
throw new Error('Did not get a candidate from the model');
205188
}
206189

207-
return generateContentResponse;
190+
return Promise.resolve({response:generateContentResponse});
208191
}
209-
}
210192

193+
async streamSendMessage(request: string|
194+
Array<string|Part>): Promise<StreamGenerateContentResult> {
195+
const newContent: Content = formulateNewContent(request);
196+
let generateContentrequest: GenerateContentRequest = {
197+
contents: this.historyInternal.concat([newContent]),
198+
safety_settings: this.safety_settings,
199+
generation_config: this.generation_config,
200+
};
201+
202+
const streamGenerateContentResult =
203+
await this._model_instance.streamGenerateContent(generateContentrequest);
204+
const streamGenerateContentResponse =
205+
await streamGenerateContentResult.response;
206+
// Only push the latest message to history if the response returned a result
207+
if (streamGenerateContentResponse.candidates.length !== 0) {
208+
this.historyInternal.push(newContent);
209+
this.historyInternal.push(
210+
streamGenerateContentResponse.candidates[0].content);
211+
} else {
212+
// TODO: handle promptFeedback in the response
213+
throw new Error('Did not get a candidate from the model');
214+
}
215+
216+
return Promise.resolve(
217+
{
218+
response: Promise.resolve(streamGenerateContentResponse),
219+
stream: streamGenerateContentResult.stream,
220+
}
221+
);
222+
}
223+
}
211224

212225
/**
213226
* Base class for generative models.
@@ -345,7 +358,6 @@ export class GenerativeModel {
345358
}
346359
}
347360

348-
349361
startChat(request: StartChatParams): ChatSession {
350362
const startChatRequest = {
351363
history: request.history,
@@ -358,3 +370,23 @@ export class GenerativeModel {
358370
return new ChatSession(startChatRequest);
359371
}
360372
}
373+
374+
function formulateNewContent(request: string|Array<string|Part>): Content {
375+
376+
let newParts: Part[] = [];
377+
378+
if (typeof request === 'string') {
379+
newParts = [{text: request}];
380+
} else if (Array.isArray(request)) {
381+
for (const item of request) {
382+
if (typeof item === 'string') {
383+
newParts.push({text: item});
384+
} else {
385+
newParts.push(item);
386+
}
387+
}
388+
};
389+
390+
const newContent: Content = {role: 'user', parts: newParts};
391+
return newContent;
392+
}

src/index_test.ts

+25
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,31 @@ describe('ChatSession', () => {
273273
// sendMessage
274274
});
275275

276+
describe('streamSendMessage', () => {
277+
it('returns a StreamGenerateContentResponse and appends to history', async () => {
278+
const req = 'How are you doing today?';
279+
const expectedResult: StreamGenerateContentResult = {
280+
response: Promise.resolve(TEST_MODEL_RESPONSE),
281+
stream: testGenerator(),
282+
};
283+
const chatSession= model.startChat({
284+
history: [{role: 'user', parts: [{text: 'How are you doing today?'}]}],
285+
});
286+
spyOn(StreamFunctions, 'processStream')
287+
.and.returnValue(expectedResult);
288+
expect(chatSession.history.length).toEqual(1);
289+
expect(chatSession.history[0].role).toEqual('user');
290+
const result = await chatSession.streamSendMessage(req);
291+
const response = await result.response;
292+
const expectedResponse = await expectedResult.response;
293+
expect(response).toEqual(expectedResponse);
294+
expect(chatSession.history.length).toEqual(3);
295+
expect(chatSession.history[0].role).toEqual('user');
296+
expect(chatSession.history[1].role).toEqual('user');
297+
expect(chatSession.history[2].role).toEqual('assistant');
298+
});
299+
});
300+
276301
describe('imageToBase64', () => {
277302
let imageBuffer: Buffer;
278303

0 commit comments

Comments
 (0)