File tree 2 files changed +15
-3
lines changed
2 files changed +15
-3
lines changed Original file line number Diff line number Diff line change 2
2
3
3
import base64
4
4
import datetime
5
+ from typing import Literal
5
6
6
7
import numpy as np
7
8
from multipledispatch import Dispatcher
@@ -464,6 +465,18 @@ def _array_agg(t, op):
464
465
return f"ARRAY_AGG({ t .translate (arg )} IGNORE NULLS)"
465
466
466
467
468
+ def _arg_min_max (sort_dir : Literal ["ASC" , "DESC" ]):
469
+ def translate (t , op : ops .ArgMin | ops .ArgMax ) -> str :
470
+ arg = op .arg
471
+ if (where := op .where ) is not None :
472
+ arg = ops .Where (where , arg , None )
473
+ arg = t .translate (arg )
474
+ key = t .translate (op .key )
475
+ return f"ARRAY_AGG({ arg } IGNORE NULLS ORDER BY { key } { sort_dir } LIMIT 1)[SAFE_OFFSET(0)]"
476
+
477
+ return translate
478
+
479
+
467
480
OPERATION_REGISTRY = {
468
481
** operation_registry ,
469
482
# Literal
@@ -587,6 +600,8 @@ def _array_agg(t, op):
587
600
ops .FloorDivide : _floor_divide ,
588
601
ops .IsNan : _is_nan ,
589
602
ops .IsInf : _is_inf ,
603
+ ops .ArgMin : _arg_min_max ("ASC" ),
604
+ ops .ArgMax : _arg_min_max ("DESC" ),
590
605
}
591
606
592
607
_invalid_operations = {
Original file line number Diff line number Diff line change @@ -84,7 +84,6 @@ def mean_udf(s):
84
84
]
85
85
86
86
argidx_not_grouped_marks = [
87
- "bigquery" ,
88
87
"datafusion" ,
89
88
"impala" ,
90
89
"mysql" ,
@@ -305,7 +304,6 @@ def mean_and_std(v):
305
304
id = 'argmin' ,
306
305
marks = pytest .mark .notyet (
307
306
[
308
- "bigquery" ,
309
307
"impala" ,
310
308
"mysql" ,
311
309
"postgres" ,
@@ -324,7 +322,6 @@ def mean_and_std(v):
324
322
id = 'argmax' ,
325
323
marks = pytest .mark .notyet (
326
324
[
327
- "bigquery" ,
328
325
"impala" ,
329
326
"mysql" ,
330
327
"postgres" ,
You can’t perform that action at this time.
0 commit comments