Skip to content

community[minor]: Adds support for Cohere Command-R via AWS Bedrock #5336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

28 changes: 24 additions & 4 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
this.stopSequences = fields?.stopSequences;
this.modelKwargs = fields?.modelKwargs;
this.streaming = fields?.streaming ?? this.streaming;
this.usesMessagesApi =
this.model.split(".")[0] === "anthropic" &&
!this.model.includes("claude-v2") &&
!this.model.includes("claude-instant-v1");
this.usesMessagesApi = canUseMessagesApi(this.model);
}

async _generate(
Expand Down Expand Up @@ -499,6 +496,29 @@ function isChatGenerationChunk(
);
}

function canUseMessagesApi(model: string): boolean {
const modelProviderName = model.split(".")[0];

if (
modelProviderName === "anthropic" &&
!model.includes("claude-v2") &&
!model.includes("claude-instant-v1")
) {
return true;
}

if (modelProviderName === "cohere") {
if (model.includes("command-r-v1")) {
return true;
}
if (model.includes("command-r-plus-v1")) {
return true;
}
}

return false;
}

/**
* @deprecated Use `BedrockChat` instead.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,84 @@ import { BedrockChat } from "../bedrock/web.js";
// );

void testChatModel(
"Test Bedrock chat model: Claude-v2",
"Test Bedrock chat model Generating search queries: Command-r",
"us-east-1",
"cohere.command-r-v1:0",
"Who is more popular: Nsync or Backstreet Boys?",
{
search_queries_only: true,
}
);

void testChatModel(
"Test Bedrock chat model: Command-r",
"us-east-1",
"cohere.command-r-v1:0",
"What is your name?"
);

void testChatModel(
"Test Bedrock chat model: Command-r",
"us-east-1",
"cohere.command-r-v1:0",
"What are the characteristics of the emperor penguin?",
{
documents: [
{ title: "Tall penguins", snippet: "Emperor penguins are the tallest." },
{
title: "Penguin habitats",
snippet: "Emperor penguins only live in Antarctica.",
},
],
}
);

void testChatStreamingModel(
"Test Bedrock chat model streaming: Command-r",
"us-east-1",
"cohere.command-r-v1:0",
"What is your name and something about yourself?"
);

void testChatStreamingModel(
"Test Bedrock chat model streaming: Command-r",
"us-east-1",
"cohere.command-r-v1:0",
"What are the characteristics of the emperor penguin?",
{
documents: [
{ title: "Tall penguins", snippet: "Emperor penguins are the tallest." },
{
title: "Penguin habitats",
snippet: "Emperor penguins only live in Antarctica.",
},
],
}
);

void testChatHandleLLMNewToken(
"Test Bedrock chat model HandleLLMNewToken: Command-r",
"us-east-1",
"cohere.command-r-v1:0",
"What is your name and something about yourself?"
);

void testChatModel(
"Test Bedrock chat model: Mistral-7b-instruct",
"us-east-1",
"mistral.mistral-7b-instruct-v0:2",
"What is your name?"
);

void testChatStreamingModel(
"Test Bedrock chat model streaming: Claude-v2",
"Test Bedrock chat model streaming: Mistral-7b-instruct",
"us-east-1",
"mistral.mistral-7b-instruct-v0:2",
"What is your name and something about yourself?"
);

void testChatHandleLLMNewToken(
"Test Bedrock chat model HandleLLMNewToken: Claude-v2",
"Test Bedrock chat model HandleLLMNewToken: Mistral-7b-instruct",
"us-east-1",
"mistral.mistral-7b-instruct-v0:2",
"What is your name and something about yourself?"
Expand All @@ -59,6 +122,7 @@ void testChatHandleLLMNewToken(
"anthropic.claude-3-sonnet-20240229-v1:0",
"What is your name and something about yourself?"
);

// void testChatHandleLLMNewToken(
// "Test Bedrock chat model HandleLLMNewToken: Llama2 13B v1",
// "us-east-1",
Expand All @@ -77,13 +141,14 @@ async function testChatModel(
title: string,
defaultRegion: string,
model: string,
message: string
message: string,
modelKwargs?: Record<string, unknown>
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;

const bedrock = new BedrockChat({
maxTokens: 20,
maxTokens: 200,
region,
model,
maxRetries: 0,
Expand All @@ -92,6 +157,7 @@ async function testChatModel(
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
modelKwargs,
});

const res = await bedrock.invoke([new HumanMessage(message)]);
Expand All @@ -109,7 +175,8 @@ async function testChatStreamingModel(
title: string,
defaultRegion: string,
model: string,
message: string
message: string,
modelKwargs?: Record<string, unknown>
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;
Expand All @@ -124,6 +191,7 @@ async function testChatStreamingModel(
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
modelKwargs,
});

const stream = await bedrock.stream([
Expand Down
173 changes: 171 additions & 2 deletions libs/langchain-community/src/utils/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,92 @@ function formatMessagesForAnthropic(messages: BaseMessage[]): {
};
}

/**
* format messages for Cohere Command-R and CommandR+ via AWS Bedrock.
*
* @param messages messages The base messages to format as a prompt.
*
* @returns The formatted prompt for Cohere.
*
* `system`: user system prompts. Overrides the default preamble for search query generation. Has no effect on tool use generations.\
* `message`: (Required) Text input for the model to respond to.\
* `chatHistory`: A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.\
* The following are required fields.
* - `role` - The role for the message. Valid values are USER or CHATBOT.\
* - `message` – Text contents of the message.\
*
* The following is example JSON for the chat_history field.\
* "chat_history": [
* {"role": "USER", "message": "Who discovered gravity?"},
* {"role": "CHATBOT", "message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"}]\
*
* docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
*/
function formatMessagesForCohere(messages: BaseMessage[]): {
system?: string;
message: string;
chatHistory: Record<string, unknown>[];
} {
const systemMessages = messages.filter(
(system) => system._getType() === "system"
);

const system = systemMessages
.filter((m) => typeof m.content === "string")
.map((m) => m.content)
.join("\n\n");

const conversationMessages = messages.filter(
(message) => message._getType() !== "system"
);

const questionContent = conversationMessages.slice(-1);

if (!questionContent.length || questionContent[0]._getType() !== "human") {
throw new Error("question message content must be a human message.");
}

if (typeof questionContent[0].content !== "string") {
throw new Error("question message content must be a string.");
}

const formattedMessage = questionContent[0].content;

const formattedChatHistories = conversationMessages
.slice(0, -1)
.map((message) => {
let role;
switch (message._getType()) {
case "human":
role = "USER" as const;
break;
case "ai":
role = "CHATBOT" as const;
break;
case "system":
throw new Error("chat_history can not include system prompts.");
default:
throw new Error(
`Message type "${message._getType()}" is not supported.`
);
}

if (typeof message.content !== "string") {
throw new Error("message content must be a string.");
}
return {
role,
message: message.content,
};
});

return {
chatHistory: formattedChatHistories,
message: formattedMessage,
system,
};
}

