Skip to content

Upgraded mistralai dep to >= 1 #2025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions giskard/llm/client/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Sequence

import os
from dataclasses import asdict
from logging import warning

Expand All @@ -9,18 +10,17 @@
from .base import ChatMessage

try:
from mistralai.client import MistralClient as _MistralClient
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
from mistralai import Mistral
except ImportError as err:
raise LLMImportError(
flavor="llm", msg="To use Mistral models, please install the `mistralai` package with `pip install mistralai`"
) from err


class MistralClient(LLMClient):
def __init__(self, model: str = "mistral-large-latest", client: _MistralClient = None):
def __init__(self, model: str = "mistral-large-latest", client: Mistral = None):
self.model = model
self._client = client or _MistralClient()
self._client = client or Mistral(api_key=os.getenv("MISTRAL_API_KEY", ""))

def complete(
self,
Expand All @@ -43,9 +43,9 @@ def complete(
extra_params["response_format"] = {"type": "json_object"}

try:
completion = self._client.chat(
completion = self._client.chat.complete(
model=self.model,
messages=[MistralChatMessage(**asdict(m)) for m in messages],
messages=[asdict(m) for m in messages],
temperature=temperature,
max_tokens=max_tokens,
**extra_params,
Expand Down
35 changes: 24 additions & 11 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ dev = [
"pytest-asyncio>=0.21.1",
"pydantic>=2",
"avidtools",
"mistralai>=0.1.8, <1",
"mistralai>=1",
"boto3>=1.34.88",
"scikit-learn==1.4.2",
]
Expand Down
20 changes: 9 additions & 11 deletions tests/llm/test_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import pydantic
import pytest
from google.generativeai.types import ContentDict
from mistralai.models.chat_completion import ChatCompletionResponse, ChatCompletionResponseChoice
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
from mistralai.models.chat_completion import FinishReason, UsageInfo
from mistralai.models import ChatCompletionChoice, ChatCompletionResponse, UsageInfo
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
Expand Down Expand Up @@ -41,10 +39,10 @@
created=1630000000,
model="mistral-large",
choices=[
ChatCompletionResponseChoice(
ChatCompletionChoice(
index=0,
message=MistralChatMessage(role="assistant", content="This is a test!", name=None, tool_calls=None),
finish_reason=FinishReason.stop,
message={"role": "assistant", "content": "This is a test!"},
finish_reason="stop",
)
],
usage=UsageInfo(prompt_tokens=9, total_tokens=89, completion_tokens=80),
Expand All @@ -70,18 +68,18 @@ def test_llm_complete_message():
@pytest.mark.skipif(not PYDANTIC_V2, reason="Mistral raise an error with pydantic < 2")
def test_mistral_client():
client = Mock()
client.chat.return_value = DEMO_MISTRAL_RESPONSE
client.chat.complete.return_value = DEMO_MISTRAL_RESPONSE

from giskard.llm.client.mistral import MistralClient

res = MistralClient(model="mistral-large", client=client).complete(
[ChatMessage(role="user", content="Hello")], temperature=0.11, max_tokens=12
)

client.chat.assert_called_once()
assert client.chat.call_args[1]["messages"] == [MistralChatMessage(role="user", content="Hello")]
assert client.chat.call_args[1]["temperature"] == 0.11
assert client.chat.call_args[1]["max_tokens"] == 12
client.chat.complete.assert_called_once()
assert client.chat.complete.call_args[1]["messages"] == [{"role": "user", "content": "Hello"}]
assert client.chat.complete.call_args[1]["temperature"] == 0.11
assert client.chat.complete.call_args[1]["max_tokens"] == 12

assert isinstance(res, ChatMessage)
assert res.content == "This is a test!"
Expand Down
Loading