Skip to content

Commit c53d346

Browse files
authored
Better tool call detection logic (#1424)
1 parent d6c227e commit c53d346

File tree

1 file changed

+57
-7
lines changed
  • mistralrs-core/src/tools

1 file changed

+57
-7
lines changed

mistralrs-core/src/tools/mod.rs

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ use candle_core::Result;
55
use regex::Regex;
66
pub use request::*;
77
pub use response::*;
8-
use serde_json::Value;
9-
use std::{
10-
collections::HashMap,
11-
sync::{Arc, OnceLock},
12-
};
8+
use serde::de::{self, Deserializer, MapAccess, Visitor};
9+
use serde_json::{Map, Value};
10+
use std::fmt;
11+
use std::sync::{Arc, OnceLock};
1312
use uuid::Uuid;
1413

1514
use crate::Pipeline;
@@ -87,8 +86,57 @@ pub struct ToolCallingMatcher {
8786
pub struct CalledFunctionParameters {
8887
#[serde(alias = "function")]
8988
pub name: String,
90-
#[serde(alias = "arguments")]
91-
pub parameters: HashMap<String, Value>,
89+
#[serde(alias = "arguments", deserialize_with = "flexible_args")]
90+
pub parameters: Value,
91+
}
92+
93+
// Accept either `{...}` **or** a `"stringified { ... }"`
94+
fn flexible_args<'de, D>(d: D) -> std::result::Result<Value, D::Error>
95+
where
96+
D: Deserializer<'de>,
97+
{
98+
struct ArgVisitor;
99+
100+
impl<'de> Visitor<'de> for ArgVisitor {
101+
type Value = Value;
102+
103+
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
104+
f.write_str("an object or a JSON-encoded string containing an object")
105+
}
106+
107+
// Case 1 – the good case: already a JSON object
108+
fn visit_map<M>(self, mut m: M) -> std::result::Result<Self::Value, M::Error>
109+
where
110+
M: MapAccess<'de>,
111+
{
112+
let mut map = Map::new();
113+
while let Some((k, v)) = m.next_entry()? {
114+
map.insert(k, v);
115+
}
116+
Ok(Value::Object(map))
117+
}
118+
119+
// Case 2 – got a *string*; try parsing it as JSON
120+
fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
121+
where
122+
E: de::Error,
123+
{
124+
serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}")))
125+
}
126+
}
127+
128+
d.deserialize_any(ArgVisitor)
129+
}
130+
131+
/// Fixup potentially broken JSON
132+
/// 1) allow/handle arguments as maps in quotations
133+
fn fix_broken_json(raw: &str) -> anyhow::Result<String> {
134+
// 1) Delete the opening quote that shouldn’t be there
135+
let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1);
136+
// 2) Delete the closing quote that matches it
137+
let fixed = tmp.replacen(r#"}"}"#, r#"}}"#, 1);
138+
139+
Ok(fixed)
92140
}
93141

94142
impl ToolCallingMatcher {
@@ -111,6 +159,7 @@ impl ToolCallingMatcher {
111159
return Ok((false, false));
112160
}
113161
let message_prefix = process_model_specific_message(message_prefix)?;
162+
let message_prefix = fix_broken_json(&message_prefix).unwrap();
114163

115164
// Check if the prefix could be a JSON serialization of any of the following types.
116165
Ok([
@@ -138,6 +187,7 @@ impl ToolCallingMatcher {
138187
return Ok(Vec::new());
139188
}
140189
let message = process_model_specific_message(message)?;
190+
let message = fix_broken_json(&message).unwrap();
141191

142192
if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
143193
let id = format!("call-{}", Uuid::new_v4());

0 commit comments

Comments
 (0)