/** Bedrock models.
To authenticate, the AWS client uses the following methods to automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
Expand Down Expand Up @@ -221,9 +307,25 @@ export class BedrockLLMInputOutputAdapter {
inputBody.temperature = temperature;
inputBody.stop_sequences = stopSequences;
return { ...inputBody, ...modelKwargs };
} else if (provider === "cohere") {
const {
system,
message: formattedMessage,
chatHistory: formattedChatHistories,
} = formatMessagesForCohere(messages);

if (system !== undefined && system.length > 0) {
inputBody.preamble = system;
}
inputBody.message = formattedMessage;
inputBody.chat_history = formattedChatHistories;
inputBody.max_tokens = maxTokens;
inputBody.temperature = temperature;
inputBody.stop_sequences = stopSequences;
return { ...inputBody, ...modelKwargs };
} else {
throw new Error(
"The messages API is currently only supported by Anthropic"
"The messages API is currently only supported by Anthropic or Cohere"
);
}
}
Expand Down Expand Up @@ -298,9 +400,48 @@ export class BedrockLLMInputOutputAdapter {
} else {
return undefined;
}
} else if (provider === "cohere") {
if (responseBody.event_type === "stream-start") {
return parseMessageCohere(responseBody.message, true);
} else if (
responseBody.event_type === "text-generation" &&
typeof responseBody?.text === "string"
) {
return new ChatGenerationChunk({
message: new AIMessageChunk({
content: responseBody.text,
}),
text: responseBody.text,
});
} else if (responseBody.event_type === "search-queries-generation") {
return parseMessageCohere(responseBody);
} else if (
responseBody.event_type === "stream-end" &&
responseBody.response !== undefined &&
responseBody["amazon-bedrock-invocationMetrics"] !== undefined
) {
return new ChatGenerationChunk({
message: new AIMessageChunk({ content: "" }),
text: "",
generationInfo: {
response: responseBody.response,
"amazon-bedrock-invocationMetrics":
responseBody["amazon-bedrock-invocationMetrics"],
},
});
} else {
if (
responseBody.finish_reason === "COMPLETE" ||
responseBody.finish_reason === "MAX_TOKENS"
) {
return parseMessageCohere(responseBody);
} else {
return undefined;
}
}
} else {
throw new Error(
"The messages API is currently only supported by Anthropic."
"The messages API is currently only supported by Anthropic or Cohere."
);
}
}
Expand Down Expand Up @@ -341,3 +482,31 @@ function parseMessage(responseBody: any, asChunk?: boolean): ChatGeneration {
};
}
}

function parseMessageCohere(
responseBody: any,
asChunk?: boolean
): ChatGeneration {
const { text, ...generationInfo } = responseBody;
let parsedContent = text;
if (typeof text !== "string") {
parsedContent = "";
}
if (asChunk) {
return new ChatGenerationChunk({
message: new AIMessageChunk({
content: parsedContent,
}),
text: parsedContent,
generationInfo,
});
} else {
return {
message: new AIMessage({
content: parsedContent,
}),
text: parsedContent,
generationInfo,
};
}
}
Loading