Skip to content

Commit e58ca73

Browse files
committed
fix(backends): ensure select after filter works
1 parent b643544 commit e58ca73

File tree

6 files changed

+65
-36
lines changed

6 files changed

+65
-36
lines changed

ibis/backends/pandas/execution/generic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,10 @@ def execute_series_notnnull(op, data, **kwargs):
969969

970970
@execute_node.register(ops.IsNan, (pd.Series, floating_types))
971971
def execute_isnan(op, data, **kwargs):
972-
return np.isnan(data)
972+
try:
973+
return np.isnan(data)
974+
except (TypeError, ValueError):
975+
return data != data
973976

974977

975978
@execute_node.register(ops.IsInf, (pd.Series, floating_types))

ibis/backends/tests/test_generic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,22 @@ def test_where_column(backend, alltypes, df):
662662
)
663663

664664
backend.assert_series_equal(result, expected)
665+
666+
667+
def test_select_filter(backend, alltypes, df):
668+
t = alltypes
669+
670+
expr = t.select("int_col").filter(t.string_col == "4")
671+
result = expr.execute()
672+
673+
expected = df.loc[df.string_col == "4", ["int_col"]].reset_index(drop=True)
674+
backend.assert_frame_equal(result, expected)
675+
676+
677+
def test_select_filter_select(backend, alltypes, df):
678+
t = alltypes
679+
expr = t.select("int_col").filter(t.string_col == "4").int_col
680+
result = expr.execute().rename("int_col")
681+
682+
expected = df.loc[df.string_col == "4", "int_col"].reset_index(drop=True)
683+
backend.assert_series_equal(result, expected)

ibis/expr/analysis.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
import ibis.expr.operations as ops
1313
import ibis.expr.types as ir
1414
from ibis import util
15-
from ibis.common.exceptions import ExpressionError, IbisTypeError
15+
from ibis.common.exceptions import (
16+
ExpressionError,
17+
IbisTypeError,
18+
IntegrityError,
19+
)
1620
from ibis.expr.window import window
1721

1822
# ---------------------------------------------------------------------
@@ -343,21 +347,31 @@ def _filter_selection(expr, predicates):
343347
# the parent tables in the join being projected
344348

345349
op = expr.op()
346-
if not op.blocks():
347-
# Potential fusion opportunity. The predicates may need to be
348-
# rewritten in terms of the child table. This prevents the broken
349-
# ref issue (described in more detail in #59)
350+
# Potential fusion opportunity. The predicates may need to be
351+
# rewritten in terms of the child table. This prevents the broken
352+
# ref issue (described in more detail in #59)
353+
try:
350354
simplified_predicates = tuple(
351355
sub_for(predicate, [(expr, op.table)])
352356
if not is_reduction(predicate)
353357
else predicate
354358
for predicate in predicates
355359
)
356-
357-
if shares_all_roots(simplified_predicates, op.table):
360+
except IntegrityError:
361+
pass
362+
else:
363+
if shares_all_roots(simplified_predicates, op.table) and not any(
364+
# we can't push down filters on unnest because unnest changes the
365+
# shape and potential values of the data: unnest can potentially
366+
# produce NULLs
367+
#
368+
# the getattr shenanigans is to handle Alias
369+
isinstance(getattr(sel.op(), "arg", sel).op(), ops.Unnest)
370+
for sel in op.selections
371+
):
358372
result = ops.Selection(
359373
op.table,
360-
[],
374+
selections=op.selections,
361375
predicates=op.predicates + simplified_predicates,
362376
sort_keys=op.sort_keys,
363377
)

ibis/tests/expr/test_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def test_select_filter_mutate_fusion():
284284
#
285285
# eventually we will bring this back, but we're trading off the ability
286286
# to remove materialize for some performance in the short term
287-
assert len(first_selection.op().selections) == 0
287+
assert len(first_selection.op().selections) == 1
288288
assert len(first_selection.op().predicates) == 1
289289

290290

ibis/tests/sql/test_compiler.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -229,19 +229,15 @@ def test_agg_filter():
229229
FROM my_table
230230
),
231231
t1 AS (
232-
SELECT `a`, `b2`
232+
SELECT t0.`a`, t0.`b2`
233233
FROM t0
234+
WHERE t0.`a` < 100
234235
)
235-
SELECT t2.*
236-
FROM (
237-
SELECT t1.*
238-
FROM t1
239-
WHERE t1.`a` < 100
240-
) t2
241-
WHERE t2.`a` = (
242-
SELECT max(t1.`a`) AS `blah`
236+
SELECT t1.*
237+
FROM t1
238+
WHERE t1.`a` = (
239+
SELECT max(`a`) AS `blah`
243240
FROM t1
244-
WHERE t1.`a` < 100
245241
)"""
246242
assert result == expected
247243

@@ -259,19 +255,15 @@ def test_agg_filter_with_alias():
259255
FROM my_table
260256
),
261257
t1 AS (
262-
SELECT `a`, `b2`
258+
SELECT t0.`a`, t0.`b2`
263259
FROM t0
260+
WHERE t0.`a` < 100
264261
)
265-
SELECT t2.*
266-
FROM (
267-
SELECT t1.*
268-
FROM t1
269-
WHERE t1.`a` < 100
270-
) t2
271-
WHERE t2.`a` = (
272-
SELECT max(t1.`a`) AS `blah`
262+
SELECT t1.*
263+
FROM t1
264+
WHERE t1.`a` = (
265+
SELECT max(`a`) AS `blah`
273266
FROM t1
274-
WHERE t1.`a` < 100
275267
)"""
276268
assert result == expected
277269

@@ -358,8 +350,8 @@ def test_subquery_where_location():
358350
FROM (
359351
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
360352
FROM alltypes
353+
WHERE `timestamp_col` < '20140101'
361354
) t1
362-
WHERE `timestamp_col` < '20140101'
363355
GROUP BY 1
364356
) t0"""
365357
assert result == expected

ibis/tests/sql/test_select_sql.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,16 +1488,17 @@ def test_filter_predicates():
14881488
expr = projected
14891489

14901490
expected = """\
1491-
SELECT t0.*
1491+
SELECT *
14921492
FROM (
14931493
SELECT *
14941494
FROM (
14951495
SELECT *
14961496
FROM t
14971497
WHERE (lower(`color`) LIKE '%de%') AND
14981498
(locate('de', lower(`color`)) - 1 >= 0)
1499-
) t2
1500-
) t0
1501-
WHERE regexp_like(lower(t0.`color`), '.*ge.*')"""
1499+
) t1
1500+
WHERE regexp_like(lower(`color`), '.*ge.*')
1501+
) t0"""
15021502

1503-
assert Compiler.to_sql(expr) == expected
1503+
result = Compiler.to_sql(expr)
1504+
assert result == expected

0 commit comments

Comments
 (0)