Skip to content

Commit e481a15

Browse files
committed
Add substrait roundtrip option in sqllogictests
1 parent 21248fb commit e481a15

File tree

7 files changed

+233
-6
lines changed

7 files changed

+233
-6
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/sqllogictest/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ chrono = { workspace = true, optional = true }
4545
clap = { version = "4.5.39", features = ["derive", "env"] }
4646
datafusion = { workspace = true, default-features = true, features = ["avro"] }
4747
datafusion-spark = { workspace = true, default-features = true }
48+
datafusion-substrait = { path = "../substrait" }
4849
futures = { workspace = true }
4950
half = { workspace = true, default-features = true }
5051
indicatif = "0.17"

datafusion/sqllogictest/bin/sqllogictests.rs

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ use datafusion::common::utils::get_available_parallelism;
2121
use datafusion::common::{exec_err, DataFusionError, Result};
2222
use datafusion_sqllogictest::{
2323
df_value_validator, read_dir_recursive, setup_scratch_dir, should_skip_file,
24-
should_skip_record, value_normalizer, DataFusion, Filter, TestContext,
24+
should_skip_record, value_normalizer, DataFusion, DataFusionSubstraitRoundTrip,
25+
Filter, TestContext,
2526
};
2627
use futures::stream::StreamExt;
2728
use indicatif::{
@@ -102,6 +103,11 @@ async fn run_tests() -> Result<()> {
102103
// to stdout and return OK so they can continue listing other tests.
103104
return Ok(());
104105
}
106+
if options.substrait_round_trip && (options.postgres_runner || options.complete) {
107+
let msg = "--substrait-round-trip option is not supported with --postgres-runner or --complete";
108+
return Err(DataFusionError::External(msg.into()));
109+
}
110+
105111
options.warn_on_ignored();
106112

107113
#[cfg(feature = "postgres")]
@@ -138,8 +144,22 @@ async fn run_tests() -> Result<()> {
138144
let filters = options.filters.clone();
139145

140146
SpawnedTask::spawn(async move {
141-
match (options.postgres_runner, options.complete) {
142-
(false, false) => {
147+
match (
148+
options.postgres_runner,
149+
options.complete,
150+
options.substrait_round_trip,
151+
) {
152+
(_, _, true) => {
153+
run_test_file_substrait_round_trip(
154+
test_file,
155+
validator,
156+
m_clone,
157+
m_style_clone,
158+
filters.as_ref(),
159+
)
160+
.await?
161+
}
162+
(false, false, _) => {
143163
run_test_file(
144164
test_file,
145165
validator,
@@ -149,11 +169,11 @@ async fn run_tests() -> Result<()> {
149169
)
150170
.await?
151171
}
152-
(false, true) => {
172+
(false, true, _) => {
153173
run_complete_file(test_file, validator, m_clone, m_style_clone)
154174
.await?
155175
}
156-
(true, false) => {
176+
(true, false, _) => {
157177
run_test_file_with_postgres(
158178
test_file,
159179
validator,
@@ -163,7 +183,7 @@ async fn run_tests() -> Result<()> {
163183
)
164184
.await?
165185
}
166-
(true, true) => {
186+
(true, true, _) => {
167187
run_complete_file_with_postgres(
168188
test_file,
169189
validator,
@@ -210,6 +230,45 @@ async fn run_tests() -> Result<()> {
210230
}
211231
}
212232

233+
async fn run_test_file_substrait_round_trip(
234+
test_file: TestFile,
235+
validator: Validator,
236+
mp: MultiProgress,
237+
mp_style: ProgressStyle,
238+
filters: &[Filter],
239+
) -> Result<()> {
240+
let TestFile {
241+
path,
242+
relative_path,
243+
} = test_file;
244+
let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await else {
245+
info!("Skipping: {}", path.display());
246+
return Ok(());
247+
};
248+
setup_scratch_dir(&relative_path)?;
249+
250+
let count: u64 = get_record_count(&path, "DatafusionSubstraitRoundTrip".to_string());
251+
let pb = mp.add(ProgressBar::new(count));
252+
253+
pb.set_style(mp_style);
254+
pb.set_message(format!("{:?}", &relative_path));
255+
256+
let mut runner = sqllogictest::Runner::new(|| async {
257+
Ok(DataFusionSubstraitRoundTrip::new(
258+
test_ctx.session_ctx().clone(),
259+
relative_path.clone(),
260+
pb.clone(),
261+
))
262+
});
263+
runner.add_label("DatafusionSubstraitRoundTrip");
264+
runner.with_column_validator(strict_column_validator);
265+
runner.with_normalizer(value_normalizer);
266+
runner.with_validator(validator);
267+
let res = run_file_in_runner(path, runner, filters).await;
268+
pb.finish_and_clear();
269+
res
270+
}
271+
213272
async fn run_test_file(
214273
test_file: TestFile,
215274
validator: Validator,
@@ -578,6 +637,12 @@ struct Options {
578637
)]
579638
postgres_runner: bool,
580639

640+
#[clap(
641+
long,
642+
help = "Before executing each query, convert its logical plan to Substrait and from Substrait back to its logical plan"
643+
)]
644+
substrait_round_trip: bool,
645+
581646
#[clap(long, env = "INCLUDE_SQLITE", help = "Include sqlite files")]
582647
include_sqlite: bool,
583648

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod runner;
2+
3+
pub use runner::*;
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
use std::{path::PathBuf, time::Duration};
20+
21+
use crate::engines::datafusion_engine::Result;
22+
use crate::engines::output::{DFColumnType, DFOutput};
23+
use crate::{convert_batches, convert_schema_to_types, DFSqlLogicTestError};
24+
use arrow::record_batch::RecordBatch;
25+
use async_trait::async_trait;
26+
use datafusion::logical_expr::LogicalPlan;
27+
use datafusion::physical_plan::common::collect;
28+
use datafusion::physical_plan::execute_stream;
29+
use datafusion::prelude::SessionContext;
30+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
31+
use datafusion_substrait::logical_plan::producer::to_substrait_plan;
32+
use indicatif::ProgressBar;
33+
use log::Level::{Debug, Info};
34+
use log::{debug, log_enabled, warn};
35+
use sqllogictest::DBOutput;
36+
use tokio::time::Instant;
37+
38+
pub struct DataFusionSubstraitRoundTrip {
39+
ctx: SessionContext,
40+
relative_path: PathBuf,
41+
pb: ProgressBar,
42+
}
43+
44+
impl DataFusionSubstraitRoundTrip {
45+
pub fn new(ctx: SessionContext, relative_path: PathBuf, pb: ProgressBar) -> Self {
46+
Self {
47+
ctx,
48+
relative_path,
49+
pb,
50+
}
51+
}
52+
53+
fn update_slow_count(&self) {
54+
let msg = self.pb.message();
55+
let split: Vec<&str> = msg.split(" ").collect();
56+
let mut current_count = 0;
57+
58+
if split.len() > 2 {
59+
// third match will be current slow count
60+
current_count = split[2].parse::<i32>().unwrap();
61+
}
62+
63+
current_count += 1;
64+
65+
self.pb
66+
.set_message(format!("{} - {} took > 500 ms", split[0], current_count));
67+
}
68+
}
69+
70+
#[async_trait]
71+
impl sqllogictest::AsyncDB for DataFusionSubstraitRoundTrip {
72+
type Error = DFSqlLogicTestError;
73+
type ColumnType = DFColumnType;
74+
75+
async fn run(&mut self, sql: &str) -> Result<DFOutput> {
76+
if log_enabled!(Debug) {
77+
debug!(
78+
"[{}] Running query: \"{}\"",
79+
self.relative_path.display(),
80+
sql
81+
);
82+
}
83+
84+
let start = Instant::now();
85+
let result = run_query_substrait_round_trip(&self.ctx, sql).await;
86+
let duration = start.elapsed();
87+
88+
if duration.gt(&Duration::from_millis(500)) {
89+
self.update_slow_count();
90+
}
91+
92+
self.pb.inc(1);
93+
94+
if log_enabled!(Info) && duration.gt(&Duration::from_secs(2)) {
95+
warn!(
96+
"[{}] Running query took more than 2 sec ({duration:?}): \"{sql}\"",
97+
self.relative_path.display()
98+
);
99+
}
100+
101+
result
102+
}
103+
104+
/// Engine name of current database.
105+
fn engine_name(&self) -> &str {
106+
"DataFusionSubstraitRoundTrip"
107+
}
108+
109+
/// [`DataFusion`] calls this function to perform sleep.
110+
///
111+
/// The default implementation is `std::thread::sleep`, which is universal to any async runtime
112+
/// but would block the current thread. If you are running in tokio runtime, you should override
113+
/// this by `tokio::time::sleep`.
114+
async fn sleep(dur: Duration) {
115+
tokio::time::sleep(dur).await;
116+
}
117+
118+
async fn shutdown(&mut self) {}
119+
}
120+
121+
async fn run_query_substrait_round_trip(
122+
ctx: &SessionContext,
123+
sql: impl Into<String>,
124+
) -> Result<DFOutput> {
125+
let df = ctx.sql(sql.into().as_str()).await?;
126+
let task_ctx = Arc::new(df.task_ctx());
127+
128+
let state = ctx.state();
129+
let round_tripped_plan = match df.logical_plan() {
130+
// Substrait does not handle these plans
131+
LogicalPlan::Ddl(_)
132+
| LogicalPlan::Explain(_)
133+
| LogicalPlan::Dml(_)
134+
| LogicalPlan::Copy(_)
135+
| LogicalPlan::Statement(_) => df.logical_plan().clone(),
136+
// For any other plan, convert to Substrait
137+
logical_plan => {
138+
let plan = to_substrait_plan(logical_plan, &state)?;
139+
from_substrait_plan(&state, &plan).await?
140+
}
141+
};
142+
143+
let physical_plan = state.create_physical_plan(&round_tripped_plan).await?;
144+
let stream = execute_stream(physical_plan, task_ctx)?;
145+
let types = convert_schema_to_types(stream.schema().fields());
146+
let results: Vec<RecordBatch> = collect(stream).await?;
147+
let rows = convert_batches(results, false)?;
148+
149+
if rows.is_empty() && types.is_empty() {
150+
Ok(DBOutput::StatementComplete(0))
151+
} else {
152+
Ok(DBOutput::Rows { types, rows })
153+
}
154+
}

datafusion/sqllogictest/src/engines/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
/// Implementation of sqllogictest for datafusion.
1919
mod conversion;
2020
mod datafusion_engine;
21+
mod datafusion_substrait_roundtrip_engine;
2122
mod output;
2223

2324
pub use datafusion_engine::convert_batches;
2425
pub use datafusion_engine::convert_schema_to_types;
2526
pub use datafusion_engine::DFSqlLogicTestError;
2627
pub use datafusion_engine::DataFusion;
28+
pub use datafusion_substrait_roundtrip_engine::DataFusionSubstraitRoundTrip;
2729
pub use output::DFColumnType;
2830
pub use output::DFOutput;
2931

datafusion/sqllogictest/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub use engines::DFColumnType;
3434
pub use engines::DFOutput;
3535
pub use engines::DFSqlLogicTestError;
3636
pub use engines::DataFusion;
37+
pub use engines::DataFusionSubstraitRoundTrip;
3738

3839
#[cfg(feature = "postgres")]
3940
pub use engines::Postgres;

0 commit comments

Comments
 (0)