Skip to content

Commit 522118f

Browse files
fix(streaming): accumulate citations (#675)
1 parent 751ecd0 commit 522118f

File tree

2 files changed

+123
-44
lines changed

2 files changed

+123
-44
lines changed

src/lib/BetaMessageStream.ts

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
type MessageCreateParams as BetaMessageCreateParams,
1010
type MessageCreateParamsBase as BetaMessageCreateParamsBase,
1111
type BetaTextBlock,
12+
type BetaTextCitation,
1213
} from '@anthropic-ai/sdk/resources/beta/messages/messages';
1314
import { type ReadableStream, type Response } from '@anthropic-ai/sdk/_shims/index';
1415
import { Stream } from '@anthropic-ai/sdk/streaming';
@@ -18,6 +19,7 @@ export interface MessageStreamEvents {
1819
connect: () => void;
1920
streamEvent: (event: BetaMessageStreamEvent, snapshot: BetaMessage) => void;
2021
text: (textDelta: string, textSnapshot: string) => void;
22+
citation: (citation: BetaTextCitation, citationsSnapshot: BetaTextCitation[]) => void;
2123
inputJson: (partialJson: string, jsonSnapshot: unknown) => void;
2224
message: (message: BetaMessage) => void;
2325
contentBlock: (content: BetaContentBlock) => void;
@@ -413,12 +415,27 @@ export class BetaMessageStream implements AsyncIterable<BetaMessageStreamEvent>
413415
switch (event.type) {
414416
case 'content_block_delta': {
415417
const content = messageSnapshot.content.at(-1)!;
416-
if (event.delta.type === 'text_delta' && content.type === 'text') {
417-
this._emit('text', event.delta.text, content.text || '');
418-
} else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') {
419-
if (content.input) {
420-
this._emit('inputJson', event.delta.partial_json, content.input);
418+
switch (event.delta.type) {
419+
case 'text_delta': {
420+
if (content.type === 'text') {
421+
this._emit('text', event.delta.text, content.text || '');
422+
}
423+
break;
421424
}
425+
case 'citations_delta': {
426+
if (content.type === 'text') {
427+
this._emit('citation', event.delta.citation, content.citations ?? []);
428+
}
429+
break;
430+
}
431+
case 'input_json_delta': {
432+
if (content.type === 'tool_use' && content.input) {
433+
this._emit('inputJson', event.delta.partial_json, content.input);
434+
}
435+
break;
436+
}
437+
default:
438+
checkNever(event.delta);
422439
}
423440
break;
424441
}
@@ -505,24 +522,43 @@ export class BetaMessageStream implements AsyncIterable<BetaMessageStreamEvent>
505522
return snapshot;
506523
case 'content_block_delta': {
507524
const snapshotContent = snapshot.content.at(event.index);
508-
if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') {
509-
snapshotContent.text += event.delta.text;
510-
} else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') {
511-
// we need to keep track of the raw JSON string as well so that we can
512-
// re-parse it for each delta, for now we just store it as an untyped
513-
// non-enumerable property on the snapshot
514-
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
515-
jsonBuf += event.delta.partial_json;
516-
517-
Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
518-
value: jsonBuf,
519-
enumerable: false,
520-
writable: true,
521-
});
522-
523-
if (jsonBuf) {
524-
snapshotContent.input = partialParse(jsonBuf);
525+
526+
switch (event.delta.type) {
527+
case 'text_delta': {
528+
if (snapshotContent?.type === 'text') {
529+
snapshotContent.text += event.delta.text;
530+
}
531+
break;
525532
}
533+
case 'citations_delta': {
534+
if (snapshotContent?.type === 'text') {
535+
snapshotContent.citations ??= [];
536+
snapshotContent.citations.push(event.delta.citation);
537+
}
538+
break;
539+
}
540+
case 'input_json_delta': {
541+
if (snapshotContent?.type === 'tool_use') {
542+
// we need to keep track of the raw JSON string as well so that we can
543+
// re-parse it for each delta, for now we just store it as an untyped
544+
// non-enumerable property on the snapshot
545+
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
546+
jsonBuf += event.delta.partial_json;
547+
548+
Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
549+
value: jsonBuf,
550+
enumerable: false,
551+
writable: true,
552+
});
553+
554+
if (jsonBuf) {
555+
snapshotContent.input = partialParse(jsonBuf);
556+
}
557+
}
558+
break;
559+
}
560+
default:
561+
checkNever(event.delta);
526562
}
527563
return snapshot;
528564
}
@@ -597,3 +633,6 @@ export class BetaMessageStream implements AsyncIterable<BetaMessageStreamEvent>
597633
return stream.toReadableStream();
598634
}
599635
}
636+
637+
// used to ensure exhaustive case matching without throwing a runtime error
638+
function checkNever(x: never) {}

