Skip to content

Commit 63ce032

Browse files
alx13copybara-github
authored andcommitted
fix: processing of streams, including UTF
PiperOrigin-RevId: 591922402
1 parent 449c7a2 commit 63ce032

File tree

3 files changed

+170
-143
lines changed

3 files changed

+170
-143
lines changed

src/index.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ export class VertexAI {
4444

4545
/**
4646
* @constructor
47-
* @param{VertexInit} init - {@link VertexInit}
47+
* @param{VertexInit} init - {@link VertexInit}
4848
* assign authentication related information to instantiate a Vertex AI client.
4949
*/
5050
constructor(init: VertexInit) {
@@ -106,7 +106,7 @@ export class VertexAI_Internal {
106106
return tokenPromise;
107107
}
108108

109-
/**
109+
/**
110110
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model.
111111
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel}
112112
*/
@@ -158,7 +158,6 @@ export declare interface StartChatSessionRequest extends StartChatParams {
158158
export class ChatSession {
159159
private project: string;
160160
private location: string;
161-
private _send_stream_promise: Promise<void> = Promise.resolve();
162161

163162
private historyInternal: Content[];
164163
private _vertex_instance: VertexAI_Internal;
@@ -214,7 +213,7 @@ export class ChatSession {
214213

215214
return Promise.resolve({response: generateContentResponse});
216215
}
217-
216+
218217
async appendHistory(
219218
streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>,
220219
newContent: Content,
@@ -255,7 +254,7 @@ export class ChatSession {
255254
this._model_instance.generateContentStream(
256255
generateContentrequest);
257256

258-
this._send_stream_promise = this.appendHistory(streamGenerateContentResultPromise, newContent);
257+
await this.appendHistory(streamGenerateContentResultPromise, newContent);
259258
return streamGenerateContentResultPromise;
260259
}
261260
}

src/process_stream.ts

+94-92
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
* limitations under the License.
1616
*/
1717

18-
import {CitationSource, GenerateContentCandidate, GenerateContentResponse, GenerateContentResult, StreamGenerateContentResult,} from './types/content';
18+
import {
19+
CitationSource,
20+
GenerateContentCandidate,
21+
GenerateContentResponse,
22+
GenerateContentResult,
23+
StreamGenerateContentResult,
24+
} from './types/content';
1925

20-
// eslint-disable-next-line no-useless-escape
21-
const responseLineRE = /^data\: (.*)\r\n/;
26+
const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
2227

23-
// TODO: set a better type for `reader`. Setting it to
24-
// `ReadableStreamDefaultReader` results in an error (diagnostic code 2304)
2528
async function* generateResponseSequence(
26-
reader2: any
29+
stream: ReadableStream<GenerateContentResponse>
2730
): AsyncGenerator<GenerateContentResponse> {
31+
const reader = stream.getReader();
2832
while (true) {
29-
const {value, done} = await reader2.read();
33+
const {value, done} = await reader.read();
3034
if (done) {
3135
break;
3236
}
@@ -35,55 +39,91 @@ async function* generateResponseSequence(
3539
}
3640

3741
/**
38-
* Reads a raw stream from the fetch response and joins incomplete
42+
* Process a response.body stream from the backend and return an
43+
* iterator that provides one complete GenerateContentResponse at a time
44+
* and a promise that resolves with a single aggregated
45+
* GenerateContentResponse.
46+
*
47+
* @param response - Response from a fetch call
48+
*/
49+
export function processStream(
50+
response: Response | undefined
51+
): StreamGenerateContentResult {
52+
if (response === undefined) {
53+
throw new Error('Error processing stream because response === undefined');
54+
}
55+
if (!response.body) {
56+
throw new Error('Error processing stream because response.body not found');
57+
}
58+
const inputStream = response.body!.pipeThrough(
59+
new TextDecoderStream('utf8', {fatal: true})
60+
);
61+
const responseStream =
62+
getResponseStream<GenerateContentResponse>(inputStream);
63+
const [stream1, stream2] = responseStream.tee();
64+
return {
65+
stream: generateResponseSequence(stream1),
66+
response: getResponsePromise(stream2),
67+
};
68+
}
69+
70+
async function getResponsePromise(
71+
stream: ReadableStream<GenerateContentResponse>
72+
): Promise<GenerateContentResponse> {
73+
const allResponses: GenerateContentResponse[] = [];
74+
const reader = stream.getReader();
75+
// eslint-disable-next-line no-constant-condition
76+
while (true) {
77+
const {done, value} = await reader.read();
78+
if (done) {
79+
return aggregateResponses(allResponses);
80+
}
81+
allResponses.push(value);
82+
}
83+
}
84+
85+
/**
86+
* Reads a raw stream from the fetch response and join incomplete
3987
* chunks, returning a new stream that provides a single complete
4088
* GenerateContentResponse in each iteration.
4189
*/
42-
function readFromReader(
43-
reader: ReadableStreamDefaultReader
44-
): ReadableStream<GenerateContentResponse> {
45-
let currentText = '';
46-
const stream = new ReadableStream<GenerateContentResponse>({
90+
export function getResponseStream<T>(
91+
inputStream: ReadableStream<string>
92+
): ReadableStream<T> {
93+
const reader = inputStream.getReader();
94+
const stream = new ReadableStream<T>({
4795
start(controller) {
96+
let currentText = '';
4897
return pump();
4998
function pump(): Promise<(() => Promise<void>) | undefined> {
50-
let streamReader;
51-
try {
52-
streamReader = reader.read().then(({value, done}) => {
53-
if (done) {
54-
controller.close();
99+
return reader.read().then(({value, done}) => {
100+
if (done) {
101+
if (currentText.trim()) {
102+
controller.error(new Error('Failed to parse stream'));
55103
return;
56104
}
57-
const chunk = new TextDecoder().decode(value);
58-
currentText += chunk;
59-
const match = currentText.match(responseLineRE);
60-
if (match) {
61-
let parsedResponse: GenerateContentResponse;
62-
try {
63-
parsedResponse = JSON.parse(
64-
match[1]
65-
) as GenerateContentResponse;
66-
} catch (e) {
67-
throw new Error(`Error parsing JSON response: "${match[1]}"`);
68-
}
69-
currentText = '';
70-
if ('candidates' in parsedResponse) {
71-
controller.enqueue(parsedResponse);
72-
} else {
73-
console.warn(
74-
`No candidates in this response: ${parsedResponse}`
75-
);
76-
controller.enqueue({
77-
candidates: [],
78-
});
79-
}
105+
controller.close();
106+
return;
107+
}
108+
109+
currentText += value;
110+
let match = currentText.match(responseLineRE);
111+
let parsedResponse: T;
112+
while (match) {
113+
try {
114+
parsedResponse = JSON.parse(match[1]) as T;
115+
} catch (e) {
116+
controller.error(
117+
new Error(`Error parsing JSON response: "${match[1]}"`)
118+
);
119+
return;
80120
}
81-
return pump();
82-
});
83-
} catch (e) {
84-
throw new Error(`Error reading from stream ${e}.`);
85-
}
86-
return streamReader;
121+
controller.enqueue(parsedResponse);
122+
currentText = currentText.substring(match[0].length);
123+
match = currentText.match(responseLineRE);
124+
}
125+
return pump();
126+
});
87127
}
88128
},
89129
});
@@ -121,20 +161,21 @@ function aggregateResponses(
121161
} as GenerateContentCandidate;
122162
}
123163
if (response.candidates[i].citationMetadata) {
124-
if (!aggregatedResponse.candidates[i]
125-
.citationMetadata?.citationSources) {
164+
if (
165+
!aggregatedResponse.candidates[i].citationMetadata?.citationSources
166+
) {
126167
aggregatedResponse.candidates[i].citationMetadata = {
127168
citationSources: [] as CitationSource[],
128169
};
129170
}
130171

131-
132-
let existingMetadata = response.candidates[i].citationMetadata ?? {};
172+
const existingMetadata = response.candidates[i].citationMetadata ?? {};
133173

134174
if (aggregatedResponse.candidates[i].citationMetadata) {
135175
aggregatedResponse.candidates[i].citationMetadata!.citationSources =
136-
aggregatedResponse.candidates[i]
137-
.citationMetadata!.citationSources.concat(existingMetadata);
176+
aggregatedResponse.candidates[
177+
i
178+
].citationMetadata!.citationSources.concat(existingMetadata);
138179
}
139180
}
140181
aggregatedResponse.candidates[i].finishReason =
@@ -157,45 +198,6 @@ function aggregateResponses(
157198
return aggregatedResponse;
158199
}
159200

160-
// TODO: improve error handling throughout stream processing
161-
/**
162-
* Processes model responses from streamGenerateContent
163-
*/
164-
export function processStream(
165-
response: Response | undefined
166-
): StreamGenerateContentResult {
167-
if (response === undefined) {
168-
throw new Error('Error processing stream because response === undefined');
169-
}
170-
if (!response.body) {
171-
throw new Error('Error processing stream because response.body not found');
172-
}
173-
const reader = response.body.getReader();
174-
const responseStream = readFromReader(reader);
175-
const [stream1, stream2] = responseStream.tee();
176-
const reader1 = stream1.getReader();
177-
const reader2 = stream2.getReader();
178-
const allResponses: GenerateContentResponse[] = [];
179-
const responsePromise = new Promise<GenerateContentResponse>(
180-
// eslint-disable-next-line no-async-promise-executor
181-
async resolve => {
182-
// eslint-disable-next-line no-constant-condition
183-
while (true) {
184-
const {value, done} = await reader1.read();
185-
if (done) {
186-
resolve(aggregateResponses(allResponses));
187-
return;
188-
}
189-
allResponses.push(value);
190-
}
191-
}
192-
);
193-
return {
194-
response: responsePromise,
195-
stream: generateResponseSequence(reader2),
196-
};
197-
}
198-
199201
/**
200202
* Process model responses from generateContent
201203
*/

0 commit comments

Comments
 (0)