Skip to content

Commit e326126

Browse files
author
CognitiveTech
authored
feat: add mistral + chatml prompts (#1426)
1 parent 6191bcd commit e326126

File tree

6 files changed

+107
-5
lines changed

6 files changed

+107
-5
lines changed

fern/docs/pages/recipes/list-llm.mdx

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,31 @@ user: {{ user_message }}
2424
assistant: {{ assistant_message }}
2525
```
2626

27-
And the "`tag`" style looks like this:
27+
The "`tag`" style looks like this:
2828

2929
```text
3030
<|system|>: {{ system_prompt }}
3131
<|user|>: {{ user_message }}
3232
<|assistant|>: {{ assistant_message }}
3333
```
3434

35-
Some LLMs will not understand this prompt style, and will not work (returning nothing).
35+
The "`mistral`" style looks like this:
36+
37+
```text
38+
<s>[INST] You are an AI assistant. [/INST]</s>[INST] Hello, how are you doing? [/INST]
39+
```
40+
41+
The "`chatml`" style looks like this:
42+
```text
43+
<|im_start|>system
44+
{{ system_prompt }}<|im_end|>
45+
<|im_start|>user"
46+
{{ user_message }}<|im_end|>
47+
<|im_start|>assistant
48+
{{ assistant_message }}
49+
```
50+
51+
Some LLMs will not understand these prompt styles, and will not work (returning nothing).
3652
You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
3753
change the way the messages are formatted to be passed to the LLM.
3854

private_gpt/components/llm/prompt_helper.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,51 @@ def _completion_to_prompt(self, completion: str) -> str:
123123
)
124124

125125

126+
class MistralPromptStyle(AbstractPromptStyle):
127+
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
128+
prompt = "<s>"
129+
for message in messages:
130+
role = message.role
131+
content = message.content or ""
132+
if role.lower() == "system":
133+
message_from_user = f"[INST] {content.strip()} [/INST]"
134+
prompt += message_from_user
135+
elif role.lower() == "user":
136+
prompt += "</s>"
137+
message_from_user = f"[INST] {content.strip()} [/INST]"
138+
prompt += message_from_user
139+
return prompt
140+
141+
def _completion_to_prompt(self, completion: str) -> str:
142+
return self._messages_to_prompt(
143+
[ChatMessage(content=completion, role=MessageRole.USER)]
144+
)
145+
146+
147+
class ChatMLPromptStyle(AbstractPromptStyle):
148+
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
149+
prompt = "<|im_start|>system\n"
150+
for message in messages:
151+
role = message.role
152+
content = message.content or ""
153+
if role.lower() == "system":
154+
message_from_user = f"{content.strip()}"
155+
prompt += message_from_user
156+
elif role.lower() == "user":
157+
prompt += "<|im_end|>\n<|im_start|>user\n"
158+
message_from_user = f"{content.strip()}<|im_end|>\n"
159+
prompt += message_from_user
160+
prompt += "<|im_start|>assistant\n"
161+
return prompt
162+
163+
def _completion_to_prompt(self, completion: str) -> str:
164+
return self._messages_to_prompt(
165+
[ChatMessage(content=completion, role=MessageRole.USER)]
166+
)
167+
168+
126169
def get_prompt_style(
127-
prompt_style: Literal["default", "llama2", "tag"] | None
170+
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
128171
) -> AbstractPromptStyle:
129172
"""Get the prompt style to use from the given string.
130173
@@ -137,4 +180,8 @@ def get_prompt_style(
137180
return Llama2PromptStyle()
138181
elif prompt_style == "tag":
139182
return TagPromptStyle()
183+
elif prompt_style == "mistral":
184+
return MistralPromptStyle()
185+
elif prompt_style == "chatml":
186+
return ChatMLPromptStyle()
140187
raise ValueError(f"Unknown prompt_style='{prompt_style}'")

private_gpt/settings/settings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,14 @@ class LocalSettings(BaseModel):
110110
embedding_hf_model_name: str = Field(
111111
description="Name of the HuggingFace model to use for embeddings"
112112
)
113-
prompt_style: Literal["default", "llama2", "tag"] = Field(
113+
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
114114
"llama2",
115115
description=(
116116
"The prompt style to use for the chat engine. "
117117
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
118118
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
119119
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
120+
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
120121
"`llama2` is the historic behaviour. `default` might work better with your custom models."
121122
),
122123
)

scripts/extract_openapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import json
33
import sys
4+
45
import yaml
56
from uvicorn.importer import import_from_string
67

settings.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ qdrant:
5151
path: local_data/private_gpt/qdrant
5252

5353
local:
54-
prompt_style: "llama2"
54+
prompt_style: "mistral"
5555
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
5656
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
5757
embedding_hf_model_name: BAAI/bge-small-en-v1.5

tests/test_prompt_helper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from llama_index.llms import ChatMessage, MessageRole
33

44
from private_gpt.components.llm.prompt_helper import (
5+
ChatMLPromptStyle,
56
DefaultPromptStyle,
67
Llama2PromptStyle,
8+
MistralPromptStyle,
79
TagPromptStyle,
810
get_prompt_style,
911
)
@@ -15,6 +17,8 @@
1517
("default", DefaultPromptStyle),
1618
("llama2", Llama2PromptStyle),
1719
("tag", TagPromptStyle),
20+
("mistral", MistralPromptStyle),
21+
("chatml", ChatMLPromptStyle),
1822
],
1923
)
2024
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
@@ -62,6 +66,39 @@ def test_tag_prompt_style_format_with_system_prompt():
6266
assert prompt_style.messages_to_prompt(messages) == expected_prompt
6367

6468

69+
def test_mistral_prompt_style_format():
70+
prompt_style = MistralPromptStyle()
71+
messages = [
72+
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
73+
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
74+
]
75+
76+
expected_prompt = (
77+
"<s>[INST] You are an AI assistant. [/INST]</s>"
78+
"[INST] Hello, how are you doing? [/INST]"
79+
)
80+
81+
assert prompt_style.messages_to_prompt(messages) == expected_prompt
82+
83+
84+
def test_chatml_prompt_style_format():
85+
prompt_style = ChatMLPromptStyle()
86+
messages = [
87+
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
88+
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
89+
]
90+
91+
expected_prompt = (
92+
"<|im_start|>system\n"
93+
"You are an AI assistant.<|im_end|>\n"
94+
"<|im_start|>user\n"
95+
"Hello, how are you doing?<|im_end|>\n"
96+
"<|im_start|>assistant\n"
97+
)
98+
99+
assert prompt_style.messages_to_prompt(messages) == expected_prompt
100+
101+
65102
def test_llama2_prompt_style_format():
66103
prompt_style = Llama2PromptStyle()
67104
messages = [

0 commit comments

Comments
 (0)