Skip to content

Commit ecb6907

Browse files
authored
Propely handle consecutive searches (#1421)
* Update extraction tool reinjection * Looped
1 parent 57d6e12 commit ecb6907

File tree

1 file changed

+161
-114
lines changed

1 file changed

+161
-114
lines changed

mistralrs-core/src/engine/search_request.rs

Lines changed: 161 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@ use tracing::{level_filters::LevelFilter, Dispatch};
99
use crate::{
1010
get_mut_arcmutex,
1111
request::SearchContextSize,
12-
search::{
13-
self, search_tool_called, ExtractFunctionParameters, SearchFunctionParameters,
14-
SearchResult, EXTRACT_TOOL_NAME, SEARCH_TOOL_NAME,
15-
},
12+
search::{self, ExtractFunctionParameters, SearchFunctionParameters, SearchResult},
1613
MessageContent, NormalRequest, RequestMessage, Response, ResponseOk, ToolCallResponse,
17-
WebSearchOptions,
14+
ToolChoice, WebSearchOptions,
1815
};
1916

2017
use super::Engine;
@@ -24,7 +21,7 @@ async fn do_search(
2421
mut second_request: NormalRequest,
2522
tool_calls: &ToolCallResponse,
2623
web_search_options: &WebSearchOptions,
27-
) {
24+
) -> NormalRequest {
2825
let messages = match &mut second_request.messages {
2926
RequestMessage::Chat { messages, .. } | RequestMessage::VisionChat { messages, .. } => {
3027
messages
@@ -212,18 +209,20 @@ async fn do_search(
212209
messages.push(message);
213210
}
214211

212+
// Allow the assistant to invoke tools again on the next turn
213+
second_request.tool_choice = Some(ToolChoice::Auto);
215214
// Recursion is enabled here!
216215
second_request.web_search_options = Some(web_search_options.clone());
217216

218-
this.add_request(second_request).await;
217+
second_request
219218
}
220219

221220
async fn do_extraction(
222221
this: Arc<Engine>,
223222
mut second_request: NormalRequest,
224223
tool_calls: &ToolCallResponse,
225224
web_search_options: &WebSearchOptions,
226-
) {
225+
) -> NormalRequest {
227226
let messages = match &mut second_request.messages {
228227
RequestMessage::Chat { messages, .. } | RequestMessage::VisionChat { messages, .. } => {
229228
messages
@@ -306,144 +305,192 @@ async fn do_extraction(
306305
message.insert("role".to_string(), Either::Left("tool".to_string()));
307306
message.insert(
308307
"content".to_string(),
309-
Either::Left(format!("{{\"output\": \"{tool_result}\"}}")),
308+
Either::Left(
309+
// Format the tool output JSON and append the search tool description for context
310+
format!(
311+
"{{\"output\": \"{}\"}}\n\n{}\n\n{}",
312+
tool_result,
313+
search::SEARCH_DESCRIPTION,
314+
search::EXTRACT_DESCRIPTION,
315+
),
316+
),
310317
);
311318
messages.push(message);
312319
}
313320

321+
// Allow the assistant to invoke tools again on the next turn
322+
second_request.tool_choice = Some(ToolChoice::Auto);
314323
// Recursion is enabled here!
315324
second_request.web_search_options = Some(web_search_options.clone());
316325

317-
this.add_request(second_request).await;
326+
second_request
318327
}
319328

320-
/// The strategy is:
321-
/// - Send the first request to allow a tool call
322-
/// - If no, tool call, early return
323-
/// - Proceed to `do_search`
324-
/// 1) Execute search
325-
/// 2) Rank by relevance
326-
/// - Send final tool call, which is allowed to have web search for repeated queries.
329+
/// Drive one or more web-search / extraction rounds without recursion.
330+
///
331+
/// Strategy:
332+
/// 1. Send a “probe” request that may call the search/extract tools.
333+
/// 2. If such a tool is called, run it (`do_search` / `do_extraction`) to
334+
/// mutate the conversational context and build the next request.
335+
/// 3. Repeat until no further tool call is made.
336+
/// 4. Forward every user-visible reply **except** the first, which is just the
337+
/// probe that discovers whether a tool call is needed.
327338
pub(super) async fn search_request(this: Arc<Engine>, request: NormalRequest) {
339+
// We entered this function only when web_search_options is Some(_)
328340
let Some(web_search_options) = request.web_search_options.clone() else {
329341
unreachable!()
330342
};
331-
let mut first_request = request.clone();
332-
// Actually add the search tools here
333-
first_request
343+
344+
// The sender that ultimately delivers data back to the caller.
345+
let user_sender = request.response.clone();
346+
let is_streaming = request.is_streaming;
347+
348+
// ---------------------------------------------------------------------
349+
// Build the *first* request (the “probe”).
350+
// ---------------------------------------------------------------------
351+
let mut probe = request.clone();
352+
probe
334353
.tools
335354
.get_or_insert_with(Vec::new)
336355
.extend(search::get_search_tools(&web_search_options).unwrap());
356+
probe.tool_choice = Some(ToolChoice::Auto);
357+
// Prevent accidental infinite recursion on the probe itself.
358+
probe.web_search_options = None;
337359

338-
let mut second_request = first_request.clone();
339-
first_request.web_search_options = None;
340-
second_request.web_search_options = None;
360+
// The conversation context that the user *will* see.
361+
let mut visible_req = probe.clone();
362+
visible_req.response = user_sender.clone();
341363

364+
// We'll drive everything inside a single spawned task.
342365
let this_clone = this.clone();
366+
let handle = tokio::spawn(async move {
367+
// `current` is what we actually dispatch each loop.
368+
// The very first time that is the hidden probe.
369+
let mut current = probe;
370+
// Forward results to the user after the first loop.
371+
let mut forward_to_user = false;
372+
373+
loop {
374+
// Each dispatch gets its own one-shot channel so we can peek at
375+
// the response before (optionally) forwarding it.
376+
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
377+
current.response = sender;
378+
379+
// Kick the request into the engine.
380+
this_clone.add_request(current).await;
381+
382+
// ----------------------- NON-STREAMING ------------------------
383+
if !is_streaming {
384+
let ResponseOk::Done(done) = receiver.recv().await.unwrap().as_result().unwrap()
385+
else {
386+
unreachable!();
387+
};
343388

344-
if !request.is_streaming {
345-
let handle = tokio::spawn(async move {
346-
let (new_sender, mut first_receiver) = tokio::sync::mpsc::channel(1);
347-
second_request.response = new_sender;
348-
std::mem::swap(&mut first_request.response, &mut second_request.response);
349-
350-
this_clone.add_request(first_request).await;
351-
let ResponseOk::Done(done) = first_receiver.recv().await.unwrap().as_result().unwrap()
352-
else {
353-
unreachable!()
354-
};
355-
356-
let tool_calls = match &done.choices[0].message.tool_calls {
357-
Some(tool_calls)
358-
if tool_calls.len() == 1
359-
&& search_tool_called(&tool_calls[0].function.name) =>
360-
{
361-
&tool_calls[0]
362-
}
363-
None => {
364-
second_request
365-
.response
366-
.send(Response::Done(done))
389+
// Forward to the caller once the probe is out of the way.
390+
if forward_to_user {
391+
user_sender
392+
.send(Response::Done(done.clone()))
367393
.await
368394
.unwrap();
369-
return;
370395
}
371-
Some(_) => {
372-
second_request
373-
.response
374-
.send(Response::Done(done))
375-
.await
376-
.unwrap();
377-
return;
396+
397+
// Did the assistant ask to run a tool?
398+
let tc_opt = match &done.choices[0].message.tool_calls {
399+
Some(calls)
400+
if calls.len() == 1
401+
&& search::search_tool_called(&calls[0].function.name) =>
402+
{
403+
Some(&calls[0])
404+
}
405+
_ => None,
406+
};
407+
408+
// No tool call? We are finished.
409+
if tc_opt.is_none() {
410+
break;
378411
}
379-
};
380412

381-
if tool_calls.function.name == SEARCH_TOOL_NAME {
382-
do_search(this_clone, second_request, tool_calls, &web_search_options).await;
383-
} else if tool_calls.function.name == EXTRACT_TOOL_NAME {
384-
do_extraction(this_clone, second_request, tool_calls, &web_search_options).await;
385-
} else {
386-
unreachable!()
413+
// Tool requested → build the next turn.
414+
let tc = tc_opt.unwrap();
415+
let next_visible = if tc.function.name == search::SEARCH_TOOL_NAME {
416+
do_search(this_clone.clone(), visible_req, tc, &web_search_options).await
417+
} else {
418+
do_extraction(this_clone.clone(), visible_req, tc, &web_search_options).await
419+
};
420+
421+
// The fresh request becomes both the user-visible context and
422+
// the next `current` we will dispatch.
423+
visible_req = next_visible.clone();
424+
visible_req.response = user_sender.clone();
425+
current = visible_req.clone();
426+
forward_to_user = true;
387427
}
388-
});
389-
get_mut_arcmutex!(this.handles).push(handle);
390-
} else {
391-
let handle = tokio::spawn(async move {
392-
let (new_sender, mut first_receiver) = tokio::sync::mpsc::channel(1);
393-
second_request.response = new_sender;
394-
std::mem::swap(&mut first_request.response, &mut second_request.response);
395-
396-
this_clone.add_request(first_request).await;
397-
let ResponseOk::Chunk(done) = first_receiver.recv().await.unwrap().as_result().unwrap()
428+
// ------------------------- STREAMING -------------------------
398429
else {
399-
unreachable!()
400-
};
401-
second_request
402-
.response
403-
.send(Response::Chunk(done.clone()))
404-
.await
405-
.unwrap();
430+
// We need the *last* chunk to see whether a tool was called.
431+
let mut last_choice = None;
432+
433+
while let Some(resp) = receiver.recv().await {
434+
match resp.as_result().unwrap() {
435+
ResponseOk::Chunk(chunk) => {
436+
// Forward every content‑bearing chunk immediately, but
437+
// *suppress* the ones that initiate a tool call. This ensures
438+
// the user sees the assistant’s streamed text from the very
439+
// first probe turn while still hiding the internal
440+
// search/extract trigger.
441+
let first_choice = &chunk.choices[0];
442+
if first_choice.delta.tool_calls.is_none() {
443+
user_sender
444+
.send(Response::Chunk(chunk.clone()))
445+
.await
446+
.unwrap();
447+
}
448+
449+
last_choice = Some(first_choice.clone());
450+
451+
// Stop once the model marks completion.
452+
if last_choice
453+
.as_ref()
454+
.and_then(|c| c.finish_reason.as_ref())
455+
.is_some()
456+
{
457+
break;
458+
}
459+
}
460+
_ => unreachable!(),
461+
}
462+
}
406463

407-
let mut choice = done.choices[0].clone();
464+
let Some(choice) = last_choice else { break };
408465

409-
while choice.finish_reason.is_none() {
410-
let ResponseOk::Chunk(done) =
411-
first_receiver.recv().await.unwrap().as_result().unwrap()
412-
else {
413-
unreachable!()
466+
let tc_opt = match &choice.delta.tool_calls {
467+
Some(calls)
468+
if calls.len() == 1
469+
&& search::search_tool_called(&calls[0].function.name) =>
470+
{
471+
Some(&calls[0])
472+
}
473+
_ => None,
414474
};
415-
second_request
416-
.response
417-
.send(Response::Chunk(done.clone()))
418-
.await
419-
.unwrap();
420-
421-
choice = done.choices[0].clone();
422-
}
423475

424-
let tool_calls = match &choice.delta.tool_calls {
425-
Some(tool_calls)
426-
if tool_calls.len() == 1
427-
&& search_tool_called(&tool_calls[0].function.name) =>
428-
{
429-
&tool_calls[0]
476+
if tc_opt.is_none() {
477+
break; // No more tool calls → done.
430478
}
431-
None => {
432-
return;
433-
}
434-
Some(_) => {
435-
return;
436-
}
437-
};
438479

439-
if tool_calls.function.name == SEARCH_TOOL_NAME {
440-
do_search(this_clone, second_request, tool_calls, &web_search_options).await;
441-
} else if tool_calls.function.name == EXTRACT_TOOL_NAME {
442-
do_extraction(this_clone, second_request, tool_calls, &web_search_options).await;
443-
} else {
444-
unreachable!()
480+
let tc = tc_opt.unwrap();
481+
let next_visible = if tc.function.name == search::SEARCH_TOOL_NAME {
482+
do_search(this_clone.clone(), visible_req, tc, &web_search_options).await
483+
} else {
484+
do_extraction(this_clone.clone(), visible_req, tc, &web_search_options).await
485+
};
486+
487+
visible_req = next_visible.clone();
488+
visible_req.response = user_sender.clone();
489+
current = visible_req.clone();
490+
forward_to_user = true;
445491
}
446-
});
447-
get_mut_arcmutex!(this.handles).push(handle);
448-
}
492+
}
493+
});
494+
495+
get_mut_arcmutex!(this.handles).push(handle);
449496
}

0 commit comments

Comments
 (0)