Skip to content

Commit 6e8d880

Browse files
committed
Adds support for aggregations (GROUP BY)
Adds support for COLL_AGGs
1 parent 2f7523b commit 6e8d880

File tree

67 files changed

+1701
-448
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+1701
-448
lines changed

partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeGroupBy.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
package org.partiql.ast.normalize
1616

17+
import org.partiql.ast.AstNode
1718
import org.partiql.ast.Expr
1819
import org.partiql.ast.GroupBy
1920
import org.partiql.ast.Statement
@@ -30,6 +31,13 @@ object NormalizeGroupBy : AstPass {
3031

3132
private object Visitor : AstRewriter<Int>() {
3233

34+
override fun visitGroupBy(node: GroupBy, ctx: Int): AstNode {
35+
val keys = node.keys.mapIndexed { index, key ->
36+
visitGroupByKey(key, index + 1)
37+
}
38+
return node.copy(keys = keys)
39+
}
40+
3341
override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key {
3442
val expr = visitExpr(node.expr, 0) as Expr
3543
val alias = when (node.asAlias) {

partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.partiql.eval.internal
22

33
import org.partiql.eval.PartiQLEngine
44
import org.partiql.eval.internal.operator.Operator
5+
import org.partiql.eval.internal.operator.rel.RelAggregate
56
import org.partiql.eval.internal.operator.rel.RelDistinct
67
import org.partiql.eval.internal.operator.rel.RelExclude
78
import org.partiql.eval.internal.operator.rel.RelFilter
@@ -44,6 +45,7 @@ import org.partiql.plan.Rex
4445
import org.partiql.plan.Statement
4546
import org.partiql.plan.debug.PlanPrinter
4647
import org.partiql.plan.visitor.PlanBaseVisitor
48+
import org.partiql.spi.fn.Agg
4749
import org.partiql.spi.fn.FnExperimental
4850
import org.partiql.types.StaticType
4951
import org.partiql.value.PartiQLValueExperimental
@@ -141,7 +143,33 @@ internal class Compiler(
141143
return ExprVarLocal(node.ref)
142144
}
143145

144-
override fun visitRexOpVarGlobal(node: Rex.Op.Var.Global, ctx: StaticType?): Operator = symbols.getGlobal(node.ref)
146+
override fun visitRexOpVarGlobal(node: Rex.Op.Var.Global, ctx: StaticType?): Operator {
147+
return symbols.getGlobal(node.ref)
148+
}
149+
150+
override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: StaticType?): Operator.Relation {
151+
val input = visitRel(node.input, ctx)
152+
val calls = node.calls.map {
153+
visitRelOpAggregateCall(it, ctx)
154+
}
155+
val groups = node.groups.map { visitRex(it, ctx).modeHandled() }
156+
return RelAggregate(input, groups, calls)
157+
}
158+
159+
@OptIn(FnExperimental::class)
160+
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: StaticType?): Operator.Accumulator {
161+
val args = node.args.map { visitRex(it, it.type).modeHandled() } // TODO: Should we support multiple arguments?
162+
val setQuantifier: Operator.Accumulator.SetQuantifier = when (node.setQuantifier) {
163+
Rel.Op.Aggregate.Call.SetQuantifier.ALL -> Operator.Accumulator.SetQuantifier.ALL
164+
Rel.Op.Aggregate.Call.SetQuantifier.DISTINCT -> Operator.Accumulator.SetQuantifier.DISTINCT
165+
}
166+
val agg = symbols.getAgg(node.agg)
167+
return object : Operator.Accumulator {
168+
override val delegate: Agg = agg
169+
override val args: List<Operator.Expr> = args
170+
override val setQuantifier: Operator.Accumulator.SetQuantifier = setQuantifier
171+
}
172+
}
145173

146174
override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Operator {
147175
val root = visitRex(node.root, ctx)
@@ -179,7 +207,7 @@ internal class Compiler(
179207
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
180208
val candidates = node.candidates.map { candidate ->
181209
val fn = symbols.getFn(candidate.fn)
182-
val types = fn.signature.parameters.map { it.type }.toTypedArray()
210+
val types = candidate.parameters.toTypedArray()
183211
val coercions = candidate.coercions.toTypedArray()
184212
ExprCallDynamic.Candidate(fn, types, coercions)
185213
}

partiql-eval/src/main/kotlin/org/partiql/eval/internal/Symbols.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import org.partiql.eval.internal.operator.rex.ExprVarGlobal
77
import org.partiql.plan.Catalog
88
import org.partiql.plan.PartiQLPlan
99
import org.partiql.plan.Ref
10+
import org.partiql.spi.connector.ConnectorAggProvider
1011
import org.partiql.spi.connector.ConnectorBindings
1112
import org.partiql.spi.connector.ConnectorFnProvider
1213
import org.partiql.spi.connector.ConnectorPath
14+
import org.partiql.spi.fn.Agg
1315
import org.partiql.spi.fn.Fn
1416
import org.partiql.spi.fn.FnExperimental
1517

@@ -25,6 +27,7 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
2527
val name: String,
2628
val bindings: ConnectorBindings,
2729
val functions: ConnectorFnProvider,
30+
val aggregations: ConnectorAggProvider,
2831
val items: Array<Catalog.Item>,
2932
) {
3033

@@ -53,6 +56,18 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
5356
?: error("Catalog `$catalog` has no entry for function $item")
5457
}
5558

59+
fun getAgg(ref: Ref): Agg {
60+
val catalog = catalogs[ref.catalog]
61+
val item = catalog.items.getOrNull(ref.symbol)
62+
if (item == null || item !is Catalog.Item.Agg) {
63+
error("Invalid reference $ref; missing function entry for catalog `$catalog`.")
64+
}
65+
// Lookup in connector
66+
val path = ConnectorPath(item.path)
67+
return catalog.aggregations.getAgg(path, item.specific)
68+
?: error("Catalog `$catalog` has no entry for aggregation function $item")
69+
}
70+
5671
companion object {
5772

5873
/**
@@ -71,6 +86,7 @@ internal class Symbols private constructor(private val catalogs: Array<C>) {
7186
name = it.name,
7287
bindings = connector.getBindings(),
7388
functions = connector.getFunctions(),
89+
aggregations = connector.getAggregations(),
7490
items = it.items.toTypedArray()
7591
)
7692
}.toTypedArray()

partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/Operator.kt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package org.partiql.eval.internal.operator
22

33
import org.partiql.eval.internal.Record
4+
import org.partiql.spi.fn.Agg
5+
import org.partiql.spi.fn.FnExperimental
46
import org.partiql.value.PartiQLValue
57
import org.partiql.value.PartiQLValueExperimental
68

@@ -26,4 +28,19 @@ internal sealed interface Operator {
2628

2729
override fun close()
2830
}
31+
32+
interface Accumulator : Operator {
33+
34+
@OptIn(FnExperimental::class)
35+
val delegate: Agg
36+
37+
val args: List<Expr>
38+
39+
val setQuantifier: SetQuantifier
40+
41+
enum class SetQuantifier {
42+
ALL,
43+
DISTINCT
44+
}
45+
}
2946
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package org.partiql.eval.internal.operator.rel
2+
3+
import org.partiql.eval.internal.Record
4+
import org.partiql.eval.internal.operator.Operator
5+
import org.partiql.spi.fn.Agg
6+
import org.partiql.spi.fn.FnExperimental
7+
import org.partiql.value.ListValue
8+
import org.partiql.value.PartiQLValue
9+
import org.partiql.value.PartiQLValueExperimental
10+
import org.partiql.value.PartiQLValueType
11+
import org.partiql.value.listValue
12+
import org.partiql.value.nullValue
13+
import java.util.TreeMap
14+
import java.util.TreeSet
15+
16+
internal class RelAggregate(
17+
val input: Operator.Relation,
18+
val keys: List<Operator.Expr>,
19+
val functions: List<Operator.Accumulator>
20+
) : Operator.Relation {
21+
22+
lateinit var records: Iterator<Record>
23+
24+
@OptIn(PartiQLValueExperimental::class)
25+
val aggregationMap = TreeMap<PartiQLValue, List<AccumulatorWrapper>>(PartiQLValue.comparator(nullsFirst = false))
26+
27+
@OptIn(PartiQLValueExperimental::class)
28+
object PartiQLValueListComparator : Comparator<List<PartiQLValue>> {
29+
private val delegate = PartiQLValue.comparator(nullsFirst = false)
30+
override fun compare(o1: List<PartiQLValue>, o2: List<PartiQLValue>): Int {
31+
if (o1.size < o2.size) {
32+
return -1
33+
}
34+
if (o1.size > o2.size) {
35+
return 1
36+
}
37+
for (index in 0..o2.lastIndex) {
38+
val element1 = o1[index]
39+
val element2 = o2[index]
40+
val compared = delegate.compare(element1, element2)
41+
if (compared != 0) {
42+
return compared
43+
}
44+
}
45+
return 0
46+
}
47+
}
48+
49+
/**
50+
* Wraps an [Operator.Accumulator.Instance] to help with filtering distinct values.
51+
*
52+
* @property seen maintains which values have already been seen. If null, we accumulate all values coming through.
53+
*/
54+
class AccumulatorWrapper @OptIn(PartiQLValueExperimental::class, FnExperimental::class) constructor(
55+
val delegate: Agg.Accumulator,
56+
val args: List<Operator.Expr>,
57+
val seen: TreeSet<List<PartiQLValue>>?
58+
)
59+
60+
@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
61+
override fun open() {
62+
input.open()
63+
var inputRecord = input.next()
64+
while (inputRecord != null) {
65+
// Initialize the AggregationMap
66+
val evaluatedGroupByKeys = listValue(
67+
keys.map {
68+
val key = it.eval(inputRecord!!)
69+
when (key.type == PartiQLValueType.MISSING) {
70+
true -> nullValue()
71+
false -> key
72+
}
73+
}
74+
)
75+
val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) {
76+
functions.map {
77+
AccumulatorWrapper(
78+
delegate = it.delegate.accumulator(),
79+
args = it.args,
80+
seen = when (it.setQuantifier) {
81+
Operator.Accumulator.SetQuantifier.DISTINCT -> TreeSet(PartiQLValueListComparator)
82+
Operator.Accumulator.SetQuantifier.ALL -> null
83+
}
84+
)
85+
}
86+
}
87+
88+
// Aggregate Values in Aggregation State
89+
accumulators.forEachIndexed { index, function ->
90+
val valueToAggregate = function.args.map { it.eval(inputRecord!!) }
91+
// Skip over aggregation if NULL/MISSING
92+
if (valueToAggregate.any { it.type == PartiQLValueType.MISSING || it.isNull }) {
93+
return@forEachIndexed
94+
}
95+
// Skip over aggregation if DISTINCT and SEEN
96+
if (function.seen != null && (function.seen.add(valueToAggregate).not())) {
97+
return@forEachIndexed
98+
}
99+
accumulators[index].delegate.next(valueToAggregate.toTypedArray())
100+
}
101+
inputRecord = input.next()
102+
}
103+
104+
// No Aggregations Created // TODO: How would this be possible?
105+
if (keys.isEmpty() && aggregationMap.isEmpty()) {
106+
val record = mutableListOf<PartiQLValue>()
107+
functions.forEach { function ->
108+
val accumulator = function.delegate.accumulator()
109+
record.add(accumulator.value())
110+
}
111+
records = iterator { yield(Record.of(*record.toTypedArray())) }
112+
return
113+
}
114+
115+
records = iterator {
116+
aggregationMap.forEach { (pValue, accumulators) ->
117+
val keysEvaluated = pValue as ListValue<*>
118+
val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated.map { value -> value }
119+
yield(Record.of(*recordValues.toTypedArray()))
120+
}
121+
}
122+
}
123+
124+
override fun next(): Record? {
125+
return if (records.hasNext()) {
126+
records.next()
127+
} else {
128+
null
129+
}
130+
}
131+
132+
@OptIn(PartiQLValueExperimental::class)
133+
override fun close() {
134+
aggregationMap.clear()
135+
input.close()
136+
}
137+
}

partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ internal class ExprCallDynamic(
3131
return candidate.eval(actualArgs)
3232
}
3333
}
34-
throw TypeCheckException()
34+
val errorString = buildString {
35+
val argString = actualArgs.joinToString(", ")
36+
append("Could not dynamically find function for arguments $argString in $candidates.")
37+
}
38+
throw TypeCheckException(errorString)
3539
}
3640

3741
/**

partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprPathSymbol.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ internal class ExprPathSymbol(
2525
return v
2626
}
2727
}
28-
throw TypeCheckException()
28+
throw TypeCheckException("Couldn't find symbol '$symbol' in $struct.")
2929
}
3030
}

0 commit comments

Comments
 (0)