Skip to content

Commit e573ce6

Browse files
yyyu-googlecopybara-github
authored andcommitted
fix: fix bug in safetyRatings handling, fix incomplete content interfaces, and add unit test for stream response handling
PiperOrigin-RevId: 616867948
1 parent 75e70f0 commit e573ce6

File tree

4 files changed

+577
-10
lines changed

4 files changed

+577
-10
lines changed

src/functions/post_fetch_processing.ts

+16-8
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async function getResponsePromise(
117117
* GenerateContentResponse in each iteration.
118118
* @ignore
119119
*/
120-
export function getResponseStream(
120+
function getResponseStream(
121121
inputStream: ReadableStream<string>
122122
): ReadableStream<unknown> {
123123
const reader = inputStream.getReader();
@@ -170,8 +170,9 @@ export function getResponseStream(
170170
* Aggregates an array of `GenerateContentResponse`s into a single
171171
* GenerateContentResponse.
172172
* @ignore
173+
* @VisibleForTesting
173174
*/
174-
function aggregateResponses(
175+
export function aggregateResponses(
175176
responses: GenerateContentResponse[]
176177
): GenerateContentResponse {
177178
const lastResponse = responses[responses.length - 1];
@@ -216,12 +217,19 @@ function aggregateResponses(
216217
].citationMetadata!.citationSources.concat(existingMetadata);
217218
}
218219
}
219-
aggregatedResponse.candidates[i].finishReason =
220-
response.candidates[i].finishReason;
221-
aggregatedResponse.candidates[i].finishMessage =
222-
response.candidates[i].finishMessage;
223-
aggregatedResponse.candidates[i].safetyRatings =
224-
response.candidates[i].safetyRatings;
220+
const finishResonOfChunk = response.candidates[i].finishReason;
221+
if (finishResonOfChunk) {
222+
aggregatedResponse.candidates[i].finishReason =
223+
response.candidates[i].finishReason;
224+
}
225+
const finishMessageOfChunk = response.candidates[i].finishMessage;
226+
if (finishMessageOfChunk) {
227+
aggregatedResponse.candidates[i].finishMessage = finishMessageOfChunk;
228+
}
229+
const safetyRatingsOfChunk = response.candidates[i].safetyRatings;
230+
if (safetyRatingsOfChunk) {
231+
aggregatedResponse.candidates[i].safetyRatings = safetyRatingsOfChunk;
232+
}
225233
if ('parts' in response.candidates[i].content) {
226234
for (const part of response.candidates[i].content.parts) {
227235
if (part.text) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
import {
19+
AGGREGATED_RESPONSE_STREAM_RESPONSE_CHUNKS_1,
20+
STREAM_RESPONSE_CHUNKS_1,
21+
} from './test_data';
22+
import {aggregateResponses} from '../post_fetch_processing';
23+
24+
describe('aggregateResponses', () => {
25+
it('grounding metadata in muliple chunks for multiple chandidates, should aggregate accordingly', () => {
26+
const actualResult = aggregateResponses(STREAM_RESPONSE_CHUNKS_1);
27+
28+
expect(JSON.stringify(actualResult)).toEqual(
29+
JSON.stringify(AGGREGATED_RESPONSE_STREAM_RESPONSE_CHUNKS_1)
30+
);
31+
});
32+
});

0 commit comments

Comments
 (0)