Skip to content

Add implementation of LET clause and LetTests #1745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions partiql-eval/src/test/kotlin/org/partiql/eval/internal/LetTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
package org.partiql.eval.internal

import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution
import org.junit.jupiter.api.parallel.ExecutionMode
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.spi.value.Datum
import org.partiql.spi.value.Field

/**
* This test file exercises the `LET` clause in PartiQL.
*/
class LetTests {

@ParameterizedTest
@MethodSource("successTestCases")
@Execution(ExecutionMode.CONCURRENT)
fun successTests(tc: SuccessTestCase) = tc.run()

@ParameterizedTest
@MethodSource("failureTestCases")
@Execution(ExecutionMode.CONCURRENT)
fun failureTests(tc: FailureTestCase) = tc.run()

companion object {

@JvmStatic
fun successTestCases() = listOf(
SuccessTestCase(
name = "Basic LET usage 1",
input = """
SELECT t.a, c
FROM <<{ 'a': 1 , 'b': 2}>> AS t
LET t.a*5 AS c
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("a", Datum.integer(1)),
Field.of("c", Datum.integer(5))
)
)
),

SuccessTestCase(
name = "Basic LET usage 2",
input = """
SELECT t.x, t.y, t.z * 2 AS double_z
FROM (
SELECT A AS x, B AS y, new_val AS z
FROM <<{ 'A': 1, 'B': 2, 'C': 3}>>
LET B + C AS new_val
) AS t;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("x", Datum.integer(1)),
Field.of("y", Datum.integer(2)),
Field.of("double_z", Datum.integer(10))
)
)
),

SuccessTestCase(
name = "LET with JOIN operation",
input = """
SELECT t.customer_name, t.order_total
FROM (
SELECT
c.name AS customer_name,
total AS order_total
FROM <<
{ 'id': 1, 'name': 'Alice' },
{ 'id': 2, 'name': 'Bob' }
>> AS c
JOIN <<
{ 'customer_id': 1, 'amount': 100 },
{ 'customer_id': 1, 'amount': 200 },
{ 'customer_id': 2, 'amount': 150 }
>> AS o
ON c.id = o.customer_id
LET o.amount * c.id AS total
) AS t;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("customer_name", Datum.string("Alice")),
Field.of("order_total", Datum.integer(100))
),
Datum.struct(
Field.of("customer_name", Datum.string("Alice")),
Field.of("order_total", Datum.integer(200))
),
Datum.struct(
Field.of("customer_name", Datum.string("Bob")),
Field.of("order_total", Datum.integer(300))
)
)
),

SuccessTestCase(
name = "LET with multiple items in data",
input = """
SELECT t.x, t.y, t.z AS total
FROM (
SELECT A AS x, B AS y, sum_val AS z
FROM <<
{ 'A': 1, 'B': 2, 'C': 3 },
{ 'A': 10, 'B': 20, 'C': 30 }
>>
LET B + C AS sum_val
) AS t;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("x", Datum.integer(1)),
Field.of("y", Datum.integer(2)),
Field.of("total", Datum.integer(5))
),
Datum.struct(
Field.of("x", Datum.integer(10)),
Field.of("y", Datum.integer(20)),
Field.of("total", Datum.integer(50))
)
)
),
SuccessTestCase(
name = "LET referencing prior expressions",
input = """
SELECT t.x, t.sum_val, t.double_sum
FROM (
SELECT
A AS x,
sum_val,
sum_val * 2 AS double_sum
FROM <<
{ 'A': 3, 'B': 5 },
{ 'A': 10, 'B': 2 }
>>
LET A + B AS sum_val
) AS t;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("x", Datum.integer(3)),
Field.of("sum_val", Datum.integer(8)),
Field.of("double_sum", Datum.integer(16))
),
Datum.struct(
Field.of("x", Datum.integer(10)),
Field.of("sum_val", Datum.integer(12)),
Field.of("double_sum", Datum.integer(24))
)
)
),
SuccessTestCase(
name = "LET with multiple LET clauses 1",
input = """
SELECT t.a, b, c
FROM << { 'a': 1 }>> AS t
LET t.a + 2 AS b, t.a * 3 AS c
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("a", Datum.integer(1)),
Field.of("b", Datum.integer(3)),
Field.of("c", Datum.integer(3))
)
)
),
SuccessTestCase(
name = "LET with multiple LET clauses 2",
input = """
SELECT a, b, c, d, e
FROM << { 'a': 1 , 'b':2}>> AS t
LET t.a + 5 AS c, t.b+ 10 AS d, t.a + 15 AS e
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("a", Datum.integer(1)),
Field.of("b", Datum.integer(2)),
Field.of("c", Datum.integer(6)),
Field.of("d", Datum.integer(12)),
Field.of("e", Datum.integer(16))
)
)
)
)

@JvmStatic
fun failureTestCases() = listOf(
FailureTestCase(
name = "LET referencing undefined variable",
input = """
SELECT t.x
FROM (
SELECT A AS x
FROM << { 'A': 1, 'B': 2 } >>
LET nonexistent + B AS new_val
) AS t;
""".trimIndent()
),
FailureTestCase(
name = "LET clause with ambiguous reference",
input = """
SELECT t.z
FROM (
SELECT new_val AS z
FROM << { 'A': 1, 'B': 2 } >>
-- 'new_val' references itself in LET, which is not allowed
LET new_val + B AS new_val
) AS t;
""".trimIndent()
),
FailureTestCase(
name = "Outside clauses referencing subquery's LET bindings",
input = """
SELECT t.x, t.y, new_val
FROM (
SELECT A AS x, B AS y
FROM <<{ 'A': 1, 'B': 2, 'C': 3}>>
LET B + C AS new_val
) AS t;
""".trimIndent()
),
FailureTestCase(
name = "LET with invalid JOIN reference",
input = """
SELECT t.customer_name, t.calculated_total
FROM (
SELECT
c.name AS customer_name,
total AS calculated_total
FROM <<
{ 'id': 1, 'name': 'Alice' },
{ 'id': 2, 'name': 'Bob' }
>> AS c
LEFT JOIN <<
{ 'customer_id': 1, 'amount': 100 },
{ 'customer_id': 2, 'amount': 150 }
>> AS o
ON c.id = o.customer_id
-- This should fail because we're trying to reference 'missing_field'
-- which doesn't exist in either joined table
LET missing_field + o.amount AS total
) AS t;
""".trimIndent(),
)
)
}

// Example of a test that might need special handling or a skip
@Test
@Disabled("Demonstration of a scenario needing further investigation.")
fun disabledTestExample() {
// Implementation left blank or used for demonstration
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.partiql.ast.GroupBy
import org.partiql.ast.GroupByStrategy
import org.partiql.ast.Identifier
import org.partiql.ast.JoinType
import org.partiql.ast.Let
import org.partiql.ast.Let.Binding
import org.partiql.ast.Literal.intNum
import org.partiql.ast.Nulls
import org.partiql.ast.Order
Expand Down Expand Up @@ -192,6 +194,7 @@ internal object RelConverter {
rel = convertOffset(rel, offset)
rel = convertLimit(rel, limit)
rel = convertExclude(rel, sel.exclude)
rel = convertLet(rel, sel.let)
// append SQL projection if present
rel = when (val projection = sel.select) {
is SelectValue -> {
Expand Down Expand Up @@ -383,6 +386,13 @@ internal object RelConverter {
return binding to rex
}

private fun convertLetBinding(binding: Binding): Pair<Rel.Binding, Rex> {
val name = binding.asAlias.text
val rex = RexConverter.apply(binding.expr, env)
val newBinding = relBinding(name, rex.type)
return newBinding to rex
}

/**
* Append [Rel.Op.Filter] only if a WHERE condition exists
*/
Expand Down Expand Up @@ -582,6 +592,29 @@ internal object RelConverter {
return rel(type, op)
}

/**
* Concatenate bindings in LET clause with existing env bindings from input
*/
private fun convertLet(input: Rel, let: Let?): Rel {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add tests in amendment

if (let == null) {
return input
}
val schema = input.type.schema.toMutableList()
val props = input.type.props
val projections = mutableListOf<Rex>()
repeat(input.type.schema.size) { index ->
projections.add(rex(ANY, rexOpVarLocal(0, index)))
}
let.bindings.forEach {
val (newBinding, projection) = convertLetBinding(it)
schema.add(newBinding)
projections.add(projection)
}
val type = relType(schema, props)
val op = relOpProject(input, projections)
return rel(type, op)
}

/**
* Append [Rel.Op.Offset] if there is an OFFSET
*/
Expand Down
Loading