Skip to content

Commit 564708b

Browse files
m-hamarojacoblee93
andauthored
community[minor]: Adds support for Cohere Command-R via AWS Bedrock (#5336)
* add-cohere-command-r-formatted-messages * add integration test patterns * fixed typo * add comments * fixed input system prompts check * fixed cohere modelName validation * fixed lint * Small typo fix * Add special message parsing for Cohere --------- Co-authored-by: jacoblee93 <[email protected]>
1 parent 05c4c76 commit 564708b

File tree

3 files changed

+269
-12
lines changed

3 files changed

+269
-12
lines changed

libs/langchain-community/src/chat_models/bedrock/web.ts

+24-4
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
208208
this.stopSequences = fields?.stopSequences;
209209
this.modelKwargs = fields?.modelKwargs;
210210
this.streaming = fields?.streaming ?? this.streaming;
211-
this.usesMessagesApi =
212-
this.model.split(".")[0] === "anthropic" &&
213-
!this.model.includes("claude-v2") &&
214-
!this.model.includes("claude-instant-v1");
211+
this.usesMessagesApi = canUseMessagesApi(this.model);
215212
}
216213

217214
async _generate(
@@ -499,6 +496,29 @@ function isChatGenerationChunk(
499496
);
500497
}
501498

499+
function canUseMessagesApi(model: string): boolean {
500+
const modelProviderName = model.split(".")[0];
501+
502+
if (
503+
modelProviderName === "anthropic" &&
504+
!model.includes("claude-v2") &&
505+
!model.includes("claude-instant-v1")
506+
) {
507+
return true;
508+
}
509+
510+
if (modelProviderName === "cohere") {
511+
if (model.includes("command-r-v1")) {
512+
return true;
513+
}
514+
if (model.includes("command-r-plus-v1")) {
515+
return true;
516+
}
517+
}
518+
519+
return false;
520+
}
521+
502522
/**
503523
* @deprecated Use `BedrockChat` instead.
504524
*/

libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts

+74-6
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,84 @@ import { BedrockChat } from "../bedrock/web.js";
1919
// );
2020

2121
void testChatModel(
22-
"Test Bedrock chat model: Claude-v2",
22+
"Test Bedrock chat model Generating search queries: Command-r",
23+
"us-east-1",
24+
"cohere.command-r-v1:0",
25+
"Who is more popular: Nsync or Backstreet Boys?",
26+
{
27+
search_queries_only: true,
28+
}
29+
);
30+
31+
void testChatModel(
32+
"Test Bedrock chat model: Command-r",
33+
"us-east-1",
34+
"cohere.command-r-v1:0",
35+
"What is your name?"
36+
);
37+
38+
void testChatModel(
39+
"Test Bedrock chat model: Command-r",
40+
"us-east-1",
41+
"cohere.command-r-v1:0",
42+
"What are the characteristics of the emperor penguin?",
43+
{
44+
documents: [
45+
{ title: "Tall penguins", snippet: "Emperor penguins are the tallest." },
46+
{
47+
title: "Penguin habitats",
48+
snippet: "Emperor penguins only live in Antarctica.",
49+
},
50+
],
51+
}
52+
);
53+
54+
void testChatStreamingModel(
55+
"Test Bedrock chat model streaming: Command-r",
56+
"us-east-1",
57+
"cohere.command-r-v1:0",
58+
"What is your name and something about yourself?"
59+
);
60+
61+
void testChatStreamingModel(
62+
"Test Bedrock chat model streaming: Command-r",
63+
"us-east-1",
64+
"cohere.command-r-v1:0",
65+
"What are the characteristics of the emperor penguin?",
66+
{
67+
documents: [
68+
{ title: "Tall penguins", snippet: "Emperor penguins are the tallest." },
69+
{
70+
title: "Penguin habitats",
71+
snippet: "Emperor penguins only live in Antarctica.",
72+
},
73+
],
74+
}
75+
);
76+
77+
void testChatHandleLLMNewToken(
78+
"Test Bedrock chat model HandleLLMNewToken: Command-r",
79+
"us-east-1",
80+
"cohere.command-r-v1:0",
81+
"What is your name and something about yourself?"
82+
);
83+
84+
void testChatModel(
85+
"Test Bedrock chat model: Mistral-7b-instruct",
2386
"us-east-1",
2487
"mistral.mistral-7b-instruct-v0:2",
2588
"What is your name?"
2689
);
2790

