Skip to content

Commit 7c0cf6b

Browse files
committed
decode dictionary statistical_batch before evaluating
1 parent 10d78d4 commit 7c0cf6b

File tree

1 file changed

+48
-12
lines changed

1 file changed

+48
-12
lines changed

datafusion/physical-optimizer/src/pruning.rs

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,15 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! [`PruningPredicate`] to apply filter [`Expr`] to prune "containers"
19-
//! based on statistics (e.g. Parquet Row Groups)
20-
//!
21-
//! [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
22-
use std::collections::HashSet;
23-
use std::sync::Arc;
24-
25-
use arrow::array::AsArray;
18+
use arrow::compute::{cast, CastOptions};
2619
use arrow::{
27-
array::{new_null_array, ArrayRef, BooleanArray},
20+
array::{new_null_array, ArrayRef, AsArray, BooleanArray},
2821
datatypes::{DataType, Field, Schema, SchemaRef},
2922
record_batch::{RecordBatch, RecordBatchOptions},
3023
};
3124
use log::trace;
25+
use std::collections::HashSet;
26+
use std::sync::Arc;
3227

3328
use datafusion_common::error::{DataFusionError, Result};
3429
use datafusion_common::tree_node::TransformedResult;
@@ -505,6 +500,21 @@ impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
505500
}
506501
}
507502

503+
fn decode_dictionary_to_decimal(
504+
array: &ArrayRef,
505+
precision: u8,
506+
scale: u8,
507+
) -> arrow::error::Result<ArrayRef> {
508+
// e.g. Decimal128(4, 1), or whatever your stats require
509+
let target_type = DataType::Decimal128(
510+
(precision as usize).try_into().unwrap(),
511+
(scale as usize).try_into().unwrap(),
512+
);
513+
// The CastOptions can specify whether to allow loss of precision, etc.
514+
let casted = cast(array.as_ref(), &target_type)?;
515+
Ok(casted)
516+
}
517+
508518
impl PruningPredicate {
509519
/// Try to create a new instance of [`PruningPredicate`]
510520
///
@@ -622,10 +632,36 @@ impl PruningPredicate {
622632
// appropriate statistics columns for the min/max predicate
623633
let statistics_batch =
624634
build_statistics_record_batch(statistics, &self.required_columns)?;
625-
println!("==> Statistics batch columns: {:#?}", statistics_batch);
626635

636+
println!("==> Statistics batch columns: {:#?}", statistics_batch);
637+
// Construct a new, decoded record batch if you detect dictionary-of-decimal columns
638+
let decoded_columns = statistics_batch
639+
.columns()
640+
.iter()
641+
.zip(statistics_batch.schema().fields())
642+
.map(|(arr, field)| {
643+
if let DataType::Dictionary(_, inner_ty) = field.data_type() {
644+
// if it's decimal
645+
if let DataType::Decimal128(precision, scale) = &**inner_ty {
646+
return decode_dictionary_to_decimal(
647+
arr,
648+
*precision as u8,
649+
*scale as u8,
650+
);
651+
}
652+
}
653+
// fallback: no decode
654+
Ok(Arc::clone(arr))
655+
})
656+
.collect::<Result<Vec<ArrayRef>, arrow::error::ArrowError>>()?;
657+
658+
// Build a new RecordBatch with these columns
659+
let decoded_stats_batch = RecordBatch::try_new(
660+
Arc::clone(&statistics_batch.schema()),
661+
decoded_columns,
662+
)?;
627663
// Evaluate the pruning predicate on that record batch and append any results to the builder
628-
let eval_result = self.predicate_expr.evaluate(&statistics_batch)?;
664+
let eval_result = self.predicate_expr.evaluate(&decoded_stats_batch)?;
629665
println!(
630666
"==> Evaluating expression: {:?} => {:?}",
631667
self.predicate_expr, eval_result
@@ -981,7 +1017,7 @@ fn build_statistics_record_batch<S: PruningStatistics>(
9811017

9821018
// cast statistics array to required data type (e.g. parquet
9831019
// provides timestamp statistics as "Int64")
984-
let array = arrow::compute::cast(&array, data_type)?;
1020+
let array = cast(&array, data_type)?;
9851021

9861022
fields.push(stat_field.clone());
9871023
arrays.push(array);

0 commit comments

Comments
 (0)