Skip to content

Commit bb0a6f0

Browse files
cpcloudgforsyth
authored andcommitted
feat(api): allow single argument set operations
1 parent dc80512 commit bb0a6f0

File tree

2 files changed

+104
-27
lines changed

2 files changed

+104
-27
lines changed

ibis/backends/tests/test_set_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import ibis
99
import ibis.common.exceptions as com
10+
import ibis.expr.types as ir
1011
from ibis import _
1112

1213

@@ -141,9 +142,10 @@ def test_difference(backend, alltypes, df, distinct):
141142

142143

143144
@pytest.mark.parametrize("method", ["intersect", "difference", "union"])
144-
def test_empty_set_op(alltypes, method):
145-
with pytest.raises(com.IbisTypeError, match="requires a table or tables"):
146-
getattr(alltypes, method)()
145+
@pytest.mark.parametrize("source", [ibis, ir.Table], ids=["top_level", "method"])
146+
def test_empty_set_op(alltypes, method, source):
147+
result = getattr(source, method)(alltypes)
148+
assert result.equals(alltypes)
147149

148150

149151
@pytest.mark.parametrize(

ibis/expr/types/relations.py

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,10 @@ def difference(self, *tables: Table, distinct: bool = True) -> Table:
760760
distinct
761761
Only diff distinct rows not occurring in the calling table
762762
763+
See Also
764+
--------
765+
[`ibis.difference`][ibis.difference]
766+
763767
Returns
764768
-------
765769
Table
@@ -797,15 +801,36 @@ def difference(self, *tables: Table, distinct: bool = True) -> Table:
797801
├───────┤
798802
│ 1 │
799803
└───────┘
804+
805+
Passing no arguments to `difference` returns the table expression
806+
807+
This can be useful when you have a sequence of tables to process, and
808+
you don't know the length prior to running your program (for example, user input).
809+
810+
>>> t1
811+
┏━━━━━━━┓
812+
┃ a ┃
813+
┡━━━━━━━┩
814+
│ int64 │
815+
├───────┤
816+
│ 1 │
817+
│ 2 │
818+
└───────┘
819+
>>> t1.difference()
820+
┏━━━━━━━┓
821+
┃ a ┃
822+
┡━━━━━━━┩
823+
│ int64 │
824+
├───────┤
825+
│ 1 │
826+
│ 2 │
827+
└───────┘
828+
>>> t1.difference().equals(t1)
829+
True
800830
"""
801-
left = self
802-
if not tables:
803-
raise com.IbisTypeError(
804-
"difference requires a table or tables to compare against"
805-
)
806-
for right in tables:
807-
left = ops.Difference(left, right, distinct=distinct)
808-
return left.to_expr()
831+
return functools.reduce(
832+
functools.partial(ops.Difference, distinct=distinct), tables, self.op()
833+
).to_expr()
809834

810835
def aggregate(
811836
self,
@@ -1102,6 +1127,10 @@ def union(self, *tables: Table, distinct: bool = False) -> Table:
11021127
Table
11031128
A new table containing the union of all input tables.
11041129
1130+
See Also
1131+
--------
1132+
[`ibis.union`][ibis.union]
1133+
11051134
Examples
11061135
--------
11071136
>>> import ibis
@@ -1147,15 +1176,36 @@ def union(self, *tables: Table, distinct: bool = False) -> Table:
11471176
│ 2 │
11481177
│ 3 │
11491178
└───────┘
1179+
1180+
Passing no arguments to `union` returns the table expression
1181+
1182+
This can be useful when you have a sequence of tables to process, and
1183+
you don't know the length prior to running your program (for example, user input).
1184+
1185+
>>> t1
1186+
┏━━━━━━━┓
1187+
┃ a ┃
1188+
┡━━━━━━━┩
1189+
│ int64 │
1190+
├───────┤
1191+
│ 1 │
1192+
│ 2 │
1193+
└───────┘
1194+
>>> t1.union()
1195+
┏━━━━━━━┓
1196+
┃ a ┃
1197+
┡━━━━━━━┩
1198+
│ int64 │
1199+
├───────┤
1200+
│ 1 │
1201+
│ 2 │
1202+
└───────┘
1203+
>>> t1.union().equals(t1)
1204+
True
11501205
"""
1151-
left = self
1152-
if not tables:
1153-
raise com.IbisTypeError(
1154-
"union requires a table or tables to compare against"
1155-
)
1156-
for right in tables:
1157-
left = ops.Union(left, right, distinct=distinct)
1158-
return left.to_expr()
1206+
return functools.reduce(
1207+
functools.partial(ops.Union, distinct=distinct), tables, self.op()
1208+
).to_expr()
11591209

11601210
def intersect(self, *tables: Table, distinct: bool = True) -> Table:
11611211
"""Compute the set intersection of multiple table expressions.
@@ -1174,6 +1224,10 @@ def intersect(self, *tables: Table, distinct: bool = True) -> Table:
11741224
Table
11751225
A new table containing the intersection of all input tables.
11761226
1227+
See Also
1228+
--------
1229+
[`ibis.intersect`][ibis.intersect]
1230+
11771231
Examples
11781232
--------
11791233
>>> import ibis
@@ -1206,15 +1260,36 @@ def intersect(self, *tables: Table, distinct: bool = True) -> Table:
12061260
├───────┤
12071261
│ 2 │
12081262
└───────┘
1263+
1264+
Passing no arguments to `intersect` returns the table expression.
1265+
1266+
This can be useful when you have a sequence of tables to process, and
1267+
you don't know the length prior to running your program (for example, user input).
1268+
1269+
>>> t1
1270+
┏━━━━━━━┓
1271+
┃ a ┃
1272+
┡━━━━━━━┩
1273+
│ int64 │
1274+
├───────┤
1275+
│ 1 │
1276+
│ 2 │
1277+
└───────┘
1278+
>>> t1.intersect()
1279+
┏━━━━━━━┓
1280+
┃ a ┃
1281+
┡━━━━━━━┩
1282+
│ int64 │
1283+
├───────┤
1284+
│ 1 │
1285+
│ 2 │
1286+
└───────┘
1287+
>>> t1.intersect().equals(t1)
1288+
True
12091289
"""
1210-
left = self
1211-
if not tables:
1212-
raise com.IbisTypeError(
1213-
"intersect requires a table or tables to compare against"
1214-
)
1215-
for right in tables:
1216-
left = ops.Intersection(left, right, distinct=distinct)
1217-
return left.to_expr()
1290+
return functools.reduce(
1291+
functools.partial(ops.Intersection, distinct=distinct), tables, self.op()
1292+
).to_expr()
12181293

12191294
def to_array(self) -> ir.Column:
12201295
"""View a single column table as an array.

0 commit comments

Comments
 (0)