Skip to content

Commit 2bc0b69

Browse files
committed
feat(sqlalchemy): properly implement Intersection and Difference
1 parent cd9a34c commit 2bc0b69

File tree

3 files changed

+62
-47
lines changed

3 files changed

+62
-47
lines changed

ibis/backends/base/sql/alchemy/query_builder.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
Select,
2020
SelectBuilder,
2121
TableSetFormatter,
22-
Union,
2322
)
23+
from ibis.backends.base.sql.compiler.base import SetOp
2424

2525

2626
def _schema_to_sqlalchemy_columns(schema: sch.Schema) -> list[sa.Column]:
@@ -343,21 +343,40 @@ def _convert_group_by(self, exprs):
343343
return exprs
344344

345345

346-
class AlchemyUnion(Union):
347-
def compile(self):
348-
def reduce_union(left, right, distincts=iter(self.distincts)):
349-
distinct = next(distincts)
350-
sa_func = sa.union if distinct else sa.union_all
351-
return sa_func(left, right)
346+
class AlchemySetOp(SetOp):
347+
@classmethod
348+
def reduce(cls, left, right, distincts):
349+
distinct = next(distincts)
350+
sa_func = cls.distinct_func if distinct else cls.non_distinct_func
351+
return sa_func(left, right)
352352

353+
def compile(self):
353354
context = self.context
354355
selects = []
355356

356357
for table in self.tables:
357358
table_set = context.get_compiled_expr(table)
358359
selects.append(table_set.cte().select())
359360

360-
return functools.reduce(reduce_union, selects)
361+
return functools.reduce(
362+
functools.partial(self.reduce, distincts=iter(self.distincts)),
363+
selects,
364+
)
365+
366+
367+
class AlchemyUnion(AlchemySetOp):
368+
distinct_func = sa.union
369+
non_distinct_func = sa.union_all
370+
371+
372+
class AlchemyIntersection(AlchemySetOp):
373+
distinct_func = sa.intersect
374+
non_distinct_func = sa.intersect_all
375+
376+
377+
class AlchemyDifference(AlchemySetOp):
378+
distinct_func = sa.except_
379+
non_distinct_func = sa.except_all
361380

362381

363382
class AlchemyCompiler(Compiler):
@@ -367,6 +386,8 @@ class AlchemyCompiler(Compiler):
367386
select_builder_class = AlchemySelectBuilder
368387
select_class = AlchemySelect
369388
union_class = AlchemyUnion
389+
intersect_class = AlchemyIntersection
390+
difference_class = AlchemyDifference
370391

371392
@classmethod
372393
def to_sql(cls, expr, context=None, params=None, exists=False):

ibis/backends/base/sql/compiler/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,20 @@ def compile(self):
5353

5454

5555
class SetOp(DML):
56-
def __init__(self, tables, expr, context):
56+
def __init__(self, tables, expr, context, distincts):
5757
self.context = context
5858
self.tables = tables
5959
self.table_set = expr
60+
self.distincts = distincts
6061
self.filters = []
6162

