From decb10441c7ea11981048221cf3636bf3c5005a3 Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Tue, 16 Jul 2024 12:16:22 -0700 Subject: [PATCH 1/4] Fixes majority of DISTINCT conformance tests --- .../eval/internal/operator/rel/RelDistinct.kt | 9 ++- .../internal/transforms/RexConverter.kt | 67 +++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt index a1debc2e54..7ca94f9487 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt @@ -3,16 +3,19 @@ package org.partiql.eval.internal.operator.rel import org.partiql.eval.internal.Environment import org.partiql.eval.internal.Record import org.partiql.eval.internal.operator.Operator +import org.partiql.value.ListValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.listValue +import java.util.TreeSet internal class RelDistinct( val input: Operator.Relation ) : RelPeeking() { - // TODO: Add hashcode/equals support for PQLValue. Then we can use Record directly. + // TODO: Add hashcode/equals support for Datum. Then we can use Record directly. @OptIn(PartiQLValueExperimental::class) - private val seen = mutableSetOf>() + private val seen = TreeSet>(PartiQLValue.comparator()) override fun openPeeking(env: Environment) { input.open(env) @@ -21,7 +24,7 @@ internal class RelDistinct( @OptIn(PartiQLValueExperimental::class) override fun peek(): Record? { for (next in input) { - val transformed = List(next.values.size) { next.values[it].toPartiQLValue() } + val transformed = listValue(List(next.values.size) { next.values[it].toPartiQLValue() }) if (seen.contains(transformed).not()) { seen.add(transformed) return next diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index 09f9a53416..99ed742f85 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -466,10 +466,77 @@ internal object RexConverter { // Args val args = node.args.map { visitExprCoerce(it, context) } // Rex + if (node.setq != null) { + if (isCollAgg(node)) { + return callToCollAgg(id, node.setq!!, args) + } else { + error("Currently, only COLL_ may use set quantifiers.") + } + } val op = rexOpCallUnresolved(id, args) return rex(type, op) } + /** + * @return whether call is `COLL_`. + */ + private fun isCollAgg(node: Expr.Call): Boolean { + val id = node.function as? org.partiql.ast.Identifier.Symbol ?: return false + return id.symbol.lowercase().startsWith("coll_") + } + + /** + * Converts inputs to `COLL_` when DISTINCT is used. + * + * Converts AST `COLL_MAX(DISTINCT x)` to PLAN: + * ``` + * Call (COLL_MAX(Var(0))) + * - Select (Var(0)) + * - Distinct (Var(0)) + * - Scan (x) + * ``` + * + * For the case where there is no set quantifier (or, if ALL is specified), i.e. `COLL_MAX(x)` or + * `COLL_MAX(ALL x)`, the plan is equivalent to `COLL_MAX(x)`. + */ + private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier, args: List): Rex { + if (args.size != 1) { + error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.") + } + if (setQuantifier == SetQuantifier.ALL) { + return Rex(ANY, Rex.Op.Call.Unresolved(id, args)) + } + val input = Rel( + type = Rel.Type( + schema = listOf(Rel.Binding(name = "_input", type = ANY)), + props = emptySet() + ), + op = Rel.Op.Scan(rex = args[0]) + ) + val distinct = Rel( + type = Rel.Type( + schema = listOf(Rel.Binding(name = "_input", type = BOOL)), + props = emptySet() + ), + op = Rel.Op.Distinct(input) + ) + val rex = Rex( + type = ANY, + op = Rex.Op.Select( + constructor = Rex( + type = PType.typeDynamic().toCType(), + op = Rex.Op.Var.Unresolved( + identifier = Identifier.Symbol("_input", Identifier.CaseSensitivity.SENSITIVE), + scope = Rex.Op.Var.Scope.LOCAL + ) + ), + rel = distinct + ) + ) + val op = Rex.Op.Call.Unresolved(id, listOf(rex)) + return Rex(ANY, op) + } + private fun visitExprCallTupleUnion(node: Expr.Call, context: Env): Rex { val type = (STRUCT) val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() From 4730205d7b5984bef918597bd552a2ed59519bcf Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Thu, 25 Jul 2024 11:25:10 -0700 Subject: [PATCH 2/4] Fixes remaining DISTINCT tests --- .../internal/PlanningProblemDetails.kt | 2 +- .../planner/internal/typer/CompilerType.kt | 1 + .../planner/internal/typer/PlanTyper.kt | 15 ++-- .../partiql/value/util/NumberExtensions.kt | 89 +++++++++++++++---- 4 files changed, 83 insertions(+), 24 deletions(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt index ad2bf2f775..eed98169e9 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt @@ -188,7 +188,7 @@ internal open class PlanningProblemDetails( data class ExpressionAlwaysReturnsMissing(val reason: String? = null) : PlanningProblemDetails( severity = ProblemSeverity.ERROR, - messageFormatter = { "Expression always returns null or missing: caused by $reason" } + messageFormatter = { "Expression always returns missing: caused by $reason" } ) data class InvalidArgumentTypeForFunction( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt index ce233a33a9..da5013d056 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt @@ -19,6 +19,7 @@ internal class CompilerType( // Note: This is an experimental property. internal val isMissingValue: Boolean = false ) : PType { + public fun getDelegate(): PType = _delegate override fun getKind(): Kind = _delegate.kind override fun getFields(): MutableCollection { return _delegate.fields.map { field -> diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 1f371ca337..a2ac90a106 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -101,9 +101,9 @@ internal class PlanTyper(private val env: Env) { * * TODO: Can this be merged with [anyOf]? Should we even allow this? */ - fun anyOfLiterals(types: Collection): PType? { + fun anyOfLiterals(types: Collection): PType? { // Grab unique - var unique: Collection = types.toSet() + var unique: Collection = types.map { it.getDelegate() }.toSet() if (unique.size == 0) { return null } else if (unique.size == 1) { @@ -133,7 +133,7 @@ internal class PlanTyper(private val env: Env) { } private fun collapseCollection(collections: Iterable, type: Kind): PType { - val typeParam = anyOfLiterals(collections.map { it.typeParameter })!! + val typeParam = anyOfLiterals(collections.map { it.typeParameter.toCType() })!! return when (type) { Kind.LIST -> PType.typeList(typeParam) Kind.BAG -> PType.typeList(typeParam) @@ -145,13 +145,13 @@ internal class PlanTyper(private val env: Env) { private fun collapseRows(rows: Iterable): PType { val firstFields = rows.first().fields!! val fieldNames = firstFields.map { it.name } - val fieldTypes = firstFields.map { mutableListOf(it.type) } + val fieldTypes = firstFields.map { mutableListOf(it.type.toCType()) } rows.map { struct -> val fields = struct.fields!! if (fields.map { it.name } != fieldNames) { return PType.typeStruct() } - fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type) } + fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type.toCType()) } } val newFields = fieldTypes.mapIndexed { i, types -> Field.of(fieldNames[i], anyOfLiterals(types)!!) } return PType.typeRow(newFields) @@ -162,7 +162,10 @@ internal class PlanTyper(private val env: Env) { return anyOf(unique) } - fun PType.toCType(): CompilerType = CompilerType(this) + fun PType.toCType(): CompilerType = when (this) { + is CompilerType -> this + else -> CompilerType(this) + } fun List.toCType(): List = this.map { it.toCType() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt b/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt index 68ec58678d..676b852843 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt @@ -45,46 +45,95 @@ internal fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecim else -> throw IllegalArgumentException("Unsupported number type: $num, ${num.javaClass}") } +/** + * This should handle Byte, Short, Int, Long, BigInteger, Float, Double, BigDecimal + */ private val CONVERSION_MAP = mapOf>, Class>( + // BYTE + setOf(Byte::class.javaObjectType, Byte::class.javaObjectType) to Byte::class.javaObjectType, + setOf(Byte::class.javaObjectType, Short::class.javaObjectType) to Short::class.javaObjectType, + setOf(Byte::class.javaObjectType, Int::class.javaObjectType) to Int::class.javaObjectType, + setOf(Byte::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, + setOf(Byte::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(Byte::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, + setOf(Byte::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Byte::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // SHORT + setOf(Short::class.javaObjectType, Byte::class.javaObjectType) to Short::class.javaObjectType, + setOf(Short::class.javaObjectType, Short::class.javaObjectType) to Short::class.javaObjectType, + setOf(Short::class.javaObjectType, Int::class.javaObjectType) to Int::class.javaObjectType, + setOf(Short::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, + setOf(Short::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(Short::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, + setOf(Short::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Short::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // INT + setOf(Int::class.javaObjectType, Byte::class.javaObjectType) to Int::class.javaObjectType, + setOf(Int::class.javaObjectType, Short::class.javaObjectType) to Int::class.javaObjectType, setOf(Int::class.javaObjectType, Int::class.javaObjectType) to Int::class.javaObjectType, setOf(Int::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, setOf(Int::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, - setOf(Long::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, - setOf(BigInteger::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, - // Int w/ Float -> Double setOf(Int::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, setOf(Int::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(Int::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - - setOf(Float::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, - // Float w/ Long -> Double - setOf(Float::class.javaObjectType, Long::class.javaObjectType) to Double::class.javaObjectType, - setOf(Float::class.javaObjectType, BigInteger::class.javaObjectType) to Double::class.javaObjectType, - setOf(Float::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, - setOf(Float::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - + // LONG + setOf(Long::class.javaObjectType, Byte::class.javaObjectType) to Long::class.javaObjectType, + setOf(Long::class.javaObjectType, Short::class.javaObjectType) to Long::class.javaObjectType, + setOf(Long::class.javaObjectType, Int::class.javaObjectType) to Long::class.javaObjectType, setOf(Long::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, setOf(Long::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(Long::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, setOf(Long::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(Long::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - + // FLOAT + setOf(Float::class.javaObjectType, Byte::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Short::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Int::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Long::class.javaObjectType) to Double::class.javaObjectType, + setOf(Float::class.javaObjectType, BigInteger::class.javaObjectType) to Double::class.javaObjectType, + setOf(Float::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Float::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // DOUBLE + setOf(Double::class.javaObjectType, Byte::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Short::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Int::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Long::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, BigInteger::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // BIG INTEGER + setOf(BigInteger::class.javaObjectType, Byte::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Short::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Int::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Long::class.javaObjectType) to BigInteger::class.javaObjectType, setOf(BigInteger::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, setOf(BigInteger::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(BigInteger::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - - setOf(Double::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, - setOf(Double::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - + // BIG DECIMAL + setOf(BigDecimal::class.javaObjectType, Byte::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Short::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Int::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Long::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, BigInteger::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Float::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Double::class.javaObjectType) to BigDecimal::class.javaObjectType, setOf(BigDecimal::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, ) private val CONVERTERS = mapOf, (Number) -> Number>( + Byte::class.javaObjectType to Number::toByte, + Short::class.javaObjectType to Number::toShort, Int::class.javaObjectType to Number::toInt, Long::class.javaObjectType to Number::toLong, Float::class.javaObjectType to Number::toFloat, Double::class.javaObjectType to Number::toDouble, BigInteger::class.javaObjectType to { num -> when (num) { + is Byte -> num.toInt().toBigInteger() + is Short -> num.toInt().toBigInteger() is Int -> num.toBigInteger() is Long -> num.toBigInteger() is BigInteger -> num @@ -95,6 +144,8 @@ private val CONVERTERS = mapOf, (Number) -> Number>( }, BigDecimal::class.java to { num -> when (num) { + is Byte -> bigDecimalOf(num) + is Short -> bigDecimalOf(num) is Int -> bigDecimalOf(num) is Long -> bigDecimalOf(num) is Float -> bigDecimalOf(num) @@ -109,13 +160,15 @@ private val CONVERTERS = mapOf, (Number) -> Number>( ) internal fun Number.isZero() = when (this) { + is Byte -> this.toInt() == 0 + is Short -> this.toInt() == 0 is Int -> this == 0 is Long -> this == 0L is Float -> this == 0.0f || this == -0.0f is Double -> this == 0.0 || this == -0.0 is BigDecimal -> this.signum() == 0 is BigInteger -> this.signum() == 0 - else -> throw IllegalStateException("$this") + else -> throw IllegalStateException("$this (${this.javaClass.simpleName})") } @Suppress("UNCHECKED_CAST") @@ -148,6 +201,8 @@ public fun coerceNumbers(first: Number, second: Number): Pair { internal operator fun Number.compareTo(other: Number): Int { val (first, second) = coerceNumbers(this, other) return when (first) { + is Byte -> first.compareTo(second as Byte) + is Short -> first.compareTo(second as Short) is Int -> first.compareTo(second as Int) is Long -> first.compareTo(second as Long) is Float -> first.compareTo(second as Float) From 8360e0ed5764c15b6b5ea4b6242aa9ff33018511 Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Tue, 6 Aug 2024 13:17:00 -0700 Subject: [PATCH 3/4] Addresses PR feedback --- .../internal/transforms/RexConverter.kt | 67 ++++++++--------- .../kotlin/org/partiql/spi/fn/SqlBuiltins.kt | 24 ++++--- .../org/partiql/spi/fn/builtins/FnCollAgg.kt | 71 ++++++++++--------- .../builtins/internal/AccumulatorDistinct.kt | 28 ++++++++ 4 files changed, 109 insertions(+), 81 deletions(-) create mode 100644 partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/internal/AccumulatorDistinct.kt diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index 99ed742f85..214f78555f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -79,6 +79,17 @@ internal object RexConverter { @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") private object ToRex : AstBaseVisitor() { + private val COLL_AGG_NAMES = setOf( + "coll_any", + "coll_avg", + "coll_count", + "coll_every", + "coll_max", + "coll_min", + "coll_some", + "coll_sum", + ) + override fun defaultReturn(node: AstNode, context: Env): Rex = throw IllegalArgumentException("unsupported rex $node") @@ -465,13 +476,14 @@ internal object RexConverter { } // Args val args = node.args.map { visitExprCoerce(it, context) } - // Rex + + // Check if function is actually coll_ + if (isCollAgg(node)) { + return callToCollAgg(id, node.setq, args) + } + if (node.setq != null) { - if (isCollAgg(node)) { - return callToCollAgg(id, node.setq!!, args) - } else { - error("Currently, only COLL_ may use set quantifiers.") - } + error("Currently, only COLL_ may use set quantifiers.") } val op = rexOpCallUnresolved(id, args) return rex(type, op) @@ -482,7 +494,7 @@ internal object RexConverter { */ private fun isCollAgg(node: Expr.Call): Boolean { val id = node.function as? org.partiql.ast.Identifier.Symbol ?: return false - return id.symbol.lowercase().startsWith("coll_") + return COLL_AGG_NAMES.contains(id.symbol.lowercase()) } /** @@ -499,41 +511,20 @@ internal object RexConverter { * For the case where there is no set quantifier (or, if ALL is specified), i.e. `COLL_MAX(x)` or * `COLL_MAX(ALL x)`, the plan is equivalent to `COLL_MAX(x)`. */ - private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier, args: List): Rex { + private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier?, args: List): Rex { + if (id.hasQualifier()) { + error("Qualified function calls are not currently supported.") + } if (args.size != 1) { error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.") } - if (setQuantifier == SetQuantifier.ALL) { - return Rex(ANY, Rex.Op.Call.Unresolved(id, args)) + val postfix = when (setQuantifier) { + SetQuantifier.DISTINCT -> "_distinct" + SetQuantifier.ALL -> "_all" + null -> "_all" } - val input = Rel( - type = Rel.Type( - schema = listOf(Rel.Binding(name = "_input", type = ANY)), - props = emptySet() - ), - op = Rel.Op.Scan(rex = args[0]) - ) - val distinct = Rel( - type = Rel.Type( - schema = listOf(Rel.Binding(name = "_input", type = BOOL)), - props = emptySet() - ), - op = Rel.Op.Distinct(input) - ) - val rex = Rex( - type = ANY, - op = Rex.Op.Select( - constructor = Rex( - type = PType.typeDynamic().toCType(), - op = Rex.Op.Var.Unresolved( - identifier = Identifier.Symbol("_input", Identifier.CaseSensitivity.SENSITIVE), - scope = Rex.Op.Var.Scope.LOCAL - ) - ), - rel = distinct - ) - ) - val op = Rex.Op.Call.Unresolved(id, listOf(rex)) + val newId = Identifier.regular(id.getIdentifier().getText() + postfix) + val op = Rex.Op.Call.Unresolved(newId, listOf(args[0])) return Rex(ANY, op) } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt index 8f88966bab..9bbbb11929 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt @@ -48,14 +48,22 @@ internal object SqlBuiltins { Fn_CHAR_LENGTH__STRING__INT, Fn_CHAR_LENGTH__CLOB__INT, Fn_CHAR_LENGTH__SYMBOL__INT, - Fn_COLL_AGG__BAG__ANY.ANY, - Fn_COLL_AGG__BAG__ANY.AVG, - Fn_COLL_AGG__BAG__ANY.COUNT, - Fn_COLL_AGG__BAG__ANY.EVERY, - Fn_COLL_AGG__BAG__ANY.MAX, - Fn_COLL_AGG__BAG__ANY.MIN, - Fn_COLL_AGG__BAG__ANY.SOME, - Fn_COLL_AGG__BAG__ANY.SUM, + Fn_COLL_AGG__BAG__ANY.ANY_ALL, + Fn_COLL_AGG__BAG__ANY.AVG_ALL, + Fn_COLL_AGG__BAG__ANY.COUNT_ALL, + Fn_COLL_AGG__BAG__ANY.EVERY_ALL, + Fn_COLL_AGG__BAG__ANY.MAX_ALL, + Fn_COLL_AGG__BAG__ANY.MIN_ALL, + Fn_COLL_AGG__BAG__ANY.SOME_ALL, + Fn_COLL_AGG__BAG__ANY.SUM_ALL, + Fn_COLL_AGG__BAG__ANY.ANY_DISTINCT, + Fn_COLL_AGG__BAG__ANY.AVG_DISTINCT, + Fn_COLL_AGG__BAG__ANY.COUNT_DISTINCT, + Fn_COLL_AGG__BAG__ANY.EVERY_DISTINCT, + Fn_COLL_AGG__BAG__ANY.MAX_DISTINCT, + Fn_COLL_AGG__BAG__ANY.MIN_DISTINCT, + Fn_COLL_AGG__BAG__ANY.SOME_DISTINCT, + Fn_COLL_AGG__BAG__ANY.SUM_DISTINCT, Fn_CONCAT__STRING_STRING__STRING, Fn_CONCAT__CLOB_CLOB__CLOB, Fn_CONCAT__SYMBOL_SYMBOL__SYMBOL, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt index a5930241ce..7af5c4d22b 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt @@ -3,7 +3,6 @@ package org.partiql.spi.fn.builtins -import org.partiql.spi.fn.Agg import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnParameter import org.partiql.spi.fn.FnSignature @@ -11,6 +10,7 @@ import org.partiql.spi.fn.builtins.internal.Accumulator import org.partiql.spi.fn.builtins.internal.AccumulatorAnySome import org.partiql.spi.fn.builtins.internal.AccumulatorAvg import org.partiql.spi.fn.builtins.internal.AccumulatorCount +import org.partiql.spi.fn.builtins.internal.AccumulatorDistinct import org.partiql.spi.fn.builtins.internal.AccumulatorEvery import org.partiql.spi.fn.builtins.internal.AccumulatorMax import org.partiql.spi.fn.builtins.internal.AccumulatorMin @@ -22,9 +22,18 @@ import org.partiql.value.PartiQLValueType import org.partiql.value.check @OptIn(PartiQLValueExperimental::class) -internal abstract class Fn_COLL_AGG__BAG__ANY : Fn { +internal abstract class Fn_COLL_AGG__BAG__ANY( + name: String, + private val isDistinct: Boolean, + private val accumulator: () -> Accumulator, +) : Fn { - abstract fun getAccumulator(): Agg.Accumulator + private fun getAccumulator(): Accumulator = when (isDistinct) { + true -> AccumulatorDistinct(accumulator.invoke()) + false -> accumulator.invoke() + } + + override val signature: FnSignature = createSignature(name) companion object { @JvmStatic @@ -46,43 +55,35 @@ internal abstract class Fn_COLL_AGG__BAG__ANY : Fn { return accumulator.value() } - object SUM : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_sum") - override fun getAccumulator(): Accumulator = AccumulatorSum() - } + object SUM_ALL : Fn_COLL_AGG__BAG__ANY("coll_sum_all", false, ::AccumulatorSum) - object AVG : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_avg") - override fun getAccumulator(): Accumulator = AccumulatorAvg() - } + object SUM_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_sum_distinct", true, ::AccumulatorSum) - object MIN : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_min") - override fun getAccumulator(): Accumulator = AccumulatorMin() - } + object AVG_ALL : Fn_COLL_AGG__BAG__ANY("coll_avg_all", false, ::AccumulatorAvg) - object MAX : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_max") - override fun getAccumulator(): Accumulator = AccumulatorMax() - } + object AVG_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_avg_distinct", true, ::AccumulatorAvg) - object COUNT : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_count") - override fun getAccumulator(): Accumulator = AccumulatorCount() - } + object MIN_ALL : Fn_COLL_AGG__BAG__ANY("coll_min_all", false, ::AccumulatorMin) - object EVERY : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_every") - override fun getAccumulator(): Accumulator = AccumulatorEvery() - } + object MIN_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_min_distinct", true, ::AccumulatorMin) - object ANY : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_any") - override fun getAccumulator(): Accumulator = AccumulatorAnySome() - } + object MAX_ALL : Fn_COLL_AGG__BAG__ANY("coll_max_all", false, ::AccumulatorMax) - object SOME : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_some") - override fun getAccumulator(): Accumulator = AccumulatorAnySome() - } + object MAX_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_max_distinct", true, ::AccumulatorMax) + + object COUNT_ALL : Fn_COLL_AGG__BAG__ANY("coll_count_all", false, ::AccumulatorCount) + + object COUNT_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_count_distinct", true, ::AccumulatorCount) + + object EVERY_ALL : Fn_COLL_AGG__BAG__ANY("coll_every_all", false, ::AccumulatorEvery) + + object EVERY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_every_distinct", true, ::AccumulatorEvery) + + object ANY_ALL : Fn_COLL_AGG__BAG__ANY("coll_any_all", false, ::AccumulatorAnySome) + + object ANY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_any_distinct", true, ::AccumulatorAnySome) + + object SOME_ALL : Fn_COLL_AGG__BAG__ANY("coll_some_all", false, ::AccumulatorAnySome) + + object SOME_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_some_distinct", true, ::AccumulatorAnySome) } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/internal/AccumulatorDistinct.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/internal/AccumulatorDistinct.kt new file mode 100644 index 0000000000..52858dd83a --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/internal/AccumulatorDistinct.kt @@ -0,0 +1,28 @@ +package org.partiql.spi.fn.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import java.util.TreeSet + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorDistinct( + private val _delegate: Accumulator, +) : Accumulator() { + + // TODO: Add support for a datum comparator once the accumulator passes datums instead of PartiQL values. + @OptIn(PartiQLValueExperimental::class) + private val seen = TreeSet(PartiQLValue.comparator()) + + @OptIn(PartiQLValueExperimental::class) + override fun nextValue(value: PartiQLValue) { + if (!seen.contains(value)) { + seen.add(value) + _delegate.nextValue(value) + } + } + + @OptIn(PartiQLValueExperimental::class) + override fun value(): PartiQLValue { + return _delegate.value() + } +} From cda355737a27d69d4b31e60d590b91c56b1717d8 Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Wed, 7 Aug 2024 11:23:21 -0700 Subject: [PATCH 4/4] Updates KDoc and visibility modifier --- .../planner/internal/transforms/RexConverter.kt | 16 +++++----------- .../planner/internal/typer/CompilerType.kt | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index 214f78555f..33f2d1cae5 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -498,18 +498,12 @@ internal object RexConverter { } /** - * Converts inputs to `COLL_` when DISTINCT is used. + * Converts COLL_ to the relevant function calls. For example: + * - `COLL_SUM(x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(ALL x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(DISTINCT x)` becomes `coll_sum_distinct(x)` * - * Converts AST `COLL_MAX(DISTINCT x)` to PLAN: - * ``` - * Call (COLL_MAX(Var(0))) - * - Select (Var(0)) - * - Distinct (Var(0)) - * - Scan (x) - * ``` - * - * For the case where there is no set quantifier (or, if ALL is specified), i.e. `COLL_MAX(x)` or - * `COLL_MAX(ALL x)`, the plan is equivalent to `COLL_MAX(x)`. + * It is assumed that the [id] has already been vetted by [isCollAgg]. */ private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier?, args: List): Rex { if (id.hasQualifier()) { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt index da5013d056..44e8eafa98 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt @@ -19,7 +19,7 @@ internal class CompilerType( // Note: This is an experimental property. internal val isMissingValue: Boolean = false ) : PType { - public fun getDelegate(): PType = _delegate + fun getDelegate(): PType = _delegate override fun getKind(): Kind = _delegate.kind override fun getFields(): MutableCollection { return _delegate.fields.map { field ->