Skip to content

Commit 7985180

Browse files
authored
Merge 25d56dd into 609f8b8
2 parents 609f8b8 + 25d56dd commit 7985180

File tree

31 files changed

+553
-280
lines changed

31 files changed

+553
-280
lines changed

partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,11 @@ internal class Compiler(
204204
}
205205
}
206206

207-
@OptIn(FnExperimental::class, PartiQLValueExperimental::class)
207+
@OptIn(FnExperimental::class)
208208
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator {
209209
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
210-
val candidates = node.candidates.map { candidate ->
210+
val candidates = Array(node.candidates.size) {
211+
val candidate = node.candidates[it]
211212
val fn = symbols.getFn(candidate.fn)
212213
val coercions = candidate.coercions.toTypedArray()
213214
ExprCallDynamic.Candidate(fn, coercions)

partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt

Lines changed: 152 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ import org.partiql.value.PartiQLValueType
2121
*/
2222
@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
2323
internal class ExprCallDynamic(
24-
private val candidates: List<Candidate>,
24+
candidates: Array<Candidate>,
2525
private val args: Array<Operator.Expr>
2626
) : Operator.Expr {
2727

28+
private val candidateIndex = CandidateIndex.All(candidates)
29+
2830
override fun eval(env: Environment): PartiQLValue {
2931
val actualArgs = args.map { it.eval(env) }.toTypedArray()
30-
candidates.forEach { candidate ->
31-
if (candidate.matches(actualArgs)) {
32-
return candidate.eval(actualArgs, env)
33-
}
32+
val actualTypes = actualArgs.map { it.type }
33+
candidateIndex.get(actualTypes)?.let {
34+
return it.eval(actualArgs, env)
3435
}
3536
val errorString = buildString {
3637
val argString = actualArgs.joinToString(", ")
37-
append("Could not dynamically find function for arguments $argString in $candidates.")
38+
append("Could not dynamically find function (${candidateIndex.name}) for arguments $argString.")
3839
}
3940
throw TypeCheckException(errorString)
4041
}
@@ -47,13 +48,11 @@ internal class ExprCallDynamic(
4748
*
4849
* @see ExprCallDynamic
4950
*/
50-
internal class Candidate(
51+
data class Candidate(
5152
val fn: Fn,
5253
val coercions: Array<Ref.Cast?>
5354
) {
5455

55-
private val signatureParameters = fn.signature.parameters.map { it.type }.toTypedArray()
56-
5756
fun eval(originalArgs: Array<PartiQLValue>, env: Environment): PartiQLValue {
5857
val args = originalArgs.mapIndexed { i, arg ->
5958
when (val c = coercions[i]) {
@@ -63,32 +62,156 @@ internal class ExprCallDynamic(
6362
}.toTypedArray()
6463
return fn.invoke(args)
6564
}
65+
}
66+
67+
private sealed interface CandidateIndex {
68+
69+
public fun get(args: List<PartiQLValueType>): Candidate?
70+
71+
/**
72+
* Preserves the original ordering of the passed-in candidates while making it faster to lookup matching
73+
* functions. Utilizes both [Direct] and [Indirect].
74+
*
75+
* Say a user passes in the following ordered candidates:
76+
* [
77+
* foo(int16, int16) -> int16,
78+
* foo(int32, int32) -> int32,
79+
* foo(int64, int64) -> int64,
80+
* foo(string, string) -> string,
81+
* foo(struct, struct) -> struct,
82+
* foo(numeric, numeric) -> numeric,
83+
* foo(int64, dynamic) -> dynamic,
84+
* foo(struct, dynamic) -> dynamic,
85+
* foo(bool, bool) -> bool
86+
* ]
87+
*
88+
* With the above candidates, the [CandidateIndex.All] will maintain the original ordering by utilizing:
89+
* - [CandidateIndex.Direct] to match hashable runtime types
90+
* - [CandidateIndex.Indirect] to match the dynamic type
91+
*
92+
* For the above example, the internal representation of [CandidateIndex.All] is a list of
93+
* [CandidateIndex.Direct] and [CandidateIndex.Indirect] that looks like:
94+
* ALL listOf(
95+
* DIRECT hashMap(
96+
* [int16, int16] --> foo(int16, int16) -> int16,
97+
* [int32, int32] --> foo(int32, int32) -> int32,
98+
* [int64, int64] --> foo(int64, int64) -> int64
99+
* [string, string] --> foo(string, string) -> string,
100+
* [struct, struct] --> foo(struct, struct) -> struct,
101+
* [numeric, numeric] --> foo(numeric, numeric) -> numeric
102+
* ),
103+
* INDIRECT listOf(
104+
* foo(int64, dynamic) -> dynamic,
105+
* foo(struct, dynamic) -> dynamic
106+
* ),
107+
* DIRECT hashMap(
108+
* [bool, bool] --> foo(bool, bool) -> bool
109+
* )
110+
* )
111+
*
112+
* @param candidates
113+
*/
114+
class All(
115+
candidates: Array<Candidate>,
116+
) : CandidateIndex {
117+
118+
private val lookups: List<CandidateIndex>
119+
internal val name: String = candidates.first().fn.signature.name
66120

67-
internal fun matches(inputs: Array<PartiQLValue>): Boolean {
68-
for (i in inputs.indices) {
69-
val inputType = inputs[i].type
70-
val parameterType = signatureParameters[i]
71-
val c = coercions[i]
72-
when (c) {
73-
// coercion might be null if one of the following is true
74-
// Function parameter is ANY,
75-
// Input type is null
76-
// input type is the same as function parameter
77-
null -> {
78-
if (!(inputType == parameterType || inputType == PartiQLValueType.NULL || parameterType == PartiQLValueType.ANY)) {
79-
return false
121+
init {
122+
val lookupsMutable = mutableListOf<CandidateIndex>()
123+
val accumulator = mutableListOf<Pair<List<PartiQLValueType>, Candidate>>()
124+
125+
// Indicates that we are currently processing dynamic candidates that accept ANY.
126+
var activelyProcessingAny = true
127+
128+
candidates.forEach { candidate ->
129+
// Gather the input types to the dynamic invocation
130+
val lookupTypes = candidate.coercions.mapIndexed { index, cast ->
131+
when (cast) {
132+
null -> candidate.fn.signature.parameters[index].type
133+
else -> cast.input
80134
}
81135
}
82-
else -> {
83-
// checking the input type is expected by the coercion
84-
if (inputType != c.input) return false
85-
// checking the result is expected by the function signature
86-
// this should branch should never be reached, but leave it here for clarity
87-
if (c.target != parameterType) error("Internal Error: Cast Target does not match Function Parameter")
136+
val parametersIncludeAny = lookupTypes.any { it == PartiQLValueType.ANY }
137+
// A way to simplify logic further below. If it's empty, add something and set the processing type.
138+
if (accumulator.isEmpty()) {
139+
activelyProcessingAny = parametersIncludeAny
140+
accumulator.add(lookupTypes to candidate)
141+
return@forEach
142+
}
143+
when (parametersIncludeAny) {
144+
true -> when (activelyProcessingAny) {
145+
true -> accumulator.add(lookupTypes to candidate)
146+
false -> {
147+
activelyProcessingAny = true
148+
lookupsMutable.add(Direct.of(accumulator.toList()))
149+
accumulator.clear()
150+
accumulator.add(lookupTypes to candidate)
151+
}
152+
}
153+
false -> when (activelyProcessingAny) {
154+
false -> accumulator.add(lookupTypes to candidate)
155+
true -> {
156+
activelyProcessingAny = false
157+
lookupsMutable.add(Indirect(accumulator.toList()))
158+
accumulator.clear()
159+
accumulator.add(lookupTypes to candidate)
160+
}
161+
}
162+
}
163+
}
164+
// Add any remaining candidates (that we didn't submit due to not ending while switching)
165+
when (accumulator.isEmpty()) {
166+
true -> { /* Do nothing! */ }
167+
false -> when (activelyProcessingAny) {
168+
true -> lookupsMutable.add(Indirect(accumulator.toList()))
169+
false -> lookupsMutable.add(Direct.of(accumulator.toList()))
170+
}
171+
}
172+
this.lookups = lookupsMutable
173+
}
174+
175+
override fun get(args: List<PartiQLValueType>): Candidate? {
176+
return this.lookups.firstNotNullOfOrNull { it.get(args) }
177+
}
178+
}
179+
180+
/**
181+
* An O(1) structure to quickly find directly matching dynamic candidates. This is specifically used for runtime
182+
* types that can be matched directly. AKA int32, int64, etc. This does NOT include [PartiQLValueType.ANY].
183+
*/
184+
data class Direct private constructor(val directCandidates: HashMap<List<PartiQLValueType>, Candidate>) : CandidateIndex {
185+
186+
companion object {
187+
internal fun of(candidates: List<Pair<List<PartiQLValueType>, Candidate>>): Direct {
188+
val candidateMap = java.util.HashMap<List<PartiQLValueType>, Candidate>()
189+
candidateMap.putAll(candidates)
190+
return Direct(candidateMap)
191+
}
192+
}
193+
194+
override fun get(args: List<PartiQLValueType>): Candidate? {
195+
return directCandidates[args]
196+
}
197+
}
198+
199+
/**
200+
* Holds all candidates that expect a [PartiQLValueType.ANY] on input. This maintains the original
201+
* precedence order.
202+
*/
203+
data class Indirect(private val candidates: List<Pair<List<PartiQLValueType>, Candidate>>) : CandidateIndex {
204+
override fun get(args: List<PartiQLValueType>): Candidate? {
205+
candidates.forEach { (types, candidate) ->
206+
for (i in args.indices) {
207+
if (args[i] != types[i] && types[i] != PartiQLValueType.ANY) {
208+
return@forEach
209+
}
88210
}
211+
return candidate
89212
}
213+
return null
90214
}
91-
return true
92215
}
93216
}
94217
}

partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCast.kt

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.partiql.value.Int64Value
2121
import org.partiql.value.Int8Value
2222
import org.partiql.value.IntValue
2323
import org.partiql.value.ListValue
24+
import org.partiql.value.NullValue
2425
import org.partiql.value.NumericValue
2526
import org.partiql.value.PartiQLValue
2627
import org.partiql.value.PartiQLValueExperimental
@@ -30,7 +31,13 @@ import org.partiql.value.StringValue
3031
import org.partiql.value.SymbolValue
3132
import org.partiql.value.TextValue
3233
import org.partiql.value.bagValue
34+
import org.partiql.value.binaryValue
35+
import org.partiql.value.blobValue
3336
import org.partiql.value.boolValue
37+
import org.partiql.value.byteValue
38+
import org.partiql.value.charValue
39+
import org.partiql.value.clobValue
40+
import org.partiql.value.dateValue
3441
import org.partiql.value.decimalValue
3542
import org.partiql.value.float32Value
3643
import org.partiql.value.float64Value
@@ -40,9 +47,13 @@ import org.partiql.value.int64Value
4047
import org.partiql.value.int8Value
4148
import org.partiql.value.intValue
4249
import org.partiql.value.listValue
50+
import org.partiql.value.missingValue
4351
import org.partiql.value.sexpValue
4452
import org.partiql.value.stringValue
53+
import org.partiql.value.structValue
4554
import org.partiql.value.symbolValue
55+
import org.partiql.value.timeValue
56+
import org.partiql.value.timestampValue
4657
import java.math.BigDecimal
4758
import java.math.BigInteger
4859

@@ -79,14 +90,48 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E
7990
PartiQLValueType.LIST -> castFromCollection(arg as ListValue<*>, cast.target)
8091
PartiQLValueType.SEXP -> castFromCollection(arg as SexpValue<*>, cast.target)
8192
PartiQLValueType.STRUCT -> TODO("CAST FROM STRUCT not yet implemented")
82-
PartiQLValueType.NULL -> error("cast from NULL should be handled by Typer")
93+
PartiQLValueType.NULL -> castFromNull(arg as NullValue, cast.target)
8394
PartiQLValueType.MISSING -> error("cast from MISSING should be handled by Typer")
8495
}
8596
} catch (e: DataException) {
8697
throw TypeCheckException()
8798
}
8899
}
89100

101+
@OptIn(PartiQLValueExperimental::class)
102+
private fun castFromNull(value: NullValue, t: PartiQLValueType): PartiQLValue {
103+
return when (t) {
104+
PartiQLValueType.ANY -> value
105+
PartiQLValueType.BOOL -> boolValue(null)
106+
PartiQLValueType.CHAR -> charValue(null)
107+
PartiQLValueType.STRING -> stringValue(null)
108+
PartiQLValueType.SYMBOL -> symbolValue(null)
109+
PartiQLValueType.BINARY -> binaryValue(null)
110+
PartiQLValueType.BYTE -> byteValue(null)
111+
PartiQLValueType.BLOB -> blobValue(null)
112+
PartiQLValueType.CLOB -> clobValue(null)
113+
PartiQLValueType.DATE -> dateValue(null)
114+
PartiQLValueType.TIME -> timeValue(null)
115+
PartiQLValueType.TIMESTAMP -> timestampValue(null)
116+
PartiQLValueType.INTERVAL -> TODO("Not yet supported")
117+
PartiQLValueType.BAG -> bagValue<PartiQLValue>(null)
118+
PartiQLValueType.LIST -> listValue<PartiQLValue>(null)
119+
PartiQLValueType.SEXP -> sexpValue<PartiQLValue>(null)
120+
PartiQLValueType.STRUCT -> structValue<PartiQLValue>(null)
121+
PartiQLValueType.NULL -> value
122+
PartiQLValueType.MISSING -> missingValue() // TODO: Os this allowed
123+
PartiQLValueType.INT8 -> int8Value(null)
124+
PartiQLValueType.INT16 -> int16Value(null)
125+
PartiQLValueType.INT32 -> int32Value(null)
126+
PartiQLValueType.INT64 -> int64Value(null)
127+
PartiQLValueType.INT -> intValue(null)
128+
PartiQLValueType.DECIMAL -> decimalValue(null)
129+
PartiQLValueType.DECIMAL_ARBITRARY -> decimalValue(null)
130+
PartiQLValueType.FLOAT32 -> float32Value(null)
131+
PartiQLValueType.FLOAT64 -> float64Value(null)
132+
}
133+
}
134+
90135
@OptIn(PartiQLValueExperimental::class)
91136
private fun castFromBool(value: BoolValue, t: PartiQLValueType): PartiQLValue {
92137
val v = value.value

partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ import org.partiql.value.structValue
3636
import java.io.ByteArrayOutputStream
3737
import java.math.BigDecimal
3838
import java.math.BigInteger
39-
import kotlin.test.assertEquals
4039
import kotlin.test.assertNotNull
4140

4241
/**
@@ -1253,10 +1252,12 @@ class PartiQLEngineDefaultTest {
12531252

12541253
internal fun assert() {
12551254
val permissiveResult = run(mode = PartiQLEngine.Mode.PERMISSIVE)
1256-
assertEquals(expectedPermissive, permissiveResult, comparisonString(expectedPermissive, permissiveResult))
1255+
assert(expectedPermissive == permissiveResult.first) {
1256+
comparisonString(expectedPermissive, permissiveResult.first, permissiveResult.second)
1257+
}
12571258
var error: Throwable? = null
12581259
try {
1259-
when (val result = run(mode = PartiQLEngine.Mode.STRICT)) {
1260+
when (val result = run(mode = PartiQLEngine.Mode.STRICT).first) {
12601261
is CollectionValue<*> -> result.toList()
12611262
else -> result
12621263
}
@@ -1266,7 +1267,7 @@ class PartiQLEngineDefaultTest {
12661267
assertNotNull(error)
12671268
}
12681269

1269-
private fun run(mode: PartiQLEngine.Mode): PartiQLValue {
1270+
private fun run(mode: PartiQLEngine.Mode): Pair<PartiQLValue, PartiQLPlan> {
12701271
val statement = parser.parse(input).root
12711272
val catalog = MemoryCatalog.PartiQL().name("memory").build()
12721273
val connector = MemoryConnector(catalog)
@@ -1283,17 +1284,18 @@ class PartiQLEngineDefaultTest {
12831284
val plan = planner.plan(statement, session)
12841285
val prepared = engine.prepare(plan.plan, PartiQLEngine.Session(mapOf("memory" to connector), mode = mode))
12851286
when (val result = engine.execute(prepared)) {
1286-
is PartiQLResult.Value -> return result.value
1287+
is PartiQLResult.Value -> return result.value to plan.plan
12871288
is PartiQLResult.Error -> throw result.cause
12881289
}
12891290
}
12901291

12911292
@OptIn(PartiQLValueExperimental::class)
1292-
private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue): String {
1293+
private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: PartiQLPlan): String {
12931294
val expectedBuffer = ByteArrayOutputStream()
12941295
val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer)
12951296
expectedWriter.append(expected)
12961297
return buildString {
1298+
PlanPrinter.append(this, plan)
12971299
appendLine("Expected : $expectedBuffer")
12981300
expectedBuffer.reset()
12991301
expectedWriter.append(actual)
@@ -1444,6 +1446,7 @@ class PartiQLEngineDefaultTest {
14441446
).assert()
14451447

14461448
@Test
1449+
@Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.")
14471450
// TODO: Add to conformance tests
14481451
fun wildCard() =
14491452
SuccessTestCase(
@@ -1487,6 +1490,7 @@ class PartiQLEngineDefaultTest {
14871490
).assert()
14881491

14891492
@Test
1493+
@Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.")
14901494
// TODO: add to conformance tests
14911495
// Note that the existing pipeline produced identical result when supplying with
14921496
// SELECT VALUE v2.name FROM e as v0, v0.books as v1, unpivot v1.authors as v2;

0 commit comments

Comments
 (0)