Skip to content

Commit 506d047

Browse files
authored
Support appending large Arrow record batches (#530)
Based on the test from #271 by @plaflamme, this PR enhances/fixes `append_record_batch` to automatically chunk the passed record batch into smaller pieces (using Arrow's zero-copy slicing) that fit within DuckDB's vector size limit.
2 parents a0f0b80 + 211bc3d commit 506d047

File tree

1 file changed

+53
-14
lines changed

1 file changed

+53
-14
lines changed

crates/duckdb/src/appender/arrow.rs

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ use crate::{
66
Error,
77
};
88
use arrow::record_batch::RecordBatch;
9-
use ffi::duckdb_append_data_chunk;
9+
use ffi::{duckdb_append_data_chunk, duckdb_vector_size};
1010

1111
impl Appender<'_> {
12-
/// Append one record_batch
12+
/// Append one record batch
1313
///
1414
/// ## Example
1515
///
@@ -28,27 +28,43 @@ impl Appender<'_> {
2828
/// Will return `Err` if append column count not the same with the table schema
2929
#[inline]
3030
pub fn append_record_batch(&mut self, record_batch: RecordBatch) -> Result<()> {
31-
let schema = record_batch.schema();
32-
let mut logical_type: Vec<LogicalTypeHandle> = vec![];
33-
for field in schema.fields() {
34-
let logical_t = to_duckdb_logical_type(field.data_type())
35-
.map_err(|_op| Error::ArrowTypeToDuckdbType(field.to_string(), field.data_type().clone()))?;
36-
logical_type.push(logical_t);
37-
}
31+
let logical_types: Vec<LogicalTypeHandle> = record_batch
32+
.schema()
33+
.fields()
34+
.iter()
35+
.map(|field| {
36+
to_duckdb_logical_type(field.data_type())
37+
.map_err(|_op| Error::ArrowTypeToDuckdbType(field.to_string(), field.data_type().clone()))
38+
})
39+
.collect::<Result<Vec<_>, _>>()?;
40+
41+
let vector_size = unsafe { duckdb_vector_size() } as usize;
42+
let num_rows = record_batch.num_rows();
43+
44+
// Process record batch in chunks that fit within DuckDB's vector size
45+
let mut offset = 0;
46+
while offset < num_rows {
47+
let slice_len = std::cmp::min(vector_size, num_rows - offset);
48+
let slice = record_batch.slice(offset, slice_len);
3849

39-
let mut data_chunk = DataChunkHandle::new(&logical_type);
40-
record_batch_to_duckdb_data_chunk(&record_batch, &mut data_chunk).map_err(|_op| Error::AppendError)?;
50+
let mut data_chunk = DataChunkHandle::new(&logical_types);
51+
record_batch_to_duckdb_data_chunk(&slice, &mut data_chunk).map_err(|_op| Error::AppendError)?;
4152

42-
let rc = unsafe { duckdb_append_data_chunk(self.app, data_chunk.get_ptr()) };
43-
result_from_duckdb_appender(rc, &mut self.app)
53+
let rc = unsafe { duckdb_append_data_chunk(self.app, data_chunk.get_ptr()) };
54+
result_from_duckdb_appender(rc, &mut self.app)?;
55+
56+
offset += slice_len;
57+
}
58+
59+
Ok(())
4460
}
4561
}
4662

4763
#[cfg(test)]
4864
mod test {
4965
use crate::{Connection, Result};
5066
use arrow::{
51-
array::{Int8Array, StringArray},
67+
array::{Int32Array, Int8Array, StringArray},
5268
datatypes::{DataType, Field, Schema},
5369
record_batch::RecordBatch,
5470
};
@@ -80,4 +96,27 @@ mod test {
8096
assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::<usize>(), 5);
8197
Ok(())
8298
}
99+
100+
#[test]
101+
fn test_append_record_batch_large() -> Result<()> {
102+
let record_count = usize::pow(2, 16) + 1;
103+
let db = Connection::open_in_memory()?;
104+
db.execute_batch("CREATE TABLE foo(id INT)")?;
105+
{
106+
let id_array = Int32Array::from((0..record_count as i32).collect::<Vec<_>>());
107+
let schema = Schema::new(vec![Field::new("id", DataType::Int32, true)]);
108+
let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
109+
let mut app = db.appender("foo")?;
110+
app.append_record_batch(record_batch)?;
111+
}
112+
let count: usize = db.query_row("SELECT COUNT(*) FROM foo", [], |row| row.get(0))?;
113+
assert_eq!(count, record_count);
114+
115+
// Verify the data is correct
116+
let sum: i64 = db.query_row("SELECT SUM(id) FROM foo", [], |row| row.get(0))?;
117+
let expected_sum: i64 = (0..record_count as i64).sum();
118+
assert_eq!(sum, expected_sum);
119+
120+
Ok(())
121+
}
83122
}

0 commit comments

Comments
 (0)