Skip to content

Commit 40c5f0d

Browse files
committed
feat(bigquery): implement argmin and argmax
1 parent 4df9f8b commit 40c5f0d

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

ibis/backends/bigquery/registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import datetime
5+
from typing import Literal
56

67
import numpy as np
78
from multipledispatch import Dispatcher
@@ -464,6 +465,18 @@ def _array_agg(t, op):
464465
return f"ARRAY_AGG({t.translate(arg)} IGNORE NULLS)"
465466

466467

468+
def _arg_min_max(sort_dir: Literal["ASC", "DESC"]):
469+
def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
470+
arg = op.arg
471+
if (where := op.where) is not None:
472+
arg = ops.Where(where, arg, None)
473+
arg = t.translate(arg)
474+
key = t.translate(op.key)
475+
return f"ARRAY_AGG({arg} IGNORE NULLS ORDER BY {key} {sort_dir} LIMIT 1)[SAFE_OFFSET(0)]"
476+
477+
return translate
478+
479+
467480
OPERATION_REGISTRY = {
468481
**operation_registry,
469482
# Literal
@@ -587,6 +600,8 @@ def _array_agg(t, op):
587600
ops.FloorDivide: _floor_divide,
588601
ops.IsNan: _is_nan,
589602
ops.IsInf: _is_inf,
603+
ops.ArgMin: _arg_min_max("ASC"),
604+
ops.ArgMax: _arg_min_max("DESC"),
590605
}
591606

592607
_invalid_operations = {

ibis/backends/tests/test_aggregation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def mean_udf(s):
8484
]
8585

8686
argidx_not_grouped_marks = [
87-
"bigquery",
8887
"datafusion",
8988
"impala",
9089
"mysql",
@@ -305,7 +304,6 @@ def mean_and_std(v):
305304
id='argmin',
306305
marks=pytest.mark.notyet(
307306
[
308-
"bigquery",
309307
"impala",
310308
"mysql",
311309
"postgres",
@@ -324,7 +322,6 @@ def mean_and_std(v):
324322
id='argmax',
325323
marks=pytest.mark.notyet(
326324
[
327-
"bigquery",
328325
"impala",
329326
"mysql",
330327
"postgres",

0 commit comments

Comments
 (0)