Skip to content

Adds planning and evaluation support for CTEs #1738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ enum class ErrorCodeString(val code: Int) {
NUMERIC_VALUE_OUT_OF_RANGE(PError.NUMERIC_VALUE_OUT_OF_RANGE),
INVALID_CHAR_VALUE_FOR_CAST(PError.INVALID_CHAR_VALUE_FOR_CAST),
DIVISION_BY_ZERO(PError.DIVISION_BY_ZERO),
DEGREE_VIOLATION_SCALAR_SUBQUERY(PError.DEGREE_VIOLATION_SCALAR_SUBQUERY)
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ object ErrorMessageFormatter {
ErrorCodeString.INVALID_CHAR_VALUE_FOR_CAST -> invalidCharValueForCast(error)
ErrorCodeString.DIVISION_BY_ZERO -> divisionByZero(error)
ErrorCodeString.TYPE_UNEXPECTED -> typeUnexpected(error)
ErrorCodeString.DEGREE_VIOLATION_SCALAR_SUBQUERY -> degreeViolationScalarSubquery(error)
ErrorCodeString.ALL -> "INTERNAL ERROR: This should never have occurred."
null -> "Unrecognized error code received: ${error.code()}"
}
Expand Down Expand Up @@ -129,6 +130,15 @@ object ErrorMessageFormatter {
return "Cannot divide$dividendStr$dividendTypeStr by zero."
}

/**
* @see PError.DEGREE_VIOLATION_SCALAR_SUBQUERY
*/
private fun degreeViolationScalarSubquery(error: PError): String {
val actualType = error.getOrNull("ACTUAL", java.lang.Integer::class.java)
val actualTypeStr = prepare(actualType.toString(), " Actual degree: ", ".")
return "Degree of scalar subquery must be 1 (one).$actualTypeStr"
}

/**
* @see PError.TYPE_UNEXPECTED
*/
Expand Down Expand Up @@ -166,7 +176,7 @@ object ErrorMessageFormatter {
val cause = error.getOrNull("CAUSE", Throwable::class.java)
val writer = StringPrintWriter()
writer.appendLine("Unexpected failure encountered. Caused by: $cause.")
cause.printStackTrace(writer)
cause?.printStackTrace(writer)
return writer.w.sb.toString()
}

Expand Down
241 changes: 241 additions & 0 deletions partiql-eval/src/test/kotlin/org/partiql/eval/internal/CteTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
package org.partiql.eval.internal

import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution
import org.junit.jupiter.api.parallel.ExecutionMode
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.Mode
import org.partiql.spi.value.Datum
import org.partiql.spi.value.Field

