Skip to content

Commit d92d2b5

Browse files
committed
Update docs
1 parent 231d9f8 commit d92d2b5

File tree

5 files changed

+63
-19
lines changed

5 files changed

+63
-19
lines changed

docs/WEB_SEARCH.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ By default, mistral.rs uses a DuckDuckGo-based search callback. To override this
3030
- Rust: use `.with_search_callback(...)` on the model builder with an `Arc<dyn Fn(&SearchFunctionParameters) -> anyhow::Result<Vec<SearchResult>> + Send + Sync>`.
3131
- Python: pass the `search_callback` keyword argument to `Runner`, which should be a function `def search_callback(query: str) -> List[Dict[str, str]]` returning a list of results with keys `"title"`, `"description"`, `"url"`, and `"content"`.
3232

33+
Example in Python:
34+
```py
35+
def search_callback(query: str) -> list[dict[str, str]]:
36+
# Implement your custom search logic here, returning a list of result dicts
37+
return [
38+
{
39+
"title": "Example Result",
40+
"description": "An example description",
41+
"url": "https://example.com",
42+
"content": "Full text content of the page",
43+
},
44+
# more results...
45+
]
46+
47+
from mistralrs import Runner, Which, Architecture
48+
runner = Runner(
49+
which=Which.Plain(model_id="YourModel/ID", arch=Architecture.Mistral),
50+
enable_search=True,
51+
search_callback=search_callback,
52+
)
53+
```
54+
3355
## HTTP server
3456
**Be sure to add `--enable-search`!**
3557

@@ -87,12 +109,25 @@ from mistralrs import (
87109
WebSearchOptions,
88110
)
89111

112+
# Define a custom search callback if desired
113+
def my_search_callback(query: str) -> list[dict[str, str]]:
114+
# Fetch or compute search results here
115+
return [
116+
{
117+
"title": "Mistral.rs GitHub",
118+
"description": "Official mistral.rs repository",
119+
"url": "https://github.com/huggingface/mistral.rs",
120+
"content": "mistral.rs is a Rust binding for Mistral models...",
121+
},
122+
]
123+
90124
runner = Runner(
91125
which=Which.Plain(
92126
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
93127
arch=Architecture.Llama,
94128
),
95129
enable_search=True,
130+
search_callback=my_search_callback,
96131
)
97132

98133
res = runner.send_chat_completion_request(

examples/python/local_search.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,29 @@
77
)
88
import os
99

10+
1011
def local_search(query: str):
1112
results = []
12-
for root, _, files in os.walk('.'):
13+
for root, _, files in os.walk("."):
1314
for f in files:
1415
if query in f:
1516
path = os.path.join(root, f)
1617
try:
1718
content = open(path).read()
1819
except Exception:
1920
content = ""
20-
results.append({
21-
"title": f,
22-
"description": path,
23-
"url": path,
24-
"content": content,
25-
})
26-
results.sort(key=lambda r: r['title'], reverse=True)
21+
results.append(
22+
{
23+
"title": f,
24+
"description": path,
25+
"url": path,
26+
"content": content,
27+
}
28+
)
29+
results.sort(key=lambda r: r["title"], reverse=True)
2730
return results
2831

32+
2933
runner = Runner(
3034
which=Which.Plain(
3135
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
@@ -40,7 +44,9 @@ def local_search(query: str):
4044
model="mistral",
4145
messages=[{"role": "user", "content": "Where is Cargo.toml in this repo?"}],
4246
max_tokens=64,
43-
web_search_options=WebSearchOptions(search_description="Local filesystem search"),
47+
web_search_options=WebSearchOptions(
48+
search_description="Local filesystem search"
49+
),
4450
)
4551
)
4652
print(res.choices[0].message.content)

mistralrs-pyo3/mistralrs.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from enum import Enum
3-
from typing import Iterator, Literal, Optional
3+
from typing import Iterator, Literal, Optional, Callable
44

55
class SearchContextSize(Enum):
66
Low = "low"
@@ -345,7 +345,9 @@ class Runner:
345345
paged_attn: bool = False,
346346
prompt_batchsize: int | None = None,
347347
seed: int | None = None,
348+
enable_search: bool = False,
348349
search_bert_model: str | None = None,
350+
search_callback: Callable[[str], list[dict[str, str]]] | None = None,
349351
no_bert_model: bool = False,
350352
) -> None:
351353
"""
@@ -389,6 +391,7 @@ class Runner:
389391
- `seed`, used to ensure reproducible random number generation.
390392
- `enable_search`: Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
391393
- `search_bert_model`: specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
394+
- `search_callback`: Custom Python callable to perform web searches. Should accept a query string and return a list of dicts with keys "title", "description", "url", and "content".
392395
"""
393396
...
394397

mistralrs-quant/kernels/marlin/marlin_kernel.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ dequant<half, ScalarTypeID::kU4B8>(int q) {
8686
const int HI = 0x00f000f0;
8787
const int EX = 0x64006400;
8888
// Guarantee that the `(a & b) | c` operations are LOP3s.
89-
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
90-
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
89+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
90+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
9191
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
9292
// directly into `SUB` and `ADD`.
9393
const int SUB = 0x64086408;
@@ -110,9 +110,9 @@ dequant<nv_bfloat16, ScalarTypeID::kU4B8>(int q) {
110110

111111
// Guarantee that the `(a & b) | c` operations are LOP3s.
112112

113-
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
113+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
114114
q >>= 4;
115-
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
115+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
116116

117117
typename ScalarType<nv_bfloat16>::FragB frag_b;
118118
static constexpr uint32_t MUL = 0x3F803F80;
@@ -135,8 +135,8 @@ dequant<half, ScalarTypeID::kU4>(int q) {
135135
const int HI = 0x00f000f0;
136136
const int EX = 0x64006400;
137137
// Guarantee that the `(a & b) | c` operations are LOP3s.
138-
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
139-
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
138+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
139+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
140140

141141
const int SUB = 0x64006400;
142142
const int MUL = 0x2c002c00;
@@ -158,9 +158,9 @@ dequant<nv_bfloat16, ScalarTypeID::kU4>(int q) {
158158

159159
// Guarantee that the `(a & b) | c` operations are LOP3s.
160160

161-
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
161+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
162162
q >>= 4;
163-
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
163+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
164164

165165
typename ScalarType<nv_bfloat16>::FragB frag_b;
166166
static constexpr uint32_t MUL = 0x3F803F80;

scripts/generate_uqff_card.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
)
5555
file = input("Enter UQFF filename (with extension): ").strip()
5656
if ";" in file:
57-
file = f"\"{file}\""
57+
file = f'"{file}"'
5858

5959
quants = input(
6060
"Enter quantization NAMES used to make that file (single quantization name, OR if multiple, comma delimited): "

0 commit comments

Comments
 (0)