Skip to content

Commit ee8dbab

Browse files
cpcloudkszucs
authored andcommitted
feat(postgres): implement ops.Arbitrary
1 parent 9a19302 commit ee8dbab

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

ibis/backends/postgres/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import contextlib
56
from typing import TYPE_CHECKING, Iterable, Literal
67

78
import sqlalchemy as sa
@@ -115,6 +116,37 @@ def do_connect(
115116
alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool
116117
)
117118

119+
# define first/last aggs for ops.Arbitrary
120+
#
121+
# ignore exceptions so the rest of ibis still works: a user may not
122+
# have permissions to define funtions and/or aggregates
123+
with engine.begin() as con, contextlib.suppress(Exception):
124+
# adapted from https://wiki.postgresql.org/wiki/First/last_%28aggregate%29
125+
con.exec_driver_sql(
126+
"""\
127+
CREATE OR REPLACE FUNCTION public._ibis_first_agg (anyelement, anyelement)
128+
RETURNS anyelement
129+
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS
130+
'SELECT $1';
131+
132+
CREATE OR REPLACE AGGREGATE public._ibis_first (anyelement) (
133+
SFUNC = public._ibis_first_agg,
134+
STYPE = anyelement,
135+
PARALLEL = safe
136+
);
137+
138+
CREATE OR REPLACE FUNCTION public._ibis_last_agg (anyelement, anyelement)
139+
RETURNS anyelement
140+
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS
141+
'SELECT $2';
142+
143+
CREATE OR REPLACE AGGREGATE public._ibis_last (anyelement) (
144+
SFUNC = public._ibis_last_agg,
145+
STYPE = anyelement,
146+
PARALLEL = safe
147+
);"""
148+
)
149+
118150
@sa.event.listens_for(engine, "connect")
119151
def connect(dbapi_connection, connection_record):
120152
with dbapi_connection.cursor() as cur:

ibis/backends/postgres/registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,15 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
479479
return translate
480480

481481

482+
def _arbitrary(t, op):
483+
if (how := op.how) == "heavy":
484+
raise com.UnsupportedOperationError(
485+
f"postgres backend doesn't support how={how!r} for the arbitrary() aggregate"
486+
)
487+
func = getattr(sa.func.public, f"_ibis_{op.how}")
488+
return t._reduction(func, op)
489+
490+
482491
operation_registry.update(
483492
{
484493
ops.Literal: _literal,
@@ -629,5 +638,6 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
629638
ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)),
630639
ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)),
631640
ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2),
641+
ops.Arbitrary: _arbitrary,
632642
}
633643
)

ibis/backends/tests/test_aggregation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,6 @@ def mean_and_std(v):
553553
marks=pytest.mark.notimpl(
554554
[
555555
'impala',
556-
'postgres',
557556
'mysql',
558557
'sqlite',
559558
'polars',
@@ -571,7 +570,6 @@ def mean_and_std(v):
571570
marks=pytest.mark.notimpl(
572571
[
573572
'impala',
574-
'postgres',
575573
'mysql',
576574
'sqlite',
577575
'polars',
@@ -590,7 +588,6 @@ def mean_and_std(v):
590588
pytest.mark.notimpl(
591589
[
592590
'impala',
593-
'postgres',
594591
'mysql',
595592
'sqlite',
596593
'polars',
@@ -629,14 +626,18 @@ def mean_and_std(v):
629626
"impala",
630627
"mysql",
631628
"pandas",
632-
"postgres",
633629
"sqlite",
634630
"polars",
635631
"mssql",
636632
"druid",
637633
],
638634
raises=com.OperationNotDefinedError,
639635
),
636+
pytest.mark.notimpl(
637+
["postgres"],
638+
raises=com.UnsupportedOperationError,
639+
reason="how='heavy' not supported in the postgres backend",
640+
),
640641
pytest.mark.notimpl(
641642
["duckdb"],
642643
raises=com.UnsupportedOperationError,

0 commit comments

Comments
 (0)