|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -//! [`PruningPredicate`] to apply filter [`Expr`] to prune "containers" |
19 |
| -//! based on statistics (e.g. Parquet Row Groups) |
20 |
| -//! |
21 |
| -//! [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html |
22 |
| -use std::collections::HashSet; |
23 |
| -use std::sync::Arc; |
24 |
| - |
25 |
| -use arrow::array::AsArray; |
| 18 | +use arrow::compute::{cast, CastOptions}; |
26 | 19 | use arrow::{
|
27 |
| - array::{new_null_array, ArrayRef, BooleanArray}, |
| 20 | + array::{new_null_array, ArrayRef, AsArray, BooleanArray}, |
28 | 21 | datatypes::{DataType, Field, Schema, SchemaRef},
|
29 | 22 | record_batch::{RecordBatch, RecordBatchOptions},
|
30 | 23 | };
|
31 | 24 | use log::trace;
|
| 25 | +use std::collections::HashSet; |
| 26 | +use std::sync::Arc; |
32 | 27 |
|
33 | 28 | use datafusion_common::error::{DataFusionError, Result};
|
34 | 29 | use datafusion_common::tree_node::TransformedResult;
|
@@ -505,6 +500,21 @@ impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
|
505 | 500 | }
|
506 | 501 | }
|
507 | 502 |
|
| 503 | +fn decode_dictionary_to_decimal( |
| 504 | + array: &ArrayRef, |
| 505 | + precision: u8, |
| 506 | + scale: u8, |
| 507 | +) -> arrow::error::Result<ArrayRef> { |
| 508 | + // e.g. Decimal128(4, 1), or whatever your stats require |
| 509 | + let target_type = DataType::Decimal128( |
| 510 | + (precision as usize).try_into().unwrap(), |
| 511 | + (scale as usize).try_into().unwrap(), |
| 512 | + ); |
| 513 | + // The CastOptions can specify whether to allow loss of precision, etc. |
| 514 | + let casted = cast(array.as_ref(), &target_type)?; |
| 515 | + Ok(casted) |
| 516 | +} |
| 517 | + |
508 | 518 | impl PruningPredicate {
|
509 | 519 | /// Try to create a new instance of [`PruningPredicate`]
|
510 | 520 | ///
|
@@ -622,10 +632,36 @@ impl PruningPredicate {
|
622 | 632 | // appropriate statistics columns for the min/max predicate
|
623 | 633 | let statistics_batch =
|
624 | 634 | build_statistics_record_batch(statistics, &self.required_columns)?;
|
625 |
| - println!("==> Statistics batch columns: {:#?}", statistics_batch); |
626 | 635 |
|
| 636 | + println!("==> Statistics batch columns: {:#?}", statistics_batch); |
| 637 | + // Construct a new, decoded record batch if you detect dictionary-of-decimal columns |
| 638 | + let decoded_columns = statistics_batch |
| 639 | + .columns() |
| 640 | + .iter() |
| 641 | + .zip(statistics_batch.schema().fields()) |
| 642 | + .map(|(arr, field)| { |
| 643 | + if let DataType::Dictionary(_, inner_ty) = field.data_type() { |
| 644 | + // if it's decimal |
| 645 | + if let DataType::Decimal128(precision, scale) = &**inner_ty { |
| 646 | + return decode_dictionary_to_decimal( |
| 647 | + arr, |
| 648 | + *precision as u8, |
| 649 | + *scale as u8, |
| 650 | + ); |
| 651 | + } |
| 652 | + } |
| 653 | + // fallback: no decode |
| 654 | + Ok(Arc::clone(arr)) |
| 655 | + }) |
| 656 | + .collect::<Result<Vec<ArrayRef>, arrow::error::ArrowError>>()?; |
| 657 | + |
| 658 | + // Build a new RecordBatch with these columns |
| 659 | + let decoded_stats_batch = RecordBatch::try_new( |
| 660 | + Arc::clone(&statistics_batch.schema()), |
| 661 | + decoded_columns, |
| 662 | + )?; |
627 | 663 | // Evaluate the pruning predicate on that record batch and append any results to the builder
|
628 |
| - let eval_result = self.predicate_expr.evaluate(&statistics_batch)?; |
| 664 | + let eval_result = self.predicate_expr.evaluate(&decoded_stats_batch)?; |
629 | 665 | println!(
|
630 | 666 | "==> Evaluating expression: {:?} => {:?}",
|
631 | 667 | self.predicate_expr, eval_result
|
@@ -981,7 +1017,7 @@ fn build_statistics_record_batch<S: PruningStatistics>(
|
981 | 1017 |
|
982 | 1018 | // cast statistics array to required data type (e.g. parquet
|
983 | 1019 | // provides timestamp statistics as "Int64")
|
984 |
| - let array = arrow::compute::cast(&array, data_type)?; |
| 1020 | + let array = cast(&array, data_type)?; |
985 | 1021 |
|
986 | 1022 | fields.push(stat_field.clone());
|
987 | 1023 | arrays.push(array);
|
|
0 commit comments