Skip to content

Commit f83d84f

Browse files
committed
fix(polars): fix polars std/var to properly handle sample/population
1 parent 8717629 commit f83d84f

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

ibis/backends/polars/compiler.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -721,39 +721,59 @@ def struct_column(op, **kw):
721721
ops.All: "all",
722722
ops.Any: "any",
723723
ops.ApproxMedian: "median",
724-
ops.Arbitrary: "first",
725724
ops.Count: "count",
726725
ops.CountDistinct: "n_unique",
727-
ops.First: "first",
728-
ops.Last: "last",
729726
ops.Max: "max",
730727
ops.Mean: "mean",
731728
ops.Median: "median",
732729
ops.Min: "min",
733-
ops.StandardDev: "std",
734730
ops.Sum: "sum",
735-
ops.Variance: "var",
736731
}
737732

738-
for reduction in _reductions.keys():
739733

740-
@translate.register(reduction)
741-
def reduction(op, **kw):
742-
args = [
743-
translate(arg, **kw)
744-
for name, arg in zip(op.argnames, op.args)
745-
if name not in ("where", "how")
746-
]
734+
def execute_reduction(op, **kw):
735+
arg = translate(op.arg, **kw)
736+
737+
if op.where is not None:
738+
arg = arg.filter(translate(op.where, **kw))
739+
740+
method = _reductions[type(op)]
741+
742+
return getattr(arg, method)()
743+
744+
745+
for cls in _reductions:
746+
translate.register(cls, execute_reduction)
747+
748+
749+
@translate.register(ops.First)
750+
@translate.register(ops.Last)
751+
@translate.register(ops.Arbitrary)
752+
def execute_first_last(op, **kw):
753+
arg = translate(op.arg, **kw)
754+
755+
# polars doesn't ignore nulls by default for these methods
756+
predicate = arg.is_not_null()
757+
if op.where is not None:
758+
predicate &= translate(op.where, **kw)
759+
760+
arg = arg.filter(predicate)
761+
762+
return arg.last() if isinstance(op, ops.Last) else arg.first()
747763

748-
agg = _reductions[type(op)]
749764

750-
predicates = [arg.is_not_null() for arg in args]
751-
if (where := op.where) is not None:
752-
predicates.append(translate(where, **kw))
765+
@translate.register(ops.StandardDev)
766+
@translate.register(ops.Variance)
767+
def execute_std_var(op, **kw):
768+
arg = translate(op.arg, **kw)
769+
770+
if op.where is not None:
771+
arg = arg.filter(translate(op.where, **kw))
772+
773+
method = "std" if isinstance(op, ops.StandardDev) else "var"
774+
ddof = 0 if op.how == "pop" else 1
753775

754-
first, *rest = args
755-
method = operator.methodcaller(agg, *rest)
756-
return method(first.filter(reduce(operator.and_, predicates)))
776+
return getattr(arg, method)(ddof=ddof)
757777

758778

759779
@translate.register(ops.Mode)

ibis/backends/tests/test_aggregation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,12 @@ def test_reduction_ops(
606606
ibis_cond,
607607
pandas_cond,
608608
):
609+
# Operate on a subset of the data, since aggregations like var/std with
610+
# sample/population can be too numerically similar for a larger number of
611+
# rows.
612+
alltypes = alltypes.filter(alltypes.id < 1550)
613+
df = df[df.id < 1550]
614+
609615
expr = alltypes.agg(tmp=result_fn(alltypes, ibis_cond(alltypes))).tmp
610616
result = expr.execute().squeeze()
611617
expected = expected_fn(df, pandas_cond(df))

0 commit comments

Comments
 (0)