diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/LetTests.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/LetTests.kt new file mode 100644 index 000000000..2eed5a71f --- /dev/null +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/LetTests.kt @@ -0,0 +1,259 @@ +package org.partiql.eval.internal + +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.partiql.spi.value.Datum +import org.partiql.spi.value.Field + +/** + * This test file exercises the `LET` clause in PartiQL. + */ +class LetTests { + + @ParameterizedTest + @MethodSource("successTestCases") + @Execution(ExecutionMode.CONCURRENT) + fun successTests(tc: SuccessTestCase) = tc.run() + + @ParameterizedTest + @MethodSource("failureTestCases") + @Execution(ExecutionMode.CONCURRENT) + fun failureTests(tc: FailureTestCase) = tc.run() + + companion object { + + @JvmStatic + fun successTestCases() = listOf( + SuccessTestCase( + name = "Basic LET usage 1", + input = """ + SELECT t.a, c + FROM <<{ 'a': 1 , 'b': 2}>> AS t + LET t.a*5 AS c + """.trimIndent(), + expected = Datum.bagVararg( + Datum.struct( + Field.of("a", Datum.integer(1)), + Field.of("c", Datum.integer(5)) + ) + ) + ), + + SuccessTestCase( + name = "Basic LET usage 2", + input = """ + SELECT t.x, t.y, t.z * 2 AS double_z + FROM ( + SELECT A AS x, B AS y, new_val AS z + FROM <<{ 'A': 1, 'B': 2, 'C': 3}>> + LET B + C AS new_val + ) AS t; + """.trimIndent(), + expected = Datum.bagVararg( + Datum.struct( + Field.of("x", Datum.integer(1)), + Field.of("y", Datum.integer(2)), + Field.of("double_z", Datum.integer(10)) + ) + ) + ), + + SuccessTestCase( + name = "LET with JOIN operation", + input = """ + SELECT t.customer_name, t.order_total + FROM ( + SELECT + c.name AS customer_name, + total AS order_total + FROM << + { 'id': 1, 'name': 'Alice' }, + { 'id': 2, 'name': 'Bob' } + >> AS c + JOIN << + { 'customer_id': 1, 'amount': 100 }, + { 'customer_id': 1, 'amount': 200 }, + { 'customer_id': 2, 'amount': 150 } + >> AS o + ON c.id = o.customer_id + LET o.amount * c.id AS total + ) AS t; + """.trimIndent(), + expected = Datum.bagVararg( + Datum.struct( + Field.of("customer_name", Datum.string("Alice")), + Field.of("order_total", Datum.integer(100)) + ), + Datum.struct( + Field.of("customer_name", Datum.string("Alice")), + Field.of("order_total", Datum.integer(200)) + ), + Datum.struct( + Field.of("customer_name", Datum.string("Bob")), + Field.of("order_total", Datum.integer(300)) + ) + ) + ), + + SuccessTestCase( + name = "LET with multiple items in data", + input = """ + SELECT t.x, t.y, t.z AS total + FROM ( + SELECT A AS x, B AS y, sum_val AS z + FROM << + { 'A': 1, 'B': 2, 'C': 3 }, + { 'A': 10, 'B': 20, 'C': 30 } + >> + LET B + C AS sum_val + ) AS t; + """.trimIndent(), + expected = Datum.bagVararg( + Datum.struct( + Field.of("x", Datum.integer(1)), + Field.of("y", Datum.integer(2)), + Field.of("total", Datum.integer(5)) + ), + Datum.struct( + Field.of("x", Datum.integer(10)), + Field.of("y", Datum.integer(20)), + Field.of("total", Datum.integer(50)) + ) + ) + ), + SuccessTestCase( + name = "LET referencing prior expressions", + input = """ + SELECT t.x, t.sum_val, t.double_sum + FROM ( + SELECT + A AS x, + sum_val, + sum_val * 2 AS double_sum + FROM << + { 'A': 3, 'B': 5 }, + { 'A': 10, 'B': 2 } + >> + LET A + B AS sum_val + ) AS t; + """.trimIndent(), + expected = Datum.bagVararg( + Datum.struct( + Field.of("x", Datum.integer(3)), + Field.of("sum_val", Datum.integer(8)), + Field.of("double_sum", Datum.integer(16)) + ), + Datum.struct( + Field.of("x", Datum.integer(10)), + Field.of("sum_val", Datum.integer(12)), + Field.of("double_sum", Datum.integer(24)) + ) + ) + ), + SuccessTestCase( + name = "LET with multiple LET clauses 1", + input = """ + SELECT t.a, b, c + FROM << { 'a': 1 }>> AS t + LET t.a + 2 AS b, t.a * 3 AS c + """.trimIndent(), + expected = Datum.bagVararg( + Datum.struct( + Field.of("a", Datum.integer(1)), + Field.of("b", Datum.integer(3)), + Field.of("c", Datum.integer(3)) + ) + ) + ), + SuccessTestCase( + name = "LET with multiple LET clauses 2", + input = """ + SELECT a, b, c, d, e + FROM << { 'a': 1 , 'b':2}>> AS t + LET t.a + 5 AS c, t.b+ 10 AS d, t.a + 15 AS e + """.trimIndent(), + expected = Datum.bagVararg( + Datum.struct( + Field.of("a", Datum.integer(1)), + Field.of("b", Datum.integer(2)), + Field.of("c", Datum.integer(6)), + Field.of("d", Datum.integer(12)), + Field.of("e", Datum.integer(16)) + ) + ) + ) + ) + + @JvmStatic + fun failureTestCases() = listOf( + FailureTestCase( + name = "LET referencing undefined variable", + input = """ + SELECT t.x + FROM ( + SELECT A AS x + FROM << { 'A': 1, 'B': 2 } >> + LET nonexistent + B AS new_val + ) AS t; + """.trimIndent() + ), + FailureTestCase( + name = "LET clause with ambiguous reference", + input = """ + SELECT t.z + FROM ( + SELECT new_val AS z + FROM << { 'A': 1, 'B': 2 } >> + -- 'new_val' references itself in LET, which is not allowed + LET new_val + B AS new_val + ) AS t; + """.trimIndent() + ), + FailureTestCase( + name = "Outside clauses referencing subquery's LET bindings", + input = """ + SELECT t.x, t.y, new_val + FROM ( + SELECT A AS x, B AS y + FROM <<{ 'A': 1, 'B': 2, 'C': 3}>> + LET B + C AS new_val + ) AS t; + """.trimIndent() + ), + FailureTestCase( + name = "LET with invalid JOIN reference", + input = """ + SELECT t.customer_name, t.calculated_total + FROM ( + SELECT + c.name AS customer_name, + total AS calculated_total + FROM << + { 'id': 1, 'name': 'Alice' }, + { 'id': 2, 'name': 'Bob' } + >> AS c + LEFT JOIN << + { 'customer_id': 1, 'amount': 100 }, + { 'customer_id': 2, 'amount': 150 } + >> AS o + ON c.id = o.customer_id + -- This should fail because we're trying to reference 'missing_field' + -- which doesn't exist in either joined table + LET missing_field + o.amount AS total + ) AS t; + """.trimIndent(), + ) + ) + } + + // Example of a test that might need special handling or a skip + @Test + @Disabled("Demonstration of a scenario needing further investigation.") + fun disabledTestExample() { + // Implementation left blank or used for demonstration + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index 95022cd37..130eb91c5 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -31,6 +31,8 @@ import org.partiql.ast.GroupBy import org.partiql.ast.GroupByStrategy import org.partiql.ast.Identifier import org.partiql.ast.JoinType +import org.partiql.ast.Let +import org.partiql.ast.Let.Binding import org.partiql.ast.Literal.intNum import org.partiql.ast.Nulls import org.partiql.ast.Order @@ -192,6 +194,7 @@ internal object RelConverter { rel = convertOffset(rel, offset) rel = convertLimit(rel, limit) rel = convertExclude(rel, sel.exclude) + rel = convertLet(rel, sel.let) // append SQL projection if present rel = when (val projection = sel.select) { is SelectValue -> { @@ -383,6 +386,13 @@ internal object RelConverter { return binding to rex } + private fun convertLetBinding(binding: Binding): Pair { + val name = binding.asAlias.text + val rex = RexConverter.apply(binding.expr, env) + val newBinding = relBinding(name, rex.type) + return newBinding to rex + } + /** * Append [Rel.Op.Filter] only if a WHERE condition exists */ @@ -582,6 +592,29 @@ internal object RelConverter { return rel(type, op) } + /** + * Concatenate bindings in LET clause with existing env bindings from input + */ + private fun convertLet(input: Rel, let: Let?): Rel { + if (let == null) { + return input + } + val schema = input.type.schema.toMutableList() + val props = input.type.props + val projections = mutableListOf() + repeat(input.type.schema.size) { index -> + projections.add(rex(ANY, rexOpVarLocal(0, index))) + } + let.bindings.forEach { + val (newBinding, projection) = convertLetBinding(it) + schema.add(newBinding) + projections.add(projection) + } + val type = relType(schema, props) + val op = relOpProject(input, projections) + return rel(type, op) + } + /** * Append [Rel.Op.Offset] if there is an OFFSET */