/**
* This test file tests Common Table Expressions.
*/
class CteTests {

@ParameterizedTest
@MethodSource("successTestCases")
@Execution(ExecutionMode.CONCURRENT)
fun successTests(tc: SuccessTestCase) = tc.run()

@ParameterizedTest
@MethodSource("failureTestCases")
@Execution(ExecutionMode.CONCURRENT)
fun failureTests(tc: FailureTestCase) = tc.run()

companion object {
@JvmStatic
fun successTestCases() = listOf(
SuccessTestCase(
name = "Simple SFW",
input = """
WITH x AS (SELECT VALUE t FROM <<1, 2, 3>> AS t) SELECT VALUE x FROM x;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.integer(1),
Datum.integer(2),
Datum.integer(3)
)
),
SuccessTestCase(
name = "Multiple WITH elements and a UNION",
input = """
WITH
x AS (SELECT VALUE t FROM <<1, 2, 3>> AS t),
y AS (SELECT VALUE t FROM <<4, 5, 6>> AS t),
z AS (SELECT VALUE t FROM <<7, 8, 9>> AS t)
SELECT VALUE x FROM x UNION SELECT VALUE y FROM y UNION SELECT VALUE z FROM z;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.integer(1),
Datum.integer(2),
Datum.integer(3),
Datum.integer(4),
Datum.integer(5),
Datum.integer(6),
Datum.integer(7),
Datum.integer(8),
Datum.integer(9)
)
),
SuccessTestCase(
name = "Simple SFW with repetitive cross join",
input = """
WITH x AS (SELECT VALUE t FROM <<1>> AS t) SELECT * FROM x AS s, x;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("_1", Datum.integer(1)),
Field.of("_2", Datum.integer(1))
)
)
),
SuccessTestCase(
name = "Multiple WITH elements and cross join",
input = """
WITH
x AS (SELECT VALUE t FROM <<1>> AS t),
y AS (SELECT VALUE t FROM <<2, 3>> AS t)
SELECT * FROM x, y;
""".trimIndent(),
expected = Datum.bagVararg(
Datum.struct(
Field.of("_1", Datum.integer(1)),
Field.of("_2", Datum.integer(2))
),
Datum.struct(
Field.of("_1", Datum.integer(1)),
Field.of("_2", Datum.integer(3))
)
)
),
SuccessTestCase(
name = "Nested WITH",
input = """
WITH x AS (
WITH y AS (
SELECT VALUE t FROM <<1, 2, 3>> AS t
) SELECT VALUE v * 10 FROM y AS v
) SELECT VALUE x + 5 FROM x;

""".trimIndent(),
expected = Datum.bagVararg(
Datum.integer(15),
Datum.integer(25),
Datum.integer(35)
)
),
SuccessTestCase(
name = "Handling of subqueries",
input = """
WITH x AS (
SELECT VALUE t FROM <<1>> AS t
)
SELECT VALUE y + (SELECT * FROM x) FROM <<100>> AS y;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.bagVararg(Datum.integer(101))
),
SuccessTestCase(
name = "Handling of subqueries with tuples",
input = """
WITH x AS (
SELECT VALUE t FROM << { 'a': 1 }>> AS t
)
SELECT VALUE y + (SELECT * FROM x) FROM <<100>> AS y;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.bagVararg(Datum.integer(101))
),
SuccessTestCase(
name = "Handling of subqueries with tuples and explicit attribute",
input = """
WITH x AS (
SELECT VALUE t FROM << { 'a': 1, 'b': 2 }>> AS t
)
SELECT VALUE y + (SELECT x.a FROM x) FROM <<100>> AS y;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.bagVararg(Datum.integer(101))
),
SuccessTestCase(
name = "Handling of subqueries with WHERE",
input = """
WITH x AS (
SELECT VALUE t FROM <<1, 2, 3, 4, 5>> AS t
)
SELECT VALUE y + (SELECT * FROM x WHERE x > 4) FROM <<100>> AS y;
""".trimIndent(),
mode = Mode.STRICT(),
expected = Datum.bagVararg(Datum.integer(105))
),
)

@JvmStatic
fun failureTestCases() = listOf(
FailureTestCase(
name = "CTE with cardinality greater than 1 used in subquery",
input = """
WITH x AS (
SELECT VALUE t FROM <<1, 2>> AS t
)
SELECT VALUE y + (SELECT * FROM x) FROM <<100>> AS y;
""".trimIndent(),
),
FailureTestCase(
name = "Attempting to reference variable outside the with-list-element",
input = """
WITH x AS (
SELECT VALUE t FROM <<1, 2>> AS t
)
SELECT * FROM t; -- t should not able to be referenced.
""".trimIndent(),
),
FailureTestCase(
name = "Attempting to reference variable from within the with-list-element",
input = """
WITH x AS (
SELECT VALUE t FROM t -- t should not able to be referenced.
)
SELECT * FROM << 1, 2, 3>> AS t, x
""".trimIndent(),
),
// TODO: Figure out if this should be allowed. In PostgreSQL, it is allowed. In SQL Spec, I'm not sure.
// As such, updating the implementation to allow for this would be a non-breaking change.
FailureTestCase(
name = "Attempting to reference another with list element",
input = """
WITH
x AS (SELECT VALUE t FROM << 1, 2, 3 >> t),
y AS (SELECT VALUE x FROM x)
SELECT * FROM y; -- y should not be able to be referenced.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-- y should not be able to be referenced.

y is fine, no? It's the x that should not be able to be referenced inside y?

""".trimIndent(),
),
FailureTestCase(
name = "Attempting to reference another with list element (2)",
input = """
WITH
x AS (SELECT VALUE t FROM << 1, 2, 3 >> t),
y AS (SELECT VALUE x FROM x)
SELECT * FROM x, y; -- x & y should not be able to be referenced
""".trimIndent(),
),
FailureTestCase(
name = "Attempting to create a recursive (non-labeled) CTE",
input = """
WITH x AS (
SELECT VALUE t FROM t -- t should not able to be referenced.
)
SELECT * FROM << 1, 2, 3>> AS t, x
""".trimIndent(),
),
)
}