src/lib/MessageStream.ts

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
type MessageCreateParams,
1010
type MessageCreateParamsBase,
1111
type TextBlock,
12+
type TextCitation,
1213
} from '@anthropic-ai/sdk/resources/messages';
1314
import { type ReadableStream, type Response } from '@anthropic-ai/sdk/_shims/index';
1415
import { Stream } from '@anthropic-ai/sdk/streaming';
@@ -18,6 +19,7 @@ export interface MessageStreamEvents {
1819
connect: () => void;
1920
streamEvent: (event: MessageStreamEvent, snapshot: Message) => void;
2021
text: (textDelta: string, textSnapshot: string) => void;
22+
citation: (citation: TextCitation, citationsSnapshot: TextCitation[]) => void;
2123
inputJson: (partialJson: string, jsonSnapshot: unknown) => void;
2224
message: (message: Message) => void;
2325
contentBlock: (content: ContentBlock) => void;
@@ -413,12 +415,27 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
413415
switch (event.type) {
414416
case 'content_block_delta': {
415417
const content = messageSnapshot.content.at(-1)!;
416-
if (event.delta.type === 'text_delta' && content.type === 'text') {
417-
this._emit('text', event.delta.text, content.text || '');
418-
} else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') {
419-
if (content.input) {
420-
this._emit('inputJson', event.delta.partial_json, content.input);
418+
switch (event.delta.type) {
419+
case 'text_delta': {
420+
if (content.type === 'text') {
421+
this._emit('text', event.delta.text, content.text || '');
422+
}
423+
break;
421424
}
425+
case 'citations_delta': {
426+
if (content.type === 'text') {
427+
this._emit('citation', event.delta.citation, content.citations ?? []);
428+
}
429+
break;
430+
}
431+
case 'input_json_delta': {
432+
if (content.type === 'tool_use' && content.input) {
433+
this._emit('inputJson', event.delta.partial_json, content.input);
434+
}
435+
break;
436+
}
437+
default:
438+
checkNever(event.delta);
422439
}
423440
break;
424441
}
@@ -505,25 +522,45 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
505522
return snapshot;
506523
case 'content_block_delta': {
507524
const snapshotContent = snapshot.content.at(event.index);
508-
if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') {
509-
snapshotContent.text += event.delta.text;
510-
} else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') {
511-
// we need to keep track of the raw JSON string as well so that we can
512-
// re-parse it for each delta, for now we just store it as an untyped
513-
// non-enumerable property on the snapshot
514-
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
515-
jsonBuf += event.delta.partial_json;
516-
517-
Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
518-
value: jsonBuf,
519-
enumerable: false,
520-
writable: true,
521-
});
522-
523-
if (jsonBuf) {
524-
snapshotContent.input = partialParse(jsonBuf);
525+
526+
switch (event.delta.type) {
527+
case 'text_delta': {
528+
if (snapshotContent?.type === 'text') {
529+
snapshotContent.text += event.delta.text;
530+
}
531+
break;
532+
}
533+
case 'citations_delta': {
534+
if (snapshotContent?.type === 'text') {
535+
snapshotContent.citations ??= [];
536+
snapshotContent.citations.push(event.delta.citation);
537+
}
538+
break;
525539
}
540+
case 'input_json_delta': {
541+
if (snapshotContent?.type === 'tool_use') {
542+
// we need to keep track of the raw JSON string as well so that we can
543+
// re-parse it for each delta, for now we just store it as an untyped
544+
// non-enumerable property on the snapshot
545+
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
546+
jsonBuf += event.delta.partial_json;
547+
548+
Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
549+
value: jsonBuf,
550+
enumerable: false,
551+
writable: true,
552+
});
553+
554+
if (jsonBuf) {
555+
snapshotContent.input = partialParse(jsonBuf);
556+
}
557+
}
558+
break;
559+
}
560+
default:
561+
checkNever(event.delta);
526562
}
563+
527564
return snapshot;
528565
}
529566
case 'content_block_stop':
@@ -597,3 +634,6 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
597634
return stream.toReadableStream();
598635
}
599636
}
637+
638+
// used to ensure exhaustive case matching without throwing a runtime error
639+
function checkNever(x: never) {}

0 commit comments

Comments
 (0)