Skip to content

Commit 9377966

Browse files
committed
fix(duckdb): thread udf parameters through
1 parent 3f5d090 commit 9377966

File tree

4 files changed

+69
-48
lines changed

4 files changed

+69
-48
lines changed

ibis/backends/tests/errors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,12 @@
127127
from psycopg2.errors import ProgrammingError as PsycoPg2ProgrammingError
128128
from psycopg2.errors import SyntaxError as PsycoPg2SyntaxError
129129
from psycopg2.errors import UndefinedObject as PsycoPg2UndefinedObject
130-
from psycopg2.errors import UniqueViolation as PsycoPg2UniqueViolation
131130
except ImportError:
132131
PsycoPg2SyntaxError = PsycoPg2IndeterminateDatatype = (
133132
PsycoPg2InvalidTextRepresentation
134133
) = PsycoPg2DivisionByZero = PsycoPg2InternalError = PsycoPg2ProgrammingError = (
135134
PsycoPg2OperationalError
136-
) = PsycoPg2UndefinedObject = PsycoPg2ArraySubscriptError = PsycoPg2UniqueViolation = None
135+
) = PsycoPg2UndefinedObject = PsycoPg2ArraySubscriptError = None
137136

138137
try:
139138
from psycopg.errors import ArraySubscriptError as PsycoPgArraySubscriptError

ibis/backends/tests/test_impure.py

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,20 @@
22

33
import sys
44

5-
import pandas.testing as tm
65
import pytest
76

87
import ibis
98
import ibis.common.exceptions as com
109
from ibis import _
11-
from ibis.backends.tests.errors import (
12-
PsycoPg2InternalError,
13-
Py4JJavaError,
14-
PyDruidProgrammingError,
15-
)
10+
from ibis.backends.tests.errors import Py4JJavaError
11+
12+
tm = pytest.importorskip("pandas.testing")
13+
14+
pytestmark = pytest.mark.xdist_group("impure")
1615

1716
no_randoms = [
1817
pytest.mark.notimpl(
19-
["dask", "pandas", "polars"], raises=com.OperationNotDefinedError
20-
),
21-
pytest.mark.notimpl("druid", raises=PyDruidProgrammingError),
22-
pytest.mark.notyet(
23-
"risingwave",
24-
raises=PsycoPg2InternalError,
25-
reason="function random() does not exist",
18+
["polars", "druid", "risingwave"], raises=com.OperationNotDefinedError
2619
),
2720
]
2821