2891
void testChatStreamingModel(
29-
"Test Bedrock chat model streaming: Claude-v2",
92+
"Test Bedrock chat model streaming: Mistral-7b-instruct",
3093
"us-east-1",
3194
"mistral.mistral-7b-instruct-v0:2",
3295
"What is your name and something about yourself?"
3396
);
3497

3598
void testChatHandleLLMNewToken(
36-
"Test Bedrock chat model HandleLLMNewToken: Claude-v2",
99+
"Test Bedrock chat model HandleLLMNewToken: Mistral-7b-instruct",
37100
"us-east-1",
38101
"mistral.mistral-7b-instruct-v0:2",
39102
"What is your name and something about yourself?"
@@ -59,6 +122,7 @@ void testChatHandleLLMNewToken(
59122
"anthropic.claude-3-sonnet-20240229-v1:0",
60123
"What is your name and something about yourself?"
61124
);
125+
62126
// void testChatHandleLLMNewToken(
63127
// "Test Bedrock chat model HandleLLMNewToken: Llama2 13B v1",
64128
// "us-east-1",
@@ -77,13 +141,14 @@ async function testChatModel(
77141
title: string,
78142
defaultRegion: string,
79143
model: string,
80-
message: string
144+
message: string,
145+
modelKwargs?: Record<string, unknown>
81146
) {
82147
test(title, async () => {
83148
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;
84149

85150
const bedrock = new BedrockChat({
86-
maxTokens: 20,
151+
maxTokens: 200,
87152
region,
88153
model,
89154
maxRetries: 0,
@@ -92,6 +157,7 @@ async function testChatModel(
92157
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
93158
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
94159
},
160+
modelKwargs,
95161
});
96162

97163
const res = await bedrock.invoke([new HumanMessage(message)]);
@@ -109,7 +175,8 @@ async function testChatStreamingModel(
109175
title: string,
110176
defaultRegion: string,
111177
model: string,
112-
message: string
178+
message: string,
179+
modelKwargs?: Record<string, unknown>
113180
) {
114181
test(title, async () => {
115182
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;
@@ -124,6 +191,7 @@ async function testChatStreamingModel(
124191
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
125192
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
126193
},
194+
modelKwargs,
127195
});
128196

129197
const stream = await bedrock.stream([

libs/langchain-community/src/utils/bedrock.ts

+171-2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,92 @@ function formatMessagesForAnthropic(messages: BaseMessage[]): {
8888
};
8989
}
9090

91+
/**
92+
* format messages for Cohere Command-R and CommandR+ via AWS Bedrock.
93+
*
94+
* @param messages messages The base messages to format as a prompt.
95+
*
96+
* @returns The formatted prompt for Cohere.
97+
*
98+
* `system`: user system prompts. Overrides the default preamble for search query generation. Has no effect on tool use generations.\
99+
* `message`: (Required) Text input for the model to respond to.\
100+
* `chatHistory`: A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.\
101+
* The following are required fields.
102+
* - `role` - The role for the message. Valid values are USER or CHATBOT.\
103+
* - `message` – Text contents of the message.\
104+
*
105+
* The following is example JSON for the chat_history field.\
106+
* "chat_history": [
107+
* {"role": "USER", "message": "Who discovered gravity?"},
108+
* {"role": "CHATBOT", "message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"}]\
109+
*
110+
* docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
111+
*/
112+
function formatMessagesForCohere(messages: BaseMessage[]): {
113+
system?: string;
114+
message: string;
115+
chatHistory: Record<string, unknown>[];
116+
} {
117+
const systemMessages = messages.filter(
118+
(system) => system._getType() === "system"
119+
);
120+
121+
const system = systemMessages
122+
.filter((m) => typeof m.content === "string")
123+
.map((m) => m.content)
124+
.join("\n\n");
125+
126+
const conversationMessages = messages.filter(
127+
(message) => message._getType() !== "system"
128+
);
129+
130+
const questionContent = conversationMessages.slice(-1);
131+
132+
if (!questionContent.length || questionContent[0]._getType() !== "human") {
133+
throw new Error("question message content must be a human message.");
134+
}
135+
136+
if (typeof questionContent[0].content !== "string") {
137+
throw new Error("question message content must be a string.");
138+
}
139+
140+
const formattedMessage = questionContent[0].content;
141+
142+
const formattedChatHistories = conversationMessages
143+
.slice(0, -1)
144+
.map((message) => {
145+
let role;
146+
switch (message._getType()) {
147+
case "human":
148+
role = "USER" as const;
149+
break;
150+
case "ai":
151+
role = "CHATBOT" as const;
152+
break;
153+
case "system":
154+
throw new Error("chat_history can not include system prompts.");
155+
default:
156+
throw new Error(
157+
`Message type "${message._getType()}" is not supported.`
158+
);
159+
}
160+
161+
if (typeof message.content !== "string") {
162+
throw new Error("message content must be a string.");
163+
}
164+
return {
165+
role,
166+
message: message.content,
167+
};
168+
});
169+
170+
return {
171+
chatHistory: formattedChatHistories,
172+
message: formattedMessage,
173+
system,
174+
};
175+
}
176+
91177
/** Bedrock models.
92178
To authenticate, the AWS client uses the following methods to automatically load credentials:
93179
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
@@ -221,9 +307,25 @@ export class BedrockLLMInputOutputAdapter {
221307
inputBody.temperature = temperature;
222308
inputBody.stop_sequences = stopSequences;
223309
return { ...inputBody, ...modelKwargs };
310+
} else if (provider === "cohere") {
311+
const {
312+
system,
313+
message: formattedMessage,
314+
chatHistory: formattedChatHistories,
315+
} = formatMessagesForCohere(messages);
316+
317+
if (system !== undefined && system.length > 0) {
318+
inputBody.preamble = system;
319+
}
320+
inputBody.message = formattedMessage;
321+
inputBody.chat_history = formattedChatHistories;
322+
inputBody.max_tokens = maxTokens;
323+
inputBody.temperature = temperature;
324+
inputBody.stop_sequences = stopSequences;
325+
return { ...inputBody, ...modelKwargs };
224326
} else {
225327
throw new Error(
226-
"The messages API is currently only supported by Anthropic"
328+
"The messages API is currently only supported by Anthropic or Cohere"
227329
);
228330
}
229331
}
@@ -298,9 +400,48 @@ export class BedrockLLMInputOutputAdapter {
298400
} else {
299401
return undefined;
300402
}
403+
} else if (provider === "cohere") {
404+
if (responseBody.event_type === "stream-start") {
405+
return parseMessageCohere(responseBody.message, true);
406+
} else if (
407+
responseBody.event_type === "text-generation" &&
408+
typeof responseBody?.text === "string"
409+
) {
410+
return new ChatGenerationChunk({
411+
message: new AIMessageChunk({
412+
content: responseBody.text,
413+
}),
414+
text: responseBody.text,
415+
});
416+
} else if (responseBody.event_type === "search-queries-generation") {
417+
return parseMessageCohere(responseBody);
418+
} else if (
419+
responseBody.event_type === "stream-end" &&
420+
responseBody.response !== undefined &&
421+
responseBody["amazon-bedrock-invocationMetrics"] !== undefined
422+
) {
423+
return new ChatGenerationChunk({
424+
message: new AIMessageChunk({ content: "" }),
425+
text: "",
426+
generationInfo: {
427+
response: responseBody.response,
428+
"amazon-bedrock-invocationMetrics":
429+
responseBody["amazon-bedrock-invocationMetrics"],
430+
},
431+
});
432+
} else {
433+
if (
434+
responseBody.finish_reason === "COMPLETE" ||
435+
responseBody.finish_reason === "MAX_TOKENS"
436+
) {
437+
return parseMessageCohere(responseBody);
438+
} else {
439+
return undefined;
440+
}
441+
}
301442
} else {
302443
throw new Error(
303-
"The messages API is currently only supported by Anthropic."
444+
"The messages API is currently only supported by Anthropic or Cohere."
304445
);
305446
}
306447
}
@@ -341,3 +482,31 @@ function parseMessage(responseBody: any, asChunk?: boolean): ChatGeneration {
341482
};
342483
}
343484
}
485+
486+
function parseMessageCohere(
487+
responseBody: any,
488+
asChunk?: boolean
489+
): ChatGeneration {
490+
const { text, ...generationInfo } = responseBody;
491+
let parsedContent = text;
492+
if (typeof text !== "string") {
493+
parsedContent = "";
494+
}
495+
if (asChunk) {
496+
return new ChatGenerationChunk({
497+
message: new AIMessageChunk({
498+
content: parsedContent,
499+
}),
500+
text: parsedContent,
501+
generationInfo,
502+
});
503+
} else {
504+
return {
505+
message: new AIMessage({
506+
content: parsedContent,
507+
}),
508+
text: parsedContent,
509+
generationInfo,
510+
};
511+
}
512+
}

0 commit comments

Comments
 (0)