Skip to content

mistralai[minor]: Add llms entrypoint, update chat model integration #5603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
154 changes: 154 additions & 0 deletions docs/core_docs/docs/integrations/llms/mistral.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MistralAI\n",
"\n",
"Here's how you can initialize an `MistralAI` LLM instance:\n",
"\n",
"```{=mdx}\n",
"import IntegrationInstallTooltip from \"@mdx_components/integration_install_tooltip.mdx\";\n",
"import Npm2Yarn from \"@theme/Npm2Yarn\";\n",
"\n",
"<IntegrationInstallTooltip></IntegrationInstallTooltip>\n",
"\n",
"<Npm2Yarn>\n",
" @langchain/mistralai\n",
"</Npm2Yarn>\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"console.log('hello world');\n",
"```\n",
"This will output 'hello world' to the console.\n"
]
}
],
"source": [
"import { MistralAI } from \"@langchain/mistralai\";\n",
"\n",
"const model = new MistralAI({\n",
" model: \"codestral-latest\", // Defaults to \"codestral-latest\" if no model provided.\n",
" temperature: 0,\n",
" apiKey: \"YOUR-API-KEY\", // In Node.js defaults to process.env.MISTRAL_API_KEY\n",
"});\n",
"const res = await model.invoke(\n",
" \"You can print 'hello world' to the console in javascript like this:\\n```javascript\"\n",
");\n",
"console.log(res);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the Mistral LLM is a completions model, they also allow you to insert a `suffix` to the prompt. Suffixes can be passed via the call options when invoking a model like so:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"console.log('hello world');\n",
"```\n"
]
}
],
"source": [
"const res = await model.invoke(\n",
" \"You can print 'hello world' to the console in javascript like this:\\n```javascript\", {\n",
" suffix: \"```\"\n",
" }\n",
");\n",
"console.log(res);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As seen in the first example, the model generated the requested `console.log('hello world')` code snippet, but also included extra unwanted text. By adding a suffix, we can constrain the model to only complete the prompt up to the suffix (in this case, three backticks). This allows us to easily parse the completion and extract only the desired response without the suffix using a custom output parser."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"console.log('hello world');\n",
"\n"
]
}
],
"source": [
"import { MistralAI } from \"@langchain/mistralai\";\n",
"\n",
"const model = new MistralAI({\n",
" model: \"codestral-latest\",\n",
" temperature: 0,\n",
" apiKey: \"YOUR-API-KEY\",\n",
"});\n",
"\n",
"const suffix = \"```\";\n",
"\n",
"const customOutputParser = (input: string) => {\n",
" if (input.includes(suffix)) {\n",
" return input.split(suffix)[0];\n",
" }\n",
" throw new Error(\"Input does not contain suffix.\")\n",
"};\n",
"\n",
"const res = await model.invoke(\n",
" \"You can print 'hello world' to the console in javascript like this:\\n```javascript\", {\n",
" suffix,\n",
" }\n",
");\n",
"\n",
"console.log(customOutputParser(res));"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "TypeScript",
"language": "typescript",
"name": "tslab"
},
"language_info": {
"codemirror_mode": {
"mode": "typescript",
"name": "javascript",
"typescript": true
},
"file_extension": ".ts",
"mimetype": "text/typescript",
"name": "typescript",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion examples/src/models/chat/chat_mistralai_tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CalculatorTool extends StructuredTool {

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've reviewed the code and noticed a change in the model initialization, updating the model property from "mistral-large" to "mistral-large-latest". I've flagged this for your review as it also accesses an environment variable using process.env.MISTRAL_API_KEY. Keep up the great work!

const model = new ChatMistralAI({
apiKey: process.env.MISTRAL_API_KEY,
model: "mistral-large",
model: "mistral-large-latest",
});

// Bind the tool to the model
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/chat_mistralai_wsa.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const calculatorSchema = z

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've noticed that the recent change in the chat_mistralai_wsa.ts file updates the model property to "mistral-large-latest". I've flagged this for your review as it also accesses an environment variable via process.env.MISTRAL_API_KEY. Keep up the great work!

const model = new ChatMistralAI({
apiKey: process.env.MISTRAL_API_KEY,
model: "mistral-large",
model: "mistral-large-latest",
});

// Pass the schema and tool name to the withStructuredOutput method
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/chat_mistralai_wsa_json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const calculatorJsonSchema = {

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I've reviewed the code and noticed that the recent change in the chat_mistralai_wsa_json.ts file explicitly accesses an environment variable via process.env. I've flagged this for your review to ensure it aligns with the intended functionality. Let me know if you need further assistance!

const model = new ChatMistralAI({
apiKey: process.env.MISTRAL_API_KEY,
model: "mistral-large",
model: "mistral-large-latest",
});

// Pass the schema and tool name to the withStructuredOutput method
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-mistralai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"license": "MIT",
"dependencies": {
"@langchain/core": ">0.1.56 <0.3.0",
"@mistralai/mistralai": "^0.1.3",
"@mistralai/mistralai": "^0.4.0",
"uuid": "^9.0.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.4"
Expand Down
90 changes: 44 additions & 46 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ import {
ChatCompletionResponse,
Function as MistralAIFunction,
ToolCalls as MistralAIToolCalls,
ToolChoice as MistralAIToolChoice,
ResponseFormat,
ChatCompletionResponseChunk,
ToolType,
ChatRequest,
Tool as MistralAITool,
Message as MistralAIMessage,
} from "@mistralai/mistralai";
import {
MessageType,
Expand Down Expand Up @@ -44,7 +45,6 @@ import {
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { StructuredTool, StructuredToolInterface } from "@langchain/core/tools";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { z } from "zod";
import {
type BaseLLMOutputParser,
Expand All @@ -70,40 +70,15 @@ interface TokenUsage {
totalTokens?: number;
}

type MistralAIInputMessage = {
role: string;
name?: string;
content: string | string[];
tool_calls?: MistralAIToolCalls[];
};
export type MistralAIToolChoice = "auto" | "any" | "none";

type MistralAIToolInput = { type: string; function: MistralAIFunction };

type MistralAIChatCompletionOptions = {
model: string;
messages: Array<{
role: string;
name?: string;
content: string | string[];
tool_calls?: MistralAIToolCalls[];
}>;
tools?: Array<MistralAIToolInput>;
temperature?: number;
maxTokens?: number;
topP?: number;
randomSeed?: number;
safeMode?: boolean;
safePrompt?: boolean;
toolChoice?: MistralAIToolChoice;
responseFormat?: ResponseFormat;
};

interface MistralAICallOptions
extends Omit<BaseLanguageModelCallOptions, "stop"> {
response_format?: {
type: "text" | "json_object";
};
tools: StructuredToolInterface[] | MistralAIToolInput[];
tools: StructuredToolInterface[] | MistralAIToolInput[] | MistralAITool[];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to include the old type for back-compat

tool_choice?: MistralAIToolChoice;
}

Expand Down Expand Up @@ -178,7 +153,7 @@ export interface ChatMistralAIInput extends BaseChatModelParams {

function convertMessagesToMistralMessages(
messages: Array<BaseMessage>
): Array<MistralAIInputMessage> {
): Array<MistralAIMessage> {
const getRole = (role: MessageType) => {
switch (role) {
case "human":
Expand Down Expand Up @@ -212,7 +187,7 @@ function convertMessagesToMistralMessages(
const getTools = (message: BaseMessage): MistralAIToolCalls[] | undefined => {
if (isAIMessage(message) && !!message.tool_calls?.length) {
return message.tool_calls
.map((toolCall) => ({ ...toolCall, id: "null" }))
.map((toolCall) => ({ ...toolCall, id: toolCall.id }))
.map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[];
}
if (!message.additional_kwargs.tool_calls?.length) {
Expand All @@ -221,8 +196,8 @@ function convertMessagesToMistralMessages(
const toolCalls: Omit<OpenAIToolCall, "index">[] =
message.additional_kwargs.tool_calls;
return toolCalls?.map((toolCall) => ({
id: "null",
type: "function" as ToolType.function,
id: toolCall.id,
type: "function",
function: toolCall.function,
}));
};
Expand All @@ -235,7 +210,7 @@ function convertMessagesToMistralMessages(
content,
tool_calls: toolCalls,
};
});
}) as MistralAIMessage[];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Their sdk types don't play well with their API... casting allows us to use this type which doesn't throw errors from their API

}

function mistralAIResponseToChatMessage(
Expand Down Expand Up @@ -270,7 +245,10 @@ function mistralAIResponseToChatMessage(
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs: {
tool_calls: rawToolCalls,
tool_calls: rawToolCalls.map((toolCall) => ({
...toolCall,
type: "function",
})),
},
});
}
Expand Down Expand Up @@ -350,8 +328,18 @@ function _convertDeltaToMessageChunk(delta: {

function _convertStructuredToolToMistralTool(
tools: StructuredToolInterface[]
): MistralAIToolInput[] {
return tools.map((tool) => convertToOpenAITool(tool) as MistralAIToolInput);
): MistralAITool[] {
return tools.map((tool) => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New tool type from mistral is not exactly compatible with OAI

const description = tool.description ?? `Tool: ${tool.name}`;
return {
type: "function",
function: {
name: tool.name,
description,
parameters: zodToJsonSchema(tool.schema),
},
};
});
}

/**
Expand Down Expand Up @@ -439,17 +427,27 @@ export class ChatMistralAI<
*/
invocationParams(
options?: this["ParsedCallOptions"]
): Omit<MistralAIChatCompletionOptions, "messages"> {
): Omit<ChatRequest, "messages"> {
const { response_format, tools, tool_choice } = options ?? {};
const mistralAITools = tools
const mistralAITools: Array<MistralAITool> | undefined = tools
?.map((tool) => {
if ("lc_namespace" in tool) {
return _convertStructuredToolToMistralTool([tool]);
}
return tool;
if (!tool.function.description) {
return {
type: "function",
function: {
name: tool.function.name,
description: `Tool: ${tool.function.name}`,
parameters: tool.function.parameters,
},
} as MistralAITool;
}
return tool as MistralAITool;
})
.flat();
const params: Omit<MistralAIChatCompletionOptions, "messages"> = {
const params: Omit<ChatRequest, "messages"> = {
model: this.model,
tools: mistralAITools,
temperature: this.temperature,
Expand Down Expand Up @@ -484,21 +482,21 @@ export class ChatMistralAI<

/**
* Calls the MistralAI API with retry logic in case of failures.
* @param {MistralAIChatCompletionOptions} input The input to send to the MistralAI API.
* @param {ChatRequest} input The input to send to the MistralAI API.
* @returns {Promise<MistralAIChatCompletionResult | AsyncGenerator<MistralAIChatCompletionResult>>} The response from the MistralAI API.
*/
async completionWithRetry(
input: MistralAIChatCompletionOptions,
input: ChatRequest,
streaming: true
): Promise<AsyncGenerator<ChatCompletionResponseChunk>>;

async completionWithRetry(
input: MistralAIChatCompletionOptions,
input: ChatRequest,
streaming: false
): Promise<ChatCompletionResponse>;

async completionWithRetry(
input: MistralAIChatCompletionOptions,
input: ChatRequest,
streaming: boolean
): Promise<
ChatCompletionResponse | AsyncGenerator<ChatCompletionResponseChunk>
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-mistralai/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from "./chat_models.js";
export * from "./embeddings.js";
export * from "./llms.js";
Loading
Loading