Skip to content

Commit 69b848a

Browse files
authored
feat(polars): add Intersection and Difference ops (#10623)
1 parent 43069bd commit 69b848a

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

ibis/backends/polars/compiler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,56 @@ def execute_union(op, **kw):
12561256
return result
12571257

12581258

1259+
@translate.register(ops.Intersection)
1260+
def execute_intersection(op, *, ctx, **kw):
1261+
left = gen_name("polars_intersect_left")
1262+
right = gen_name("polars_intersect_right")
1263+
1264+
ctx.register_many(
1265+
frames={
1266+
left: translate(op.left, ctx=ctx, **kw),
1267+
right: translate(op.right, ctx=ctx, **kw),
1268+
}
1269+
)
1270+
1271+
sql = (
1272+
sg.select(STAR)
1273+
.from_(sg.to_identifier(left, quoted=True))
1274+
.intersect(sg.select(STAR).from_(sg.to_identifier(right, quoted=True)))
1275+
)
1276+
1277+
result = ctx.execute(sql.sql(Polars), eager=False)
1278+
1279+
if op.distinct is True:
1280+
return result.unique()
1281+
return result
1282+
1283+
1284+
@translate.register(ops.Difference)
1285+
def execute_difference(op, *, ctx, **kw):
1286+
left = gen_name("polars_diff_left")
1287+
right = gen_name("polars_diff_right")
1288+
1289+
ctx.register_many(
1290+
frames={
1291+
left: translate(op.left, ctx=ctx, **kw),
1292+
right: translate(op.right, ctx=ctx, **kw),
1293+
}
1294+
)
1295+
1296+
sql = (
1297+
sg.select(STAR)
1298+
.from_(sg.to_identifier(left, quoted=True))
1299+
.except_(sg.select(STAR).from_(sg.to_identifier(right, quoted=True)))
1300+
)
1301+
1302+
result = ctx.execute(sql.sql(Polars), eager=False)
1303+
1304+
if op.distinct is True:
1305+
return result.unique()
1306+
return result
1307+
1308+
12591309
@translate.register(ops.Hash)
12601310
def execute_hash(op, **kw):
12611311
# polars' hash() returns a uint64, but we want to return an int64

ibis/backends/tests/test_set_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pytest import param
77

88
import ibis
9-
import ibis.common.exceptions as com
109
import ibis.expr.types as ir
1110
from ibis import _
1211
from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError
@@ -84,7 +83,6 @@ def test_union_mixed_distinct(backend, union_subsets):
8483
param(True, id="distinct"),
8584
],
8685
)
87-
@pytest.mark.notimpl(["polars"])
8886
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
8987
def test_intersect(backend, alltypes, df, distinct):
9088
a = alltypes.filter((_.id >= 5200) & (_.id <= 5210))
@@ -129,7 +127,6 @@ def test_intersect(backend, alltypes, df, distinct):
129127
param(True, id="distinct"),
130128
],
131129
)
132-
@pytest.mark.notimpl(["polars"])
133130
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
134131
def test_difference(backend, alltypes, df, distinct):
135132
a = alltypes.filter((_.id >= 5200) & (_.id <= 5210))
@@ -238,7 +235,6 @@ def test_top_level_union(backend, con, alltypes, distinct, ordered):
238235
),
239236
],
240237
)
241-
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
242238
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
243239
def test_top_level_intersect_difference(
244240
backend, con, alltypes, distinct, opname, expected, ordered

0 commit comments

Comments
 (0)