@@ -455,33 +455,18 @@ def format_limit(self):
455
455
456
456
457
457
class Union (SetOp ):
458
- def __init__ (self , tables , expr , context , distincts ):
459
- super ().__init__ (tables , expr , context )
460
- self .distincts = distincts
461
-
462
- @staticmethod
463
- def keyword (distinct ):
464
- return 'UNION' if distinct else 'UNION ALL'
465
-
466
- def _get_keyword_list (self ):
467
- return map (self .keyword , self .distincts )
458
+ _keyword = "UNION"
468
459
469
460
470
461
class Intersection (SetOp ):
471
462
_keyword = "INTERSECT"
472
463
473
- def _get_keyword_list (self ):
474
- return [self ._keyword ] * (len (self .tables ) - 1 )
475
-
476
464
477
465
class Difference (SetOp ):
478
466
_keyword = "EXCEPT"
479
467
480
- def _get_keyword_list (self ):
481
- return [self ._keyword ] * (len (self .tables ) - 1 )
482
-
483
468
484
- def flatten_union (table : ir .Table ):
469
+ def flatten_set_op (table : ir .Table ):
485
470
"""Extract all union queries from `table`.
486
471
487
472
Parameters
@@ -493,14 +478,14 @@ def flatten_union(table: ir.Table):
493
478
Iterable[Union[Table, bool]]
494
479
"""
495
480
op = table .op ()
496
- if isinstance (op , ops .Union ):
481
+ if isinstance (op , ops .SetOp ):
497
482
# For some reason mypy considers `op.left` and `op.right`
498
483
# of `Argument` type, and fails the validation. While in
499
484
# `flatten` types are the same, and it works
500
485
return toolz .concatv (
501
- flatten_union (op .left ), # type: ignore
486
+ flatten_set_op (op .left ), # type: ignore
502
487
[op .distinct ],
503
- flatten_union (op .right ), # type: ignore
488
+ flatten_set_op (op .right ), # type: ignore
504
489
)
505
490
return [table ]
506
491
@@ -517,7 +502,9 @@ def flatten(table: ir.Table):
517
502
Iterable[Union[Table]]
518
503
"""
519
504
op = table .op ()
520
- return list (toolz .concatv (flatten_union (op .left ), flatten_union (op .right )))
505
+ return list (
506
+ toolz .concatv (flatten_set_op (op .left ), flatten_set_op (op .right ))
507
+ )
521
508
522
509
523
510
class Compiler :
@@ -617,35 +604,37 @@ def _generate_setup_queries(expr, context):
617
604
def _generate_teardown_queries (expr , context ):
618
605
return []
619
606
620
- @classmethod
621
- def _make_union (cls , expr , context ):
607
+ @staticmethod
608
+ def _make_set_op (cls , expr , context ):
622
609
# flatten unions so that we can codegen them all at once
623
- union_info = list (flatten_union (expr ))
610
+ set_op_info = list (flatten_set_op (expr ))
624
611
625
612
# since op is a union, we have at least 3 elements in union_info (left
626
613
# distinct right) and if there is more than a single union we have an
627
614
# additional two elements per union (distinct right) which means the
628
615
# total number of elements is at least 3 + (2 * number of unions - 1)
629
616
# and is therefore an odd number
630
- npieces = len (union_info )
631
- assert npieces >= 3 and npieces % 2 != 0 , 'Invalid union expression'
617
+ npieces = len (set_op_info )
618
+ assert (
619
+ npieces >= 3 and npieces % 2 != 0
620
+ ), 'Invalid set operation expression'
632
621
633
622
# 1. every other object starting from 0 is a Table instance
634
623
# 2. every other object starting from 1 is a bool indicating the type
635
- # of union (distinct or not distinct)
636
- table_exprs , distincts = union_info [::2 ], union_info [1 ::2 ]
637
- return cls .union_class (
638
- table_exprs , expr , distincts = distincts , context = context
639
- )
624
+ # of $set_op (distinct or not distinct)
625
+ table_exprs , distincts = set_op_info [::2 ], set_op_info [1 ::2 ]
626
+ return cls (table_exprs , expr , distincts = distincts , context = context )
627
+
628
+ @classmethod
629
+ def _make_union (cls , expr , context ):
630
+ return cls ._make_set_op (cls .union_class , expr , context )
640
631
641
632
@classmethod
642
633
def _make_intersect (cls , expr , context ):
643
634
# flatten intersections so that we can codegen them all at once
644
- table_exprs = list (flatten (expr ))
645
- return cls .intersect_class (table_exprs , expr , context = context )
635
+ return cls ._make_set_op (cls .intersect_class , expr , context )
646
636
647
637
@classmethod
648
638
def _make_difference (cls , expr , context ):
649
639
# flatten differences so that we can codegen them all at once
650
- table_exprs = list (flatten (expr ))
651
- return cls .difference_class (table_exprs , expr , context = context )
640
+ return cls ._make_set_op (cls .difference_class , expr , context )
0 commit comments