Skip to content

Commit 51335ed

Browse files
jcristcpcloud
andauthored
fix(sql): standardize NULL handling of argmin/argmax (#10227)
Co-authored-by: Phillip Cloud <[email protected]>
1 parent 428d1a3 commit 51335ed

20 files changed

+89
-49
lines changed

ibis/backends/polars/compiler.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,20 +1256,15 @@ def execute_hash(op, **kw):
12561256

12571257

12581258
def _arg_min_max(op, func, **kw):
1259-
key = op.key
1260-
arg = op.arg
1261-
1262-
if (op_where := op.where) is not None:
1263-
key = ops.IfElse(op_where, key, None)
1264-
arg = ops.IfElse(op_where, arg, None)
1259+
key = translate(op.key, **kw)
1260+
arg = translate(op.arg, **kw)
12651261

1266-
translate_arg = translate(arg, **kw)
1267-
translate_key = translate(key, **kw)
1262+
if op.where is not None:
1263+
where = translate(op.where, **kw)
1264+
arg = arg.filter(where)
1265+
key = key.filter(where)
12681266

1269-
not_null_mask = translate_arg.is_not_null() & translate_key.is_not_null()
1270-
return translate_arg.filter(not_null_mask).get(
1271-
func(translate_key.filter(not_null_mask))
1272-
)
1267+
return arg.get(func(key))
12731268

12741269

12751270
@translate.register(ops.ArgMax)

ibis/backends/sql/compilers/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,6 @@ class SQLGlotCompiler(abc.ABC):
306306
ops.All: "bool_and",
307307
ops.Any: "bool_or",
308308
ops.ApproxCountDistinct: "approx_distinct",
309-
ops.ArgMax: "max_by",
310-
ops.ArgMin: "min_by",
311309
ops.ArrayContains: "array_contains",
312310
ops.ArrayFlatten: "flatten",
313311
ops.ArrayLength: "array_size",

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ class BigQueryCompiler(SQLGlotCompiler):
200200
ops.TimeFromHMS: "time_from_parts",
201201
ops.TimestampNow: "current_timestamp",
202202
ops.ExtractHost: "net.host",
203+
ops.ArgMin: "min_by",
204+
ops.ArgMax: "max_by",
203205
}
204206

205207
def to_sqlglot(

ibis/backends/sql/compilers/clickhouse.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
6262
ops.ApproxCountDistinct: "uniqHLL12",
6363
ops.ApproxMedian: "median",
6464
ops.Arbitrary: "any",
65-
ops.ArgMax: "argMax",
66-
ops.ArgMin: "argMin",
6765
ops.ArrayContains: "has",
6866
ops.ArrayFlatten: "arrayFlatten",
6967
ops.ArrayIntersect: "arrayIntersect",
@@ -673,6 +671,18 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
673671
)
674672
return self.agg.anyLast(arg, where=where, order_by=order_by)
675673

674+
def visit_ArgMin(self, op, *, arg, key, where):
675+
return sge.Dot(
676+
this=self.agg.argMin(self.f.tuple(arg), key, where=where),
677+
expression=sge.convert(1),
678+
)
679+
680+
def visit_ArgMax(self, op, *, arg, key, where):
681+
return sge.Dot(
682+
this=self.agg.argMax(self.f.tuple(arg), key, where=where),
683+
expression=sge.convert(1),
684+
)
685+
676686
def visit_CountDistinctStar(
677687
self, op: ops.CountDistinctStar, *, where, **_: Any
678688
) -> str:

ibis/backends/sql/compilers/datafusion.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class DataFusionCompiler(SQLGlotCompiler):
3030
post_rewrites = (split_select_distinct_with_order_by,)
3131

