Skip to content

Commit 5ca0d8a

Browse files
fix: Fix list aggregates on empty series (#4155)
## Changes Made When performing a list agg on an empty partition, we throw `daft.exceptions.DaftCoreException: DaftError::ValueError Need at least 1 series to perform concat`. To fix this, this PR adds a check if the series is empty. In this case we simply return an empty series. ## Related Issues Addresses #4153
1 parent f0b4469 commit 5ca0d8a

File tree

2 files changed

+90
-7
lines changed

2 files changed

+90
-7
lines changed

src/daft-core/src/array/ops/list.rs

+16-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ use crate::{
1515
FixedSizeListArray, ListArray, StructArray,
1616
},
1717
count_mode::CountMode,
18-
datatypes::{BooleanArray, DataType, Field, Int64Array, UInt64Array, Utf8Array},
18+
datatypes::{
19+
try_mean_aggregation_supertype, BooleanArray, DataType, Field, Int64Array, UInt64Array,
20+
Utf8Array,
21+
},
1922
kernels::search_sorted::build_is_valid,
2023
prelude::MapArray,
2124
series::{IntoSeries, Series},
@@ -1035,9 +1038,10 @@ impl FixedSizeListArray {
10351038
macro_rules! impl_aggs_list_array {
10361039
($la:ident) => {
10371040
impl $la {
1038-
fn agg_helper<T>(&self, op: T) -> DaftResult<Series>
1041+
fn agg_helper<T, F>(&self, op: T, target_type_getter: F) -> DaftResult<Series>
10391042
where
10401043
T: Fn(&Series) -> DaftResult<Series>,
1044+
F: Fn(&DataType) -> DaftResult<DataType>,
10411045
{
10421046
// TODO(Kevin): Currently this requires full materialization of one Series for every list. We could avoid this by implementing either sorted aggregation or an array builder
10431047

@@ -1050,23 +1054,28 @@ macro_rules! impl_aggs_list_array {
10501054

10511055
let agg_refs: Vec<_> = aggs.iter().collect();
10521056

1053-
Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name()))
1057+
if agg_refs.is_empty() {
1058+
let target_type = target_type_getter(self.child_data_type())?;
1059+
Ok(Series::empty(self.name(), &target_type))
1060+
} else {
1061+
Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name()))
1062+
}
10541063
}
10551064

10561065
pub fn sum(&self) -> DaftResult<Series> {
1057-
self.agg_helper(|s| s.sum(None))
1066+
self.agg_helper(|s| s.sum(None), |dtype| Ok(dtype.clone()))
10581067
}
10591068

10601069
pub fn mean(&self) -> DaftResult<Series> {
1061-
self.agg_helper(|s| s.mean(None))
1070+
self.agg_helper(|s| s.mean(None), try_mean_aggregation_supertype)
10621071
}
10631072

10641073
pub fn min(&self) -> DaftResult<Series> {
1065-
self.agg_helper(|s| s.min(None))
1074+
self.agg_helper(|s| s.min(None), |dtype| Ok(dtype.clone()))
10661075
}
10671076

10681077
pub fn max(&self) -> DaftResult<Series> {
1069-
self.agg_helper(|s| s.max(None))
1078+
self.agg_helper(|s| s.max(None), |dtype| Ok(dtype.clone()))
10701079
}
10711080
}
10721081
};

tests/recordbatch/list/test_list_numeric_aggs.py

+74
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import pyarrow as pa
34
import pytest
45

6+
import daft
57
from daft.datatype import DataType
68
from daft.expressions import col
79
from daft.recordbatch import MicroPartition
@@ -33,3 +35,75 @@ def test_list_min(table):
3335
def test_list_max(table):
3436
result = table.eval_expression_list([col("a").list.max()])
3537
assert result.to_pydict() == {"a": [2, 4, 5, None, None]}
38+
39+
40+
def test_list_numeric_aggs_empty_table():
41+
empty_table = MicroPartition.from_pydict(
42+
{
43+
"col": pa.array([], type=pa.list_(pa.int64())),
44+
"fixed_col": pa.array([], type=pa.list_(pa.int64(), 2)),
45+
}
46+
)
47+
48+
result = empty_table.eval_expression_list(
49+
[
50+
col("col").cast(DataType.list(DataType.int64())).list.sum().alias("col_sum"),
51+
col("col").list.mean().alias("col_mean"),
52+
col("col").list.min().alias("col_min"),
53+
col("col").list.max().alias("col_max"),
54+
col("fixed_col").list.sum().alias("fixed_col_sum"),
55+
col("fixed_col").list.mean().alias("fixed_col_mean"),
56+
col("fixed_col").list.min().alias("fixed_col_min"),
57+
col("fixed_col").list.max().alias("fixed_col_max"),
58+
]
59+
)
60+
assert result.to_pydict() == {
61+
"col_sum": [],
62+
"col_mean": [],
63+
"col_min": [],
64+
"col_max": [],
65+
"fixed_col_sum": [],
66+
"fixed_col_mean": [],
67+
"fixed_col_min": [],
68+
"fixed_col_max": [],
69+
}
70+
71+
72+
def test_list_numeric_aggs_with_groupby():
73+
df = daft.from_pydict(
74+
{
75+
"group_col": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
76+
"id_col": [3, 1, 2, 2, 5, 4, None, 3, None, None, None, None],
77+
}
78+
)
79+
80+
# Group by and test aggregates.
81+
grouped_df = df.groupby("group_col").agg(daft.col("id_col").agg_list().alias("ids_col"))
82+
result = grouped_df.select(
83+
col("group_col"),
84+
col("ids_col").list.sum().alias("ids_col_sum"),
85+
col("ids_col").list.mean().alias("ids_col_mean"),
86+
col("ids_col").list.min().alias("ids_col_min"),
87+
col("ids_col").list.max().alias("ids_col_max"),
88+
).sort("group_col", desc=False)
89+
result_dict = result.to_pydict()
90+
expected = {
91+
"group_col": [1, 2, 3],
92+
"ids_col_sum": [8, 12, None],
93+
"ids_col_mean": [2.0, 4.0, None],
94+
"ids_col_min": [1, 3, None],
95+
"ids_col_max": [3, 5, None],
96+
}
97+
assert result_dict == expected
98+
99+
# Cast to fixed size list, group by, and test aggregates.
100+
grouped_df = grouped_df.with_column("ids_col", col("ids_col").cast(DataType.fixed_size_list(DataType.int64(), 4)))
101+
result = grouped_df.select(
102+
col("group_col"),
103+
col("ids_col").list.sum().alias("ids_col_sum"),
104+
col("ids_col").list.mean().alias("ids_col_mean"),
105+
col("ids_col").list.min().alias("ids_col_min"),
106+
col("ids_col").list.max().alias("ids_col_max"),
107+
).sort("group_col", desc=False)
108+
result_dict = result.to_pydict()
109+
assert result_dict == expected

0 commit comments

Comments
 (0)