// TODO: Figure out the right behavior here.
@Test
@Disabled(
"""
This _maybe_ should fail, since CTE "y" references a non-existing variable "s". In the specification, it is a bit
vague about what to do in this scenario. Currently, due to https://partiql.org/partiql-lang/#sec:schema-in-tuple-path,
the implementation does not throw an error at compile-time. It is only during the evaluation of a non-existent
variable that it throws an error. Therefore, even though we are emitting a warning when compiling the reference to
"s", it is never used at runtime (and therefore an error is never emitted).
"""
)
fun nonReferencedBadCTE() {
val tc = FailureTestCase(
name = "Attempting to reference another with list element (3)",
input = """
WITH
x AS (SELECT VALUE t FROM << 1, 2, 3 >> t),
y AS (SELECT VALUE s FROM s) -- this is rubbish!
SELECT * FROM x;
""".trimIndent(),
mode = Mode.STRICT(),
)
tc.run()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,22 @@ class Global(
}

public class SuccessTestCase(
val name: String,
val input: String,
val expected: Datum,
val mode: Mode = Mode.PERMISSIVE(),
val globals: List<Global> = emptyList(),
val jvmEquality: Boolean = false
) : PTestCase {

constructor(
input: String,
expected: Datum,
mode: Mode = Mode.PERMISSIVE(),
globals: List<Global> = emptyList(),
jvmEquality: Boolean = false
) : this("no_name", input, expected, mode, globals, jvmEquality)

constructor(
input: String,
expected: PartiQLValue,
Expand Down Expand Up @@ -96,15 +105,23 @@ public class SuccessTestCase(
}

override fun toString(): String {
return input
return "$name ($mode): $input"
}
}

public class FailureTestCase(
val name: String,
val input: String,
val mode: Mode = Mode.STRICT(), // default to run in STRICT mode
val globals: List<Global> = emptyList(),
) : PTestCase {

constructor(
input: String,
mode: Mode = Mode.STRICT(),
globals: List<Global> = emptyList()
) : this("no_name", input, mode, globals)

private val compiler = PartiQLCompiler.standard()
private val parser = PartiQLParser.standard()
private val planner = PartiQLPlanner.standard()
Expand Down Expand Up @@ -146,4 +163,8 @@ public class FailureTestCase(
error(message)
}
}

override fun toString(): String {
return "$name ($mode): $input"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ internal object PErrors {
)
}

internal fun degreeViolationScalarSubquery(actual: Int, location: SourceLocation? = null): PError {
return PError(
PError.DEGREE_VIOLATION_SCALAR_SUBQUERY,
Severity.ERROR(),
PErrorKind.SEMANTIC(),
location,
mapOf("ACTUAL" to actual),
)
}

private fun internalError(cause: Throwable): PError = PError(
PError.INTERNAL_ERROR,
Severity.ERROR(),
Expand Down
Loading
Loading