@@ -88,6 +88,92 @@ function formatMessagesForAnthropic(messages: BaseMessage[]): {
88
88
} ;
89
89
}
90
90
91
+ /**
92
+ * format messages for Cohere Command-R and CommandR+ via AWS Bedrock.
93
+ *
94
+ * @param messages messages The base messages to format as a prompt.
95
+ *
96
+ * @returns The formatted prompt for Cohere.
97
+ *
98
+ * `system`: user system prompts. Overrides the default preamble for search query generation. Has no effect on tool use generations.\
99
+ * `message`: (Required) Text input for the model to respond to.\
100
+ * `chatHistory`: A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.\
101
+ * The following are required fields.
102
+ * - `role` - The role for the message. Valid values are USER or CHATBOT.\
103
+ * - `message` – Text contents of the message.\
104
+ *
105
+ * The following is example JSON for the chat_history field.\
106
+ * "chat_history": [
107
+ * {"role": "USER", "message": "Who discovered gravity?"},
108
+ * {"role": "CHATBOT", "message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"}]\
109
+ *
110
+ * docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
111
+ */
112
+ function formatMessagesForCohere ( messages : BaseMessage [ ] ) : {
113
+ system ?: string ;
114
+ message : string ;
115
+ chatHistory : Record < string , unknown > [ ] ;
116
+ } {
117
+ const systemMessages = messages . filter (
118
+ ( system ) => system . _getType ( ) === "system"
119
+ ) ;
120
+
121
+ const system = systemMessages
122
+ . filter ( ( m ) => typeof m . content === "string" )
123
+ . map ( ( m ) => m . content )
124
+ . join ( "\n\n" ) ;
125
+
126
+ const conversationMessages = messages . filter (
127
+ ( message ) => message . _getType ( ) !== "system"
128
+ ) ;
129
+
130
+ const questionContent = conversationMessages . slice ( - 1 ) ;
131
+
132
+ if ( ! questionContent . length || questionContent [ 0 ] . _getType ( ) !== "human" ) {
133
+ throw new Error ( "question message content must be a human message." ) ;
134
+ }
135
+
136
+ if ( typeof questionContent [ 0 ] . content !== "string" ) {
137
+ throw new Error ( "question message content must be a string." ) ;
138
+ }
139
+
140
+ const formattedMessage = questionContent [ 0 ] . content ;
141
+
142
+ const formattedChatHistories = conversationMessages
143
+ . slice ( 0 , - 1 )
144
+ . map ( ( message ) => {
145
+ let role ;
146
+ switch ( message . _getType ( ) ) {
147
+ case "human" :
148
+ role = "USER" as const ;
149
+ break ;
150
+ case "ai" :
151
+ role = "CHATBOT" as const ;
152
+ break ;
153
+ case "system" :
154
+ throw new Error ( "chat_history can not include system prompts." ) ;
155
+ default :
156
+ throw new Error (
157
+ `Message type "${ message . _getType ( ) } " is not supported.`
158
+ ) ;
159
+ }
160
+
161
+ if ( typeof message . content !== "string" ) {
162
+ throw new Error ( "message content must be a string." ) ;
163
+ }
164
+ return {
165
+ role,
166
+ message : message . content ,
167
+ } ;
168
+ } ) ;
169
+
170
+ return {
171
+ chatHistory : formattedChatHistories ,
172
+ message : formattedMessage ,
173
+ system,
174
+ } ;
175
+ }
176
+
91
177
/** Bedrock models.
92
178
To authenticate, the AWS client uses the following methods to automatically load credentials:
93
179
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
@@ -221,9 +307,25 @@ export class BedrockLLMInputOutputAdapter {
221
307
inputBody . temperature = temperature ;
222
308
inputBody . stop_sequences = stopSequences ;
223
309
return { ...inputBody , ...modelKwargs } ;
310
+ } else if ( provider === "cohere" ) {
311
+ const {
312
+ system,
313
+ message : formattedMessage ,
314
+ chatHistory : formattedChatHistories ,
315
+ } = formatMessagesForCohere ( messages ) ;
316
+
317
+ if ( system !== undefined && system . length > 0 ) {
318
+ inputBody . preamble = system ;
319
+ }
320
+ inputBody . message = formattedMessage ;
321
+ inputBody . chat_history = formattedChatHistories ;
322
+ inputBody . max_tokens = maxTokens ;
323
+ inputBody . temperature = temperature ;
324
+ inputBody . stop_sequences = stopSequences ;
325
+ return { ...inputBody , ...modelKwargs } ;
224
326
} else {
225
327
throw new Error (
226
- "The messages API is currently only supported by Anthropic"
328
+ "The messages API is currently only supported by Anthropic or Cohere "
227
329
) ;
228
330
}
229
331
}
@@ -298,9 +400,48 @@ export class BedrockLLMInputOutputAdapter {
298
400
} else {
299
401
return undefined ;
300
402
}
403
+ } else if ( provider === "cohere" ) {
404
+ if ( responseBody . event_type === "stream-start" ) {
405
+ return parseMessageCohere ( responseBody . message , true ) ;
406
+ } else if (
407
+ responseBody . event_type === "text-generation" &&
408
+ typeof responseBody ?. text === "string"
409
+ ) {
410
+ return new ChatGenerationChunk ( {
411
+ message : new AIMessageChunk ( {
412
+ content : responseBody . text ,
413
+ } ) ,
414
+ text : responseBody . text ,
415
+ } ) ;
416
+ } else if ( responseBody . event_type === "search-queries-generation" ) {
417
+ return parseMessageCohere ( responseBody ) ;
418
+ } else if (
419
+ responseBody . event_type === "stream-end" &&
420
+ responseBody . response !== undefined &&
421
+ responseBody [ "amazon-bedrock-invocationMetrics" ] !== undefined
422
+ ) {
423
+ return new ChatGenerationChunk ( {
424
+ message : new AIMessageChunk ( { content : "" } ) ,
425
+ text : "" ,
426
+ generationInfo : {
427
+ response : responseBody . response ,
428
+ "amazon-bedrock-invocationMetrics" :
429
+ responseBody [ "amazon-bedrock-invocationMetrics" ] ,
430
+ } ,
431
+ } ) ;
432
+ } else {
433
+ if (
434
+ responseBody . finish_reason === "COMPLETE" ||
435
+ responseBody . finish_reason === "MAX_TOKENS"
436
+ ) {
437
+ return parseMessageCohere ( responseBody ) ;
438
+ } else {
439
+ return undefined ;
440
+ }
441
+ }
301
442
} else {
302
443
throw new Error (
303
- "The messages API is currently only supported by Anthropic."
444
+ "The messages API is currently only supported by Anthropic or Cohere ."
304
445
) ;
305
446
}
306
447
}
@@ -341,3 +482,31 @@ function parseMessage(responseBody: any, asChunk?: boolean): ChatGeneration {
341
482
} ;
342
483
}
343
484
}
485
+
486
+ function parseMessageCohere (
487
+ responseBody : any ,
488
+ asChunk ?: boolean
489
+ ) : ChatGeneration {
490
+ const { text, ...generationInfo } = responseBody ;
491
+ let parsedContent = text ;
492
+ if ( typeof text !== "string" ) {
493
+ parsedContent = "" ;
494
+ }
495
+ if ( asChunk ) {
496
+ return new ChatGenerationChunk ( {
497
+ message : new AIMessageChunk ( {
498
+ content : parsedContent ,
499
+ } ) ,
500
+ text : parsedContent ,
501
+ generationInfo,
502
+ } ) ;
503
+ } else {
504
+ return {
505
+ message : new AIMessage ( {
506
+ content : parsedContent ,
507
+ } ) ,
508
+ text : parsedContent ,
509
+ generationInfo,
510
+ } ;
511
+ }
512
+ }
0 commit comments