diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt index a1debc2e54..7ca94f9487 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelDistinct.kt @@ -3,16 +3,19 @@ package org.partiql.eval.internal.operator.rel import org.partiql.eval.internal.Environment import org.partiql.eval.internal.Record import org.partiql.eval.internal.operator.Operator +import org.partiql.value.ListValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.listValue +import java.util.TreeSet internal class RelDistinct( val input: Operator.Relation ) : RelPeeking() { - // TODO: Add hashcode/equals support for PQLValue. Then we can use Record directly. + // TODO: Add hashcode/equals support for Datum. Then we can use Record directly. @OptIn(PartiQLValueExperimental::class) - private val seen = mutableSetOf>() + private val seen = TreeSet>(PartiQLValue.comparator()) override fun openPeeking(env: Environment) { input.open(env) @@ -21,7 +24,7 @@ internal class RelDistinct( @OptIn(PartiQLValueExperimental::class) override fun peek(): Record? { for (next in input) { - val transformed = List(next.values.size) { next.values[it].toPartiQLValue() } + val transformed = listValue(List(next.values.size) { next.values[it].toPartiQLValue() }) if (seen.contains(transformed).not()) { seen.add(transformed) return next diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt index ad2bf2f775..eed98169e9 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt @@ -188,7 +188,7 @@ internal open class PlanningProblemDetails( data class ExpressionAlwaysReturnsMissing(val reason: String? = null) : PlanningProblemDetails( severity = ProblemSeverity.ERROR, - messageFormatter = { "Expression always returns null or missing: caused by $reason" } + messageFormatter = { "Expression always returns missing: caused by $reason" } ) data class InvalidArgumentTypeForFunction( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index 09f9a53416..33f2d1cae5 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -79,6 +79,17 @@ internal object RexConverter { @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") private object ToRex : AstBaseVisitor() { + private val COLL_AGG_NAMES = setOf( + "coll_any", + "coll_avg", + "coll_count", + "coll_every", + "coll_max", + "coll_min", + "coll_some", + "coll_sum", + ) + override fun defaultReturn(node: AstNode, context: Env): Rex = throw IllegalArgumentException("unsupported rex $node") @@ -465,11 +476,52 @@ internal object RexConverter { } // Args val args = node.args.map { visitExprCoerce(it, context) } - // Rex + + // Check if function is actually coll_ + if (isCollAgg(node)) { + return callToCollAgg(id, node.setq, args) + } + + if (node.setq != null) { + error("Currently, only COLL_ may use set quantifiers.") + } val op = rexOpCallUnresolved(id, args) return rex(type, op) } + /** + * @return whether call is `COLL_`. + */ + private fun isCollAgg(node: Expr.Call): Boolean { + val id = node.function as? org.partiql.ast.Identifier.Symbol ?: return false + return COLL_AGG_NAMES.contains(id.symbol.lowercase()) + } + + /** + * Converts COLL_ to the relevant function calls. For example: + * - `COLL_SUM(x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(ALL x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(DISTINCT x)` becomes `coll_sum_distinct(x)` + * + * It is assumed that the [id] has already been vetted by [isCollAgg]. + */ + private fun callToCollAgg(id: Identifier, setQuantifier: SetQuantifier?, args: List): Rex { + if (id.hasQualifier()) { + error("Qualified function calls are not currently supported.") + } + if (args.size != 1) { + error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.") + } + val postfix = when (setQuantifier) { + SetQuantifier.DISTINCT -> "_distinct" + SetQuantifier.ALL -> "_all" + null -> "_all" + } + val newId = Identifier.regular(id.getIdentifier().getText() + postfix) + val op = Rex.Op.Call.Unresolved(newId, listOf(args[0])) + return Rex(ANY, op) + } + private fun visitExprCallTupleUnion(node: Expr.Call, context: Env): Rex { val type = (STRUCT) val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt index ce233a33a9..44e8eafa98 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/CompilerType.kt @@ -19,6 +19,7 @@ internal class CompilerType( // Note: This is an experimental property. internal val isMissingValue: Boolean = false ) : PType { + fun getDelegate(): PType = _delegate override fun getKind(): Kind = _delegate.kind override fun getFields(): MutableCollection { return _delegate.fields.map { field -> diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 1f371ca337..a2ac90a106 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -101,9 +101,9 @@ internal class PlanTyper(private val env: Env) { * * TODO: Can this be merged with [anyOf]? Should we even allow this? */ - fun anyOfLiterals(types: Collection): PType? { + fun anyOfLiterals(types: Collection): PType? { // Grab unique - var unique: Collection = types.toSet() + var unique: Collection = types.map { it.getDelegate() }.toSet() if (unique.size == 0) { return null } else if (unique.size == 1) { @@ -133,7 +133,7 @@ internal class PlanTyper(private val env: Env) { } private fun collapseCollection(collections: Iterable, type: Kind): PType { - val typeParam = anyOfLiterals(collections.map { it.typeParameter })!! + val typeParam = anyOfLiterals(collections.map { it.typeParameter.toCType() })!! return when (type) { Kind.LIST -> PType.typeList(typeParam) Kind.BAG -> PType.typeList(typeParam) @@ -145,13 +145,13 @@ internal class PlanTyper(private val env: Env) { private fun collapseRows(rows: Iterable): PType { val firstFields = rows.first().fields!! val fieldNames = firstFields.map { it.name } - val fieldTypes = firstFields.map { mutableListOf(it.type) } + val fieldTypes = firstFields.map { mutableListOf(it.type.toCType()) } rows.map { struct -> val fields = struct.fields!! if (fields.map { it.name } != fieldNames) { return PType.typeStruct() } - fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type) } + fields.forEachIndexed { index, field -> fieldTypes[index].add(field.type.toCType()) } } val newFields = fieldTypes.mapIndexed { i, types -> Field.of(fieldNames[i], anyOfLiterals(types)!!) } return PType.typeRow(newFields) @@ -162,7 +162,10 @@ internal class PlanTyper(private val env: Env) { return anyOf(unique) } - fun PType.toCType(): CompilerType = CompilerType(this) + fun PType.toCType(): CompilerType = when (this) { + is CompilerType -> this + else -> CompilerType(this) + } fun List.toCType(): List = this.map { it.toCType() } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt index 8f88966bab..9bbbb11929 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/SqlBuiltins.kt @@ -48,14 +48,22 @@ internal object SqlBuiltins { Fn_CHAR_LENGTH__STRING__INT, Fn_CHAR_LENGTH__CLOB__INT, Fn_CHAR_LENGTH__SYMBOL__INT, - Fn_COLL_AGG__BAG__ANY.ANY, - Fn_COLL_AGG__BAG__ANY.AVG, - Fn_COLL_AGG__BAG__ANY.COUNT, - Fn_COLL_AGG__BAG__ANY.EVERY, - Fn_COLL_AGG__BAG__ANY.MAX, - Fn_COLL_AGG__BAG__ANY.MIN, - Fn_COLL_AGG__BAG__ANY.SOME, - Fn_COLL_AGG__BAG__ANY.SUM, + Fn_COLL_AGG__BAG__ANY.ANY_ALL, + Fn_COLL_AGG__BAG__ANY.AVG_ALL, + Fn_COLL_AGG__BAG__ANY.COUNT_ALL, + Fn_COLL_AGG__BAG__ANY.EVERY_ALL, + Fn_COLL_AGG__BAG__ANY.MAX_ALL, + Fn_COLL_AGG__BAG__ANY.MIN_ALL, + Fn_COLL_AGG__BAG__ANY.SOME_ALL, + Fn_COLL_AGG__BAG__ANY.SUM_ALL, + Fn_COLL_AGG__BAG__ANY.ANY_DISTINCT, + Fn_COLL_AGG__BAG__ANY.AVG_DISTINCT, + Fn_COLL_AGG__BAG__ANY.COUNT_DISTINCT, + Fn_COLL_AGG__BAG__ANY.EVERY_DISTINCT, + Fn_COLL_AGG__BAG__ANY.MAX_DISTINCT, + Fn_COLL_AGG__BAG__ANY.MIN_DISTINCT, + Fn_COLL_AGG__BAG__ANY.SOME_DISTINCT, + Fn_COLL_AGG__BAG__ANY.SUM_DISTINCT, Fn_CONCAT__STRING_STRING__STRING, Fn_CONCAT__CLOB_CLOB__CLOB, Fn_CONCAT__SYMBOL_SYMBOL__SYMBOL, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt index a5930241ce..7af5c4d22b 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/FnCollAgg.kt @@ -3,7 +3,6 @@ package org.partiql.spi.fn.builtins -import org.partiql.spi.fn.Agg import org.partiql.spi.fn.Fn import org.partiql.spi.fn.FnParameter import org.partiql.spi.fn.FnSignature @@ -11,6 +10,7 @@ import org.partiql.spi.fn.builtins.internal.Accumulator import org.partiql.spi.fn.builtins.internal.AccumulatorAnySome import org.partiql.spi.fn.builtins.internal.AccumulatorAvg import org.partiql.spi.fn.builtins.internal.AccumulatorCount +import org.partiql.spi.fn.builtins.internal.AccumulatorDistinct import org.partiql.spi.fn.builtins.internal.AccumulatorEvery import org.partiql.spi.fn.builtins.internal.AccumulatorMax import org.partiql.spi.fn.builtins.internal.AccumulatorMin @@ -22,9 +22,18 @@ import org.partiql.value.PartiQLValueType import org.partiql.value.check @OptIn(PartiQLValueExperimental::class) -internal abstract class Fn_COLL_AGG__BAG__ANY : Fn { +internal abstract class Fn_COLL_AGG__BAG__ANY( + name: String, + private val isDistinct: Boolean, + private val accumulator: () -> Accumulator, +) : Fn { - abstract fun getAccumulator(): Agg.Accumulator + private fun getAccumulator(): Accumulator = when (isDistinct) { + true -> AccumulatorDistinct(accumulator.invoke()) + false -> accumulator.invoke() + } + + override val signature: FnSignature = createSignature(name) companion object { @JvmStatic @@ -46,43 +55,35 @@ internal abstract class Fn_COLL_AGG__BAG__ANY : Fn { return accumulator.value() } - object SUM : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_sum") - override fun getAccumulator(): Accumulator = AccumulatorSum() - } + object SUM_ALL : Fn_COLL_AGG__BAG__ANY("coll_sum_all", false, ::AccumulatorSum) - object AVG : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_avg") - override fun getAccumulator(): Accumulator = AccumulatorAvg() - } + object SUM_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_sum_distinct", true, ::AccumulatorSum) - object MIN : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_min") - override fun getAccumulator(): Accumulator = AccumulatorMin() - } + object AVG_ALL : Fn_COLL_AGG__BAG__ANY("coll_avg_all", false, ::AccumulatorAvg) - object MAX : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_max") - override fun getAccumulator(): Accumulator = AccumulatorMax() - } + object AVG_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_avg_distinct", true, ::AccumulatorAvg) - object COUNT : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_count") - override fun getAccumulator(): Accumulator = AccumulatorCount() - } + object MIN_ALL : Fn_COLL_AGG__BAG__ANY("coll_min_all", false, ::AccumulatorMin) - object EVERY : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_every") - override fun getAccumulator(): Accumulator = AccumulatorEvery() - } + object MIN_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_min_distinct", true, ::AccumulatorMin) - object ANY : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_any") - override fun getAccumulator(): Accumulator = AccumulatorAnySome() - } + object MAX_ALL : Fn_COLL_AGG__BAG__ANY("coll_max_all", false, ::AccumulatorMax) - object SOME : Fn_COLL_AGG__BAG__ANY() { - override val signature = createSignature("coll_some") - override fun getAccumulator(): Accumulator = AccumulatorAnySome() - } + object MAX_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_max_distinct", true, ::AccumulatorMax) + + object COUNT_ALL : Fn_COLL_AGG__BAG__ANY("coll_count_all", false, ::AccumulatorCount) + + object COUNT_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_count_distinct", true, ::AccumulatorCount) + + object EVERY_ALL : Fn_COLL_AGG__BAG__ANY("coll_every_all", false, ::AccumulatorEvery) + + object EVERY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_every_distinct", true, ::AccumulatorEvery) + + object ANY_ALL : Fn_COLL_AGG__BAG__ANY("coll_any_all", false, ::AccumulatorAnySome) + + object ANY_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_any_distinct", true, ::AccumulatorAnySome) + + object SOME_ALL : Fn_COLL_AGG__BAG__ANY("coll_some_all", false, ::AccumulatorAnySome) + + object SOME_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_some_distinct", true, ::AccumulatorAnySome) } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/internal/AccumulatorDistinct.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/internal/AccumulatorDistinct.kt new file mode 100644 index 0000000000..52858dd83a --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/fn/builtins/internal/AccumulatorDistinct.kt @@ -0,0 +1,28 @@ +package org.partiql.spi.fn.builtins.internal + +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import java.util.TreeSet + +@OptIn(PartiQLValueExperimental::class) +internal class AccumulatorDistinct( + private val _delegate: Accumulator, +) : Accumulator() { + + // TODO: Add support for a datum comparator once the accumulator passes datums instead of PartiQL values. + @OptIn(PartiQLValueExperimental::class) + private val seen = TreeSet(PartiQLValue.comparator()) + + @OptIn(PartiQLValueExperimental::class) + override fun nextValue(value: PartiQLValue) { + if (!seen.contains(value)) { + seen.add(value) + _delegate.nextValue(value) + } + } + + @OptIn(PartiQLValueExperimental::class) + override fun value(): PartiQLValue { + return _delegate.value() + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt b/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt index 68ec58678d..676b852843 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/value/util/NumberExtensions.kt @@ -45,46 +45,95 @@ internal fun bigDecimalOf(num: Number, mc: MathContext = MATH_CONTEXT): BigDecim else -> throw IllegalArgumentException("Unsupported number type: $num, ${num.javaClass}") } +/** + * This should handle Byte, Short, Int, Long, BigInteger, Float, Double, BigDecimal + */ private val CONVERSION_MAP = mapOf>, Class>( + // BYTE + setOf(Byte::class.javaObjectType, Byte::class.javaObjectType) to Byte::class.javaObjectType, + setOf(Byte::class.javaObjectType, Short::class.javaObjectType) to Short::class.javaObjectType, + setOf(Byte::class.javaObjectType, Int::class.javaObjectType) to Int::class.javaObjectType, + setOf(Byte::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, + setOf(Byte::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(Byte::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, + setOf(Byte::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Byte::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // SHORT + setOf(Short::class.javaObjectType, Byte::class.javaObjectType) to Short::class.javaObjectType, + setOf(Short::class.javaObjectType, Short::class.javaObjectType) to Short::class.javaObjectType, + setOf(Short::class.javaObjectType, Int::class.javaObjectType) to Int::class.javaObjectType, + setOf(Short::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, + setOf(Short::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(Short::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, + setOf(Short::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Short::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // INT + setOf(Int::class.javaObjectType, Byte::class.javaObjectType) to Int::class.javaObjectType, + setOf(Int::class.javaObjectType, Short::class.javaObjectType) to Int::class.javaObjectType, setOf(Int::class.javaObjectType, Int::class.javaObjectType) to Int::class.javaObjectType, setOf(Int::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, setOf(Int::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, - setOf(Long::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, - setOf(BigInteger::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, - // Int w/ Float -> Double setOf(Int::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, setOf(Int::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(Int::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - - setOf(Float::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, - // Float w/ Long -> Double - setOf(Float::class.javaObjectType, Long::class.javaObjectType) to Double::class.javaObjectType, - setOf(Float::class.javaObjectType, BigInteger::class.javaObjectType) to Double::class.javaObjectType, - setOf(Float::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, - setOf(Float::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - + // LONG + setOf(Long::class.javaObjectType, Byte::class.javaObjectType) to Long::class.javaObjectType, + setOf(Long::class.javaObjectType, Short::class.javaObjectType) to Long::class.javaObjectType, + setOf(Long::class.javaObjectType, Int::class.javaObjectType) to Long::class.javaObjectType, setOf(Long::class.javaObjectType, Long::class.javaObjectType) to Long::class.javaObjectType, setOf(Long::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(Long::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, setOf(Long::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(Long::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - + // FLOAT + setOf(Float::class.javaObjectType, Byte::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Short::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Int::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Long::class.javaObjectType) to Double::class.javaObjectType, + setOf(Float::class.javaObjectType, BigInteger::class.javaObjectType) to Double::class.javaObjectType, + setOf(Float::class.javaObjectType, Float::class.javaObjectType) to Float::class.javaObjectType, + setOf(Float::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Float::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // DOUBLE + setOf(Double::class.javaObjectType, Byte::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Short::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Int::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Long::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, BigInteger::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, + setOf(Double::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, + // BIG INTEGER + setOf(BigInteger::class.javaObjectType, Byte::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Short::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Int::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Long::class.javaObjectType) to BigInteger::class.javaObjectType, setOf(BigInteger::class.javaObjectType, BigInteger::class.javaObjectType) to BigInteger::class.javaObjectType, + setOf(BigInteger::class.javaObjectType, Float::class.javaObjectType) to Double::class.javaObjectType, setOf(BigInteger::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, setOf(BigInteger::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - - setOf(Double::class.javaObjectType, Double::class.javaObjectType) to Double::class.javaObjectType, - setOf(Double::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, - + // BIG DECIMAL + setOf(BigDecimal::class.javaObjectType, Byte::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Short::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Int::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Long::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, BigInteger::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Float::class.javaObjectType) to BigDecimal::class.javaObjectType, + setOf(BigDecimal::class.javaObjectType, Double::class.javaObjectType) to BigDecimal::class.javaObjectType, setOf(BigDecimal::class.javaObjectType, BigDecimal::class.javaObjectType) to BigDecimal::class.javaObjectType, ) private val CONVERTERS = mapOf, (Number) -> Number>( + Byte::class.javaObjectType to Number::toByte, + Short::class.javaObjectType to Number::toShort, Int::class.javaObjectType to Number::toInt, Long::class.javaObjectType to Number::toLong, Float::class.javaObjectType to Number::toFloat, Double::class.javaObjectType to Number::toDouble, BigInteger::class.javaObjectType to { num -> when (num) { + is Byte -> num.toInt().toBigInteger() + is Short -> num.toInt().toBigInteger() is Int -> num.toBigInteger() is Long -> num.toBigInteger() is BigInteger -> num @@ -95,6 +144,8 @@ private val CONVERTERS = mapOf, (Number) -> Number>( }, BigDecimal::class.java to { num -> when (num) { + is Byte -> bigDecimalOf(num) + is Short -> bigDecimalOf(num) is Int -> bigDecimalOf(num) is Long -> bigDecimalOf(num) is Float -> bigDecimalOf(num) @@ -109,13 +160,15 @@ private val CONVERTERS = mapOf, (Number) -> Number>( ) internal fun Number.isZero() = when (this) { + is Byte -> this.toInt() == 0 + is Short -> this.toInt() == 0 is Int -> this == 0 is Long -> this == 0L is Float -> this == 0.0f || this == -0.0f is Double -> this == 0.0 || this == -0.0 is BigDecimal -> this.signum() == 0 is BigInteger -> this.signum() == 0 - else -> throw IllegalStateException("$this") + else -> throw IllegalStateException("$this (${this.javaClass.simpleName})") } @Suppress("UNCHECKED_CAST") @@ -148,6 +201,8 @@ public fun coerceNumbers(first: Number, second: Number): Pair { internal operator fun Number.compareTo(other: Number): Int { val (first, second) = coerceNumbers(this, other) return when (first) { + is Byte -> first.compareTo(second as Byte) + is Short -> first.compareTo(second as Short) is Int -> first.compareTo(second as Int) is Long -> first.compareTo(second as Long) is Float -> first.compareTo(second as Float)