3232
UNSUPPORTED_OPS = (
33-
ops.ArgMax,
34-
ops.ArgMin,
3533
ops.ArrayDistinct,
3634
ops.ArrayFilter,
3735
ops.ArrayMap,
@@ -457,6 +455,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
457455
where = cond if where is None else sge.And(this=cond, expression=where)
458456
return self.agg.last_value(arg, where=where, order_by=order_by)
459457

458+
def visit_ArgMin(self, op, *, arg, key, where):
459+
return self.agg.first_value(arg, where=where, order_by=[sge.Ordered(this=key)])
460+
461+
def visit_ArgMax(self, op, *, arg, key, where):
462+
return self.agg.first_value(
463+
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
464+
)
465+
460466
def visit_Aggregate(self, op, *, parent, groups, metrics):
461467
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
462468
quoted = self.quoted

ibis/backends/sql/compilers/druid.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ class DruidCompiler(SQLGlotCompiler):
2525

2626
UNSUPPORTED_OPS = (
2727
ops.ApproxMedian,
28-
ops.ArgMax,
29-
ops.ArgMin,
3028
ops.ArrayDistinct,
3129
ops.ArrayFilter,
3230
ops.ArrayFlatten,

ibis/backends/sql/compilers/duckdb.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
543543
where = cond if where is None else sge.And(this=cond, expression=where)
544544
return self.agg.last(arg, where=where, order_by=order_by)
545545

546+
def visit_ArgMin(self, op, *, arg, key, where):
547+
return self.agg.first(arg, where=where, order_by=[sge.Ordered(this=key)])
548+
549+
def visit_ArgMax(self, op, *, arg, key, where):
550+
return self.agg.first(
551+
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
552+
)
553+
546554
def visit_Quantile(self, op, *, arg, quantile, where):
547555
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
548556
funcname = f"percentile_{suffix}"

ibis/backends/sql/compilers/exasol.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ class ExasolCompiler(SQLGlotCompiler):
3232

3333
UNSUPPORTED_OPS = (
3434
ops.AnalyticVectorizedUDF,
35-
ops.ArgMax,
36-
ops.ArgMin,
3735
ops.ArrayDistinct,
3836
ops.ArrayFilter,
3937
ops.ArrayFlatten,

ibis/backends/sql/compilers/flink.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ class FlinkCompiler(SQLGlotCompiler):
6969
UNSUPPORTED_OPS = (
7070
ops.AnalyticVectorizedUDF,
7171
ops.ApproxMedian,
72-
ops.ArgMax,
73-
ops.ArgMin,
7472
ops.ArrayFlatten,
7573
ops.ArrayStringJoin,
7674
ops.Correlation,

ibis/backends/sql/compilers/impala.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class ImpalaCompiler(SQLGlotCompiler):
3030
}
3131

3232
UNSUPPORTED_OPS = (
33-
ops.ArgMax,
34-
ops.ArgMin,
3533
ops.ArrayPosition,
3634
ops.Array,
3735
ops.Covariance,

ibis/backends/sql/compilers/mssql.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ class MSSQLCompiler(SQLGlotCompiler):
8282

8383
UNSUPPORTED_OPS = (
8484
ops.ApproxMedian,
85-
ops.ArgMax,
86-
ops.ArgMin,
8785
ops.Array,
8886
ops.ArrayDistinct,
8987
ops.ArrayFlatten,

ibis/backends/sql/compilers/mysql.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ def POS_INF(self):
6565
NEG_INF = POS_INF
6666
UNSUPPORTED_OPS = (
6767
ops.ApproxMedian,
68-
ops.ArgMax,
69-
ops.ArgMin,
7068
ops.Array,
7169
ops.ArrayFlatten,
7270
ops.ArrayMap,

ibis/backends/sql/compilers/oracle.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ class OracleCompiler(SQLGlotCompiler):
5151
}
5252

5353
UNSUPPORTED_OPS = (
54-
ops.ArgMax,
55-
ops.ArgMin,
5654
ops.Array,
5755
ops.ArrayFlatten,
5856
ops.ArrayMap,

ibis/backends/sql/compilers/postgres.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,23 +192,21 @@ def visit_Mode(self, op, *, arg, where):
192192
expr = sge.Filter(this=expr, expression=sge.Where(this=where))
193193
return expr
194194

195-
def visit_ArgMinMax(self, op, *, arg, key, where, desc: bool):
196-
conditions = [arg.is_(sg.not_(NULL)), key.is_(sg.not_(NULL))]
197-
198-
if where is not None:
199-
conditions.append(where)
195+
def _argminmax(self, op, *, arg, key, where, desc: bool):
196+
cond = key.is_(sg.not_(NULL))
197+
where = cond if where is None else sge.And(this=cond, expression=where)
200198

201199
agg = self.agg.array_agg(
202200
sge.Ordered(this=sge.Order(this=arg, expressions=[key]), desc=desc),
203-
where=sg.and_(*conditions),
201+
where=where,
204202
)
205203
return sge.paren(agg, copy=False)[0]
206204

207205
def visit_ArgMin(self, op, *, arg, key, where):
208-
return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=False)
206+
return self._argminmax(op, arg=arg, key=key, where=where, desc=False)
209207

210208
def visit_ArgMax(self, op, *, arg, key, where):
211-
return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=True)
209+
return self._argminmax(op, arg=arg, key=key, where=where, desc=True)
212210

213211
def visit_Sum(self, op, *, arg, where):
214212
arg = (

ibis/backends/sql/compilers/pyspark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class PySparkCompiler(SQLGlotCompiler):
7070
}
7171

7272
SIMPLE_OPS = {
73+
ops.ArgMax: "max_by",
74+
ops.ArgMin: "min_by",
7375
ops.ArrayDistinct: "array_distinct",
7476
ops.ArrayFlatten: "flatten",
7577
ops.ArrayIntersect: "array_intersect",

ibis/backends/sql/compilers/snowflake.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ class SnowflakeCompiler(SQLGlotCompiler):
106106
SIMPLE_OPS = {
107107
ops.All: "min",
108108
ops.Any: "max",
109+
ops.ArgMax: "max_by",
110+
ops.ArgMin: "min_by",
109111
ops.ArrayDistinct: "array_distinct",
110112
ops.ArrayFlatten: "array_flatten",
111113
ops.ArrayIndex: "get",

ibis/backends/sql/compilers/sqlite.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,7 @@ def visit_ArgMax(self, *args, **kwargs):
206206
return self._visit_arg_reduction("max", *args, **kwargs)
207207

208208
def _visit_arg_reduction(self, func, op, *, arg, key, where):
209-
cond = arg.is_(sg.not_(NULL))
210-
211-
if op.where is not None:
212-
cond = sg.and_(cond, where)
213-
214-
agg = self.agg[func](key, where=cond)
209+
agg = self.agg[func](key, where=where)
215210
return self.f.anon.json_extract(self.f.json_array(arg, agg), "$[0]")
216211

217212
def visit_UnwrapJSONString(self, op, *, arg):

ibis/backends/sql/compilers/trino.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class TrinoCompiler(SQLGlotCompiler):
6060

6161
SIMPLE_OPS = {
6262
ops.Arbitrary: "any_value",
63+
ops.ArgMax: "max_by",
64+
ops.ArgMin: "min_by",
6365
ops.Pi: "pi",
6466
ops.E: "e",
6567
ops.RegexReplace: "regexp_replace",

ibis/backends/tests/test_aggregation.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def mean_udf(s):
123123
]
124124

125125
argidx_not_grouped_marks = [
126-
"datafusion",
127126
"impala",
128127
"mysql",
129128
"mssql",
@@ -411,7 +410,6 @@ def mean_and_std(v):
411410
[
412411
"impala",
413412
"mysql",
414-
"datafusion",
415413
"mssql",
416414
"druid",
417415
"oracle",
@@ -431,7 +429,6 @@ def mean_and_std(v):
431429
[
432430
"impala",
433431
"mysql",
434-
"datafusion",
435432
"mssql",
436433
"druid",
437434
"oracle",
@@ -689,6 +686,39 @@ def test_first_last_ordered(alltypes, method, filtered, include_null):
689686
assert res == sol
690687

691688

689+
@pytest.mark.notimpl(
690+
[
691+
"druid",
692+
"exasol",
693+
"flink",
694+
"impala",
695+
"mssql",
696+
"mysql",
697+
"oracle",
698+
],
699+
raises=com.OperationNotDefinedError,
700+
)
701+
@pytest.mark.parametrize("method", ["argmin", "argmax"])
702+
@pytest.mark.parametrize("filtered", [True, False], ids=["filtered", "unfiltered"])
703+
@pytest.mark.parametrize("null_result", [True, False], ids=["null", "non-null"])
704+
def test_argmin_argmax(alltypes, method, filtered, null_result):
705+
t = alltypes.mutate(by_col=_.int_col.nullif(0).nullif(9), val_col=10 * _.int_col)
706+
707+
if filtered:
708+
where = _.int_col != (1 if method == "argmin" else 8)
709+
sol = 20 if method == "argmin" else 70
710+
else:
711+
where = None
712+
sol = 10 if method == "argmin" else 80
713+
714+
if null_result:
715+
t = t.mutate(val_col=_.val_col.nullif(sol))
716+
717+
expr = getattr(t.val_col, method)("by_col", where=where)
718+
res = expr.execute()
719+
assert pd.isna(res) if null_result else res == sol
720+
721+
692722
@pytest.mark.notimpl(
693723
[
694724
"impala",

ibis/expr/types/generic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,6 +1766,9 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar:
17661766
def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
17671767
"""Return the value of `self` that maximizes `key`.
17681768
1769+
If more than one value maximizes `key`, the returned value is backend
1770+
specific. The result may be `NULL`.
1771+
17691772
Parameters
17701773
----------
17711774
key
@@ -1801,6 +1804,9 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
18011804
def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
18021805
"""Return the value of `self` that minimizes `key`.
18031806
1807+
If more than one value minimizes `key`, the returned value is backend
1808+
specific. The result may be `NULL`.
1809+
18041810
Parameters
18051811
----------
18061812
key

0 commit comments

Comments
 (0)