Skip to content

Commit 5e307b3

Browse files
blagininDmitrii Blaginin
andauthored
Handle dicts for distinct count (#15871)
* Handle dicts for distinct count * Fix sqllogictests * Add bench * Fix no fix the bench * Do not panic if error type is bad * Add full bench query * Set the bench * Add dict of dict test * Fix tests * Rename method * Increase the grouping test * Increase the grouping test a bit more :) * Fix flakiness --------- Co-authored-by: Dmitrii Blaginin <[email protected]>
1 parent ab8cd8c commit 5e307b3

File tree

5 files changed

+285
-118
lines changed

5 files changed

+285
-118
lines changed

datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
// under the License.
1717

1818
mod bytes;
19+
mod dict;
1920
mod native;
2021

2122
pub use bytes::BytesDistinctCountAccumulator;
2223
pub use bytes::BytesViewDistinctCountAccumulator;
24+
pub use dict::DictionaryCountAccumulator;
2325
pub use native::FloatDistinctCountAccumulator;
2426
pub use native::PrimitiveDistinctCountAccumulator;
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayRef, BooleanArray};
19+
use arrow::downcast_dictionary_array;
20+
use datafusion_common::{arrow_datafusion_err, ScalarValue};
21+
use datafusion_common::{internal_err, DataFusionError};
22+
use datafusion_expr_common::accumulator::Accumulator;
23+
24+
#[derive(Debug)]
25+
pub struct DictionaryCountAccumulator {
26+
inner: Box<dyn Accumulator>,
27+
}
28+
29+
impl DictionaryCountAccumulator {
30+
pub fn new(inner: Box<dyn Accumulator>) -> Self {
31+
Self { inner }
32+
}
33+
}
34+
35+
impl Accumulator for DictionaryCountAccumulator {
36+
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
37+
let values: Vec<_> = values
38+
.iter()
39+
.map(|dict| {
40+
downcast_dictionary_array! {
41+
dict => {
42+
let buff: BooleanArray = dict.occupancy().into();
43+
arrow::compute::filter(
44+
dict.values(),
45+
&buff
46+
).map_err(|e| arrow_datafusion_err!(e))
47+
},
48+
_ => internal_err!("DictionaryCountAccumulator only supports dictionary arrays")
49+
}
50+
})
51+
.collect::<Result<Vec<_>, _>>()?;
52+
self.inner.update_batch(values.as_slice())
53+
}
54+
55+
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
56+
self.inner.evaluate()
57+
}
58+
59+
fn size(&self) -> usize {
60+
self.inner.size()
61+
}
62+
63+
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
64+
self.inner.state()
65+
}
66+
67+
fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
68+
self.inner.merge_batch(states)
69+
}
70+
}

datafusion/functions-aggregate/benches/count.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@
1717

1818
use arrow::array::{ArrayRef, BooleanArray};
1919
use arrow::datatypes::{DataType, Field, Int32Type, Schema};
20-
use arrow::util::bench_util::{create_boolean_array, create_primitive_array};
20+
use arrow::util::bench_util::{
21+
create_boolean_array, create_dict_from_values, create_primitive_array,
22+
create_string_array_with_len,
23+
};
2124
use criterion::{black_box, criterion_group, criterion_main, Criterion};
22-
use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator};
25+
use datafusion_expr::{
26+
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, GroupsAccumulator,
27+
};
2328
use datafusion_functions_aggregate::count::Count;
2429
use datafusion_physical_expr::expressions::col;
2530
use datafusion_physical_expr_common::sort_expr::LexOrdering;
2631
use std::sync::Arc;
2732

28-
fn prepare_accumulator() -> Box<dyn GroupsAccumulator> {
33+
fn prepare_group_accumulator() -> Box<dyn GroupsAccumulator> {
2934
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)]));
3035
let accumulator_args = AccumulatorArgs {
3136
return_field: Field::new("f", DataType::Int64, true).into(),
@@ -44,13 +49,34 @@ fn prepare_accumulator() -> Box<dyn GroupsAccumulator> {
4449
.unwrap()
4550
}
4651

52+
fn prepare_accumulator() -> Box<dyn Accumulator> {
53+
let schema = Arc::new(Schema::new(vec![Field::new(
54+
"f",
55+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
56+
true,
57+
)]));
58+
let accumulator_args = AccumulatorArgs {
59+
return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
60+
schema: &schema,
61+
ignore_nulls: false,
62+
ordering_req: &LexOrdering::default(),
63+
is_reversed: false,
64+
name: "COUNT(f)",
65+
is_distinct: true,
66+
exprs: &[col("f", &schema).unwrap()],
67+
};
68+
let count_fn = Count::new();
69+
70+
count_fn.accumulator(accumulator_args).unwrap()
71+
}
72+
4773
fn convert_to_state_bench(
4874
c: &mut Criterion,
4975
name: &str,
5076
values: ArrayRef,
5177
opt_filter: Option<&BooleanArray>,
5278
) {
53-
let accumulator = prepare_accumulator();
79+
let accumulator = prepare_group_accumulator();
5480
c.bench_function(name, |b| {
5581
b.iter(|| {
5682
black_box(
@@ -89,6 +115,18 @@ fn count_benchmark(c: &mut Criterion) {
89115
values,
90116
Some(&filter),
91117
);
118+
119+
let arr = create_string_array_with_len::<i32>(20, 0.0, 50);
120+
let values =
121+
Arc::new(create_dict_from_values::<Int32Type>(200_000, 0.8, &arr)) as ArrayRef;
122+
123+
let mut accumulator = prepare_accumulator();
124+
c.bench_function("count low cardinality dict 20% nulls, no filter", |b| {
125+
b.iter(|| {
126+
#[allow(clippy::unit_arg)]
127+
black_box(accumulator.update_batch(&[values.clone()]).unwrap())
128+
})
129+
});
92130
}
93131

94132
criterion_group!(benches, count_benchmark);

0 commit comments

Comments
 (0)