Skip to content

Commit a0f0b80

Browse files
authored
Add interrupt handle (#493)
This PR adds an `InterruptHandle` that allows interrupting long-running queries from another thread. Internally, this calls the `duckdb_interrupt` function from the C API. The implementation is inspired by [rusqlite][1]. It seems to get the job done, but a couple of thoughts: 1. Testing the interrupts is a bit tricky. I wasn't able to come up with a deterministic approach, so I took inspiration from tests in other client API: simply kick off an expensive query and interrupt it. 2. Unfortunately, the error code returned on interrupt is `ErrorCode::Unknown`, so we have to resort to a match on the error message. Works for now, but I'm wondering if there is a way to do this in a cleaner way, or if the actual error code is simply not exposed from the underlying API. 3. I'm not very familiar with the safety aspects of interfacing with the C API, so a critical pair of eyes is welcome. Previous attempt at adding this feature: #343. [1]: https://docs.rs/rusqlite/latest/rusqlite/struct.InterruptHandle.html
2 parents f1e58e8 + c328df5 commit a0f0b80

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

crates/duckdb/src/inner_connection.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::{
33
mem,
44
os::raw::c_char,
55
ptr, str,
6+
sync::{Arc, Mutex},
67
};
78

89
use super::{ffi, Appender, Config, Connection, Result};
@@ -15,6 +16,7 @@ use crate::{
1516
pub struct InnerConnection {
1617
pub db: ffi::duckdb_database,
1718
pub con: ffi::duckdb_connection,
19+
interrupt: Arc<InterruptHandle>,
1820
owned: bool,
1921
}
2022

@@ -30,7 +32,14 @@ impl InnerConnection {
3032
Some("connect error".to_owned()),
3133
));
3234
}
33-
Ok(Self { db, con, owned })
35+
let interrupt = Arc::new(InterruptHandle::new(con));
36+
37+
Ok(Self {
38+
db,
39+
con,
40+
interrupt,
41+
owned,
42+
})
3443
}
3544

3645
pub fn open_with_flags(c_path: &CStr, config: Config) -> Result<Self> {
@@ -57,6 +66,7 @@ impl InnerConnection {
5766
unsafe {
5867
ffi::duckdb_disconnect(&mut self.con);
5968
self.con = ptr::null_mut();
69+
self.interrupt.clear();
6070

6171
if self.owned {
6272
ffi::duckdb_close(&mut self.db);
@@ -106,6 +116,10 @@ impl InnerConnection {
106116
Ok(Appender::new(conn, c_app))
107117
}
108118

119+
pub fn get_interrupt_handle(&self) -> Arc<InterruptHandle> {
120+
self.interrupt.clone()
121+
}
122+
109123
#[inline]
110124
pub fn is_autocommit(&self) -> bool {
111125
true
@@ -126,3 +140,37 @@ impl Drop for InnerConnection {
126140
}
127141
}
128142
}
143+
144+
/// A handle that allows interrupting long-running queries.
145+
pub struct InterruptHandle {
146+
conn: Mutex<ffi::duckdb_connection>,
147+
}
148+
149+
unsafe impl Send for InterruptHandle {}
150+
unsafe impl Sync for InterruptHandle {}
151+
152+
impl InterruptHandle {
153+
fn new(conn: ffi::duckdb_connection) -> Self {
154+
Self { conn: Mutex::new(conn) }
155+
}
156+
157+
fn clear(&self) {
158+
*(self.conn.lock().unwrap()) = ptr::null_mut();
159+
}
160+
161+
/// Interrupt the query currently running on the connection this handle was
162+
/// obtained from. The interrupt will cause that query to fail with
163+
/// `Error::DuckDBFailure`. If the connection was dropped after obtaining
164+
/// this interrupt handle, calling this method results in a noop.
165+
///
166+
/// See [`crate::Connection::interrupt_handle`] for an example.
167+
pub fn interrupt(&self) {
168+
let db_handle = self.conn.lock().unwrap();
169+
170+
if !db_handle.is_null() {
171+
unsafe {
172+
ffi::duckdb_interrupt(*db_handle);
173+
}
174+
}
175+
}
176+
}

crates/duckdb/src/lib.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub use crate::{
7979
config::{AccessMode, Config, DefaultNullOrder, DefaultOrder},
8080
error::Error,
8181
ffi::ErrorCode,
82+
inner_connection::InterruptHandle,
8283
params::{params_from_iter, Params, ParamsFromIter},
8384
row::{AndThenRows, Map, MappedRows, Row, RowIndex, Rows},
8485
statement::Statement,
@@ -532,6 +533,30 @@ impl Connection {
532533
self.db.borrow_mut().appender(self, table, schema)
533534
}
534535

536+
/// Get a handle to interrupt long-running queries.
537+
///
538+
/// ## Example
539+
///
540+
/// ```rust,no_run
541+
/// # use duckdb::{Connection, Result};
542+
/// fn run_query(conn: Connection) -> Result<()> {
543+
/// let interrupt_handle = conn.interrupt_handle();
544+
/// let join_handle = std::thread::spawn(move || { conn.execute("expensive query", []) });
545+
///
546+
/// // Arbitrary wait for query to start
547+
/// std::thread::sleep(std::time::Duration::from_millis(100));
548+
///
549+
/// interrupt_handle.interrupt();
550+
///
551+
/// let query_result = join_handle.join().unwrap();
552+
/// assert!(query_result.is_err());
553+
///
554+
/// Ok(())
555+
/// }
556+
pub fn interrupt_handle(&self) -> std::sync::Arc<InterruptHandle> {
557+
self.db.borrow().get_interrupt_handle()
558+
}
559+
535560
/// Close the DuckDB connection.
536561
///
537562
/// This is functionally equivalent to the `Drop` implementation for
@@ -1338,6 +1363,36 @@ mod test {
13381363
Ok(())
13391364
}
13401365

1366+
#[test]
1367+
fn test_interrupt() -> Result<()> {
1368+
let db = checked_memory_handle();
1369+
let db_interrupt = db.interrupt_handle();
1370+
1371+
let (tx, rx) = std::sync::mpsc::channel();
1372+
std::thread::spawn(move || {
1373+
let mut stmt = db
1374+
.prepare("select count(*) from range(10000000) t1, range(1000000) t2")
1375+
.unwrap();
1376+
tx.send(stmt.execute([])).unwrap();
1377+
});
1378+
1379+
std::thread::sleep(std::time::Duration::from_millis(100));
1380+
db_interrupt.interrupt();
1381+
1382+
let result = rx.recv_timeout(std::time::Duration::from_secs(5)).unwrap();
1383+
assert!(result.is_err_and(|err| err.to_string().contains("INTERRUPT")));
1384+
Ok(())
1385+
}
1386+
1387+
#[test]
1388+
fn test_interrupt_on_dropped_db() {
1389+
let db = checked_memory_handle();
1390+
let db_interrupt = db.interrupt_handle();
1391+
1392+
drop(db);
1393+
db_interrupt.interrupt();
1394+
}
1395+
13411396
#[cfg(feature = "bundled")]
13421397
#[test]
13431398
fn test_version() -> Result<()> {

0 commit comments

Comments
 (0)