Skip to content

Commit 8cf1e02

Browse files
authored
fix: Maintain float32 type in partitioned group-by (#22340)
1 parent 33f04bd commit 8cf1e02

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

crates/polars-expr/src/expressions/aggregation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ impl PartitionedAggregation for AggregationExpr {
580580
let mask = agg_count.equal(0 as IdxSize);
581581
let agg_count = agg_count.set(&mask, None).unwrap().into_series();
582582

583-
let agg_s = &agg_s / &agg_count;
583+
let agg_s = &agg_s / &agg_count.cast(agg_s.dtype()).unwrap();
584584
Ok(agg_s?.with_name(new_name).into_column())
585585
},
586586
_ => Ok(Column::full_null(

py-polars/tests/unit/operations/test_group_by.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,3 +1251,20 @@ def test_group_by_cse_dup_key_alias_22238() -> None:
12511251
pl.DataFrame({"a": [1, 2], "a_with_alias": [1, 2], "x": [11, 5]}),
12521252
check_row_order=False,
12531253
)
1254+
1255+
1256+
def test_group_by_22328() -> None:
1257+
N = 20
1258+
1259+
df1 = pl.select(
1260+
x=pl.repeat(1, N // 2).append(pl.repeat(2, N // 2)).shuffle(),
1261+
y=pl.lit(3.0, pl.Float32),
1262+
).lazy()
1263+
1264+
df2 = pl.select(x=pl.repeat(4, N)).lazy()
1265+
1266+
assert (
1267+
df2.join(df1.group_by("x").mean().with_columns(z="y"), how="left", on="x")
1268+
.with_columns(pl.col("z").fill_null(0))
1269+
.collect()
1270+
).shape == (20, 3)

0 commit comments

Comments
 (0)