@@ -32,19 +25,16 @@
3225
[
3326
"bigquery",
3427
"clickhouse",
35-
"dask",
3628
"druid",
3729
"exasol",
3830
"impala",
3931
"mssql",
4032
"mysql",
4133
"oracle",
42-
"pandas",
4334
"trino",
4435
"risingwave",
4536
]
4637
),
47-
pytest.mark.notimpl("pyspark", reason="only supports pandas UDFs"),
4838
pytest.mark.notyet(
4939
"flink",
5040
condition=sys.version_info >= (3, 11),
@@ -55,16 +45,7 @@
5545

5646
no_uuids = [
5747
pytest.mark.notimpl(
58-
[
59-
"druid",
60-
"exasol",
61-
"oracle",
62-
"polars",
63-
"pyspark",
64-
"risingwave",
65-
"pandas",
66-
"dask",
67-
],
48+
["druid", "exasol", "oracle", "polars", "pyspark", "risingwave"],
6849
raises=com.OperationNotDefinedError,
6950
),
7051
pytest.mark.notyet("mssql", reason="Unrelated bug: Incorrect syntax near '('"),
@@ -82,11 +63,7 @@ def my_random(x: float) -> float:
8263
mark_impures = pytest.mark.parametrize(
8364
"impure",
8465
[
85-
pytest.param(
86-
lambda _: ibis.random(),
87-
marks=no_randoms,
88-
id="random",
89-
),
66+
pytest.param(lambda _: ibis.random(), marks=no_randoms, id="random"),
9067
pytest.param(
9168
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
9269
marks=[
@@ -107,6 +84,7 @@ def my_random(x: float) -> float:
10784
)
10885

10986

87+
# You can work around this by .cache()ing the table.
11088
@pytest.mark.notyet("sqlite", reason="instances are uncorrelated")
11189
@mark_impures
11290
def test_impure_correlated(alltypes, impure):
@@ -120,14 +98,12 @@ def test_impure_correlated(alltypes, impure):
12098
# t AS (SELECT random() AS common)
12199
# SELECT common as x, common as y FROM t
122100
# Then both x and y should have the same value.
123-
df = (
124-
alltypes.select(common=impure(alltypes))
125-
.select(x=_.common, y=_.common)
126-
.execute()
127-
)
101+
expr = alltypes.select(common=impure(alltypes)).select(x=_.common, y=_.common)
102+
df = expr.execute()
128103
tm.assert_series_equal(df.x, df.y, check_names=False)
129104

130105

106+
# You can work around this by .cache()ing the table.
131107
@pytest.mark.notyet("sqlite", reason="instances are uncorrelated")
132108
@mark_impures
133109
def test_chained_selections(alltypes, impure):
@@ -153,9 +129,7 @@ def test_chained_selections(alltypes, impure):
153129
lambda _: ibis.random(),
154130
marks=[
155131
*no_randoms,
156-
pytest.mark.notyet(
157-
["impala", "trino"], reason="instances are correlated"
158-
),
132+
pytest.mark.notyet(["impala"], reason="instances are correlated"),
159133
],
160134
id="random",
161135
),
@@ -164,24 +138,24 @@ def test_chained_selections(alltypes, impure):
164138
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
165139
marks=[
166140
*no_uuids,
167-
pytest.mark.notyet(
168-
["mysql", "trino"], reason="instances are correlated"
169-
),
141+
pytest.mark.notyet(["mysql"], reason="instances are correlated"),
170142
],
171143
id="uuid",
172144
),
173145
pytest.param(
174146
lambda table: my_random(table.float_col),
175147
marks=[
176148
*no_udfs,
177-
pytest.mark.notyet("duckdb", reason="instances are correlated"),
149+
# no "impure" argument for pyspark yet
150+
pytest.mark.notimpl("pyspark"),
178151
],
179152
id="udf",
180153
),
181154
],
182155
)
183156

184157

158+
# You can work around this by doing .select().cache().select()
185159
@pytest.mark.notyet(["clickhouse"], reason="instances are correlated")
186160
@impure_params_uncorrelated
187161
def test_impure_uncorrelated_different_id(alltypes, impure):
@@ -191,15 +165,57 @@ def test_impure_uncorrelated_different_id(alltypes, impure):
191165
# eg if you look at the following SQL:
192166
# select random() as x, random() as y
193167
# Then x and y should be uncorrelated.
194-
df = alltypes.select(x=impure(alltypes), y=impure(alltypes)).execute()
168+
expr = alltypes.select(x=impure(alltypes), y=impure(alltypes))
169+
df = expr.execute()
195170
assert (df.x != df.y).any()
196171

197172

173+
# You can work around this by doing .select().cache().select()
198174
@pytest.mark.notyet(["clickhouse"], reason="instances are correlated")
199175
@impure_params_uncorrelated
200176
def test_impure_uncorrelated_same_id(alltypes, impure):
201177
# Similar to test_impure_uncorrelated_different_id, but the two expressions
202178
# have the same ID. Still, they should be uncorrelated.
203179
common = impure(alltypes)
204-
df = alltypes.select(x=common, y=common).execute()
180+
expr = alltypes.select(x=common, y=common)
181+
df = expr.execute()
205182
assert (df.x != df.y).any()
183+
184+
185+
@pytest.mark.notyet(
186+
[
187+
"duckdb",
188+
"clickhouse",
189+
"datafusion",
190+
"mysql",
191+
"impala",
192+
"mssql",
193+
"trino",
194+
"flink",
195+
"bigquery",
196+
],
197+
raises=AssertionError,
198+
reason="instances are not correlated but ideally they would be",
199+
)
200+
@pytest.mark.notyet(
201+
["sqlite"],
202+
raises=AssertionError,
203+
reason="instances are *sometimes* correlated but ideally they would always be",
204+
strict=False,
205+
)
206+
@pytest.mark.notimpl(
207+
["polars", "risingwave", "druid", "exasol", "oracle", "pyspark"],
208+
raises=com.OperationNotDefinedError,
209+
)
210+
def test_self_join_with_generated_keys(con):
211+
# Even with CTEs in the generated SQL, the backends still
212+
# materialize a new value every time it is referenced.
213+
# This isn't ideal behavior, but there is nothing we can do about it
214+
# on the ibis side. The best you can do is to .cache() the table
215+
# right after you assign the uuid().
216+
# https://github.com/ibis-project/ibis/pull/9014#issuecomment-2399449665
217+
left = ibis.memtable({"idx": list(range(5))}).mutate(key=ibis.uuid())
218+
right = left.filter(left.idx < 3)
219+
expr = left.join(right, "key")
220+
result = con.execute(expr.count())
221+
assert result == 3

ibis/expr/decompile.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
ops.StringContains: "contains",
3939
ops.StringSQLILike: "ilike",
4040
ops.StringSQLLike: "like",
41-
ops.TimestampNow: "now",
4241
}
4342

4443

@@ -84,6 +83,11 @@ def translate(op, *args, **kwargs):
8483
raise NotImplementedError(op)
8584

8685

86+
@translate.register(ops.TimestampNow)
87+
def now(_):
88+
return "ibis.now()"
89+
90+
8791
@translate.register(ops.Value)
8892
def value(op, *args, **kwargs):
8993
method = _get_method_name(op)

ibis/expr/operations/generic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,15 @@ class TimestampNow(Constant):
193193
"""Return the current timestamp."""
194194

195195
dtype = dt.timestamp
196+
shape = ds.scalar
196197

197198

198199
@public
199200
class DateNow(Constant):
200201
"""Return the current date."""
201202

202203
dtype = dt.date
204+
shape = ds.scalar
203205

204206

205207
@public

0 commit comments

Comments
 (0)