63+
@classmethod
64+
def keyword(cls, distinct):
65+
return cls._keyword + (not distinct) * " ALL"
66+
67+
def _get_keyword_list(self):
68+
return map(self.keyword, self.distincts)
69+
6270
def _extract_subqueries(self):
6371
self.subqueries = _extract_common_table_expressions(
6472
[self.table_set, *self.filters]
@@ -84,9 +92,6 @@ def format_relation(self, expr):
8492
return f'SELECT *\nFROM {ref}'
8593
return self.context.get_compiled_expr(expr)
8694

87-
def _get_keyword_list(self):
88-
raise NotImplementedError("Need objects to interleave")
89-
9095
def compile(self):
9196
self._extract_subqueries()
9297

ibis/backends/base/sql/compiler/query_builder.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -455,33 +455,18 @@ def format_limit(self):
455455

456456

457457
class Union(SetOp):
458-
def __init__(self, tables, expr, context, distincts):
459-
super().__init__(tables, expr, context)
460-
self.distincts = distincts
461-
462-
@staticmethod
463-
def keyword(distinct):
464-
return 'UNION' if distinct else 'UNION ALL'
465-
466-
def _get_keyword_list(self):
467-
return map(self.keyword, self.distincts)
458+
_keyword = "UNION"
468459

469460

470461
class Intersection(SetOp):
471462
_keyword = "INTERSECT"
472463

473-
def _get_keyword_list(self):
474-
return [self._keyword] * (len(self.tables) - 1)
475-
476464

477465
class Difference(SetOp):
478466
_keyword = "EXCEPT"
479467

480-
def _get_keyword_list(self):
481-
return [self._keyword] * (len(self.tables) - 1)
482-
483468

484-
def flatten_union(table: ir.Table):
469+
def flatten_set_op(table: ir.Table):
485470
"""Extract all union queries from `table`.
486471
487472
Parameters
@@ -493,14 +478,14 @@ def flatten_union(table: ir.Table):
493478
Iterable[Union[Table, bool]]
494479
"""
495480
op = table.op()
496-
if isinstance(op, ops.Union):
481+
if isinstance(op, ops.SetOp):
497482
# For some reason mypy considers `op.left` and `op.right`
498483
# of `Argument` type, and fails the validation. While in
499484
# `flatten` types are the same, and it works
500485
return toolz.concatv(
501-
flatten_union(op.left), # type: ignore
486+
flatten_set_op(op.left), # type: ignore
502487
[op.distinct],
503-
flatten_union(op.right), # type: ignore
488+
flatten_set_op(op.right), # type: ignore
504489
)
505490
return [table]
506491

@@ -517,7 +502,9 @@ def flatten(table: ir.Table):
517502
Iterable[Union[Table]]
518503
"""
519504
op = table.op()
520-
return list(toolz.concatv(flatten_union(op.left), flatten_union(op.right)))
505+
return list(
506+
toolz.concatv(flatten_set_op(op.left), flatten_set_op(op.right))
507+
)
521508

522509

523510
class Compiler:
@@ -617,35 +604,37 @@ def _generate_setup_queries(expr, context):
617604
def _generate_teardown_queries(expr, context):
618605
return []
619606

620-
@classmethod
621-
def _make_union(cls, expr, context):
607+
@staticmethod
608+
def _make_set_op(cls, expr, context):
622609
# flatten unions so that we can codegen them all at once
623-
union_info = list(flatten_union(expr))
610+
set_op_info = list(flatten_set_op(expr))
624611

625612
# since op is a union, we have at least 3 elements in union_info (left
626613
# distinct right) and if there is more than a single union we have an
627614
# additional two elements per union (distinct right) which means the
628615
# total number of elements is at least 3 + (2 * number of unions - 1)
629616
# and is therefore an odd number
630-
npieces = len(union_info)
631-
assert npieces >= 3 and npieces % 2 != 0, 'Invalid union expression'
617+
npieces = len(set_op_info)
618+
assert (
619+
npieces >= 3 and npieces % 2 != 0
620+
), 'Invalid set operation expression'
632621

633622
# 1. every other object starting from 0 is a Table instance
634623
# 2. every other object starting from 1 is a bool indicating the type
635-
# of union (distinct or not distinct)
636-
table_exprs, distincts = union_info[::2], union_info[1::2]
637-
return cls.union_class(
638-
table_exprs, expr, distincts=distincts, context=context
639-
)
624+
# of $set_op (distinct or not distinct)
625+
table_exprs, distincts = set_op_info[::2], set_op_info[1::2]
626+
return cls(table_exprs, expr, distincts=distincts, context=context)
627+
628+
@classmethod
629+
def _make_union(cls, expr, context):
630+
return cls._make_set_op(cls.union_class, expr, context)
640631

641632
@classmethod
642633
def _make_intersect(cls, expr, context):
643634
# flatten intersections so that we can codegen them all at once
644-
table_exprs = list(flatten(expr))
645-
return cls.intersect_class(table_exprs, expr, context=context)
635+
return cls._make_set_op(cls.intersect_class, expr, context)
646636

647637
@classmethod
648638
def _make_difference(cls, expr, context):
649639
# flatten differences so that we can codegen them all at once
650-
table_exprs = list(flatten(expr))
651-
return cls.difference_class(table_exprs, expr, context=context)
640+
return cls._make_set_op(cls.difference_class, expr, context)

0 commit comments

Comments
 (0)