Skip to content

Commit 5b86afc

Browse files
authored
Fix bag constructor parsing (#1500)
1 parent 4dd0972 commit 5b86afc

File tree

5 files changed

+54
-4
lines changed

5 files changed

+54
-4
lines changed

partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParser.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ internal class PartiQLPigParser(val customTypes: List<CustomType> = listOf()) :
9595
val tokenStream = createTokenStream(queryStream)
9696
val parser = parserInit(tokenStream)
9797
val tree = parser.root()
98-
val visitor = PartiQLPigVisitor(customTypes, tokenStream.parameterIndexes)
98+
val visitor = PartiQLPigVisitor(tokenStream, customTypes, tokenStream.parameterIndexes)
9999
return visitor.visit(tree) as PartiqlAst.Statement
100100
}
101101

partiql-lang/src/main/kotlin/org/partiql/lang/syntax/impl/PartiQLPigVisitor.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import com.amazon.ionelement.api.ionNull
3333
import com.amazon.ionelement.api.ionString
3434
import com.amazon.ionelement.api.ionSymbol
3535
import com.amazon.ionelement.api.loadSingleElement
36+
import org.antlr.v4.runtime.CommonTokenStream
3637
import org.antlr.v4.runtime.ParserRuleContext
3738
import org.antlr.v4.runtime.Token
3839
import org.antlr.v4.runtime.tree.TerminalNode
@@ -62,6 +63,7 @@ import org.partiql.lang.util.getPrecisionFromTimeString
6263
import org.partiql.lang.util.unaryMinus
6364
import org.partiql.parser.internal.antlr.PartiQLParser
6465
import org.partiql.parser.internal.antlr.PartiQLParserBaseVisitor
66+
import org.partiql.parser.internal.antlr.PartiQLTokens
6567
import org.partiql.pig.runtime.SymbolPrimitive
6668
import org.partiql.value.datetime.DateTimeException
6769
import org.partiql.value.datetime.TimeZone
@@ -116,6 +118,7 @@ import java.time.format.DateTimeParseException
116118
* There could be clever ways of exploiting this, to avoid the dispatch via `visit()`.
117119
*/
118120
internal class PartiQLPigVisitor(
121+
private val tokens: CommonTokenStream,
119122
val customTypes: List<CustomType> = listOf(),
120123
private val parameterIndexes: Map<Int, Int> = mapOf(),
121124
) :
@@ -1507,6 +1510,12 @@ internal class PartiQLPigVisitor(
15071510
*/
15081511

15091512
override fun visitBag(ctx: PartiQLParser.BagContext) = PartiqlAst.build {
1513+
// Prohibit hidden characters between angle brackets
1514+
val startTokenIndex = ctx.start.tokenIndex
1515+
val endTokenIndex = ctx.stop.tokenIndex
1516+
if (tokens.getHiddenTokensToRight(startTokenIndex, PartiQLTokens.HIDDEN) != null || tokens.getHiddenTokensToLeft(endTokenIndex, PartiQLTokens.HIDDEN) != null) {
1517+
throw ParserException("Invalid bag expression", ErrorCode.PARSE_INVALID_QUERY)
1518+
}
15101519
val exprList = ctx.expr().map { visitExpr(it) }
15111520
bag(exprList, ctx.ANGLE_LEFT(0).getSourceMetaContainer())
15121521
}

partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5019,4 +5019,38 @@ class PartiQLParserTest : PartiQLParserTestBase() {
50195019
lit(ionInt(1))
50205020
)
50215021
}
5022+
5023+
// regression tests for bag constructor angle bracket
5024+
@Test
5025+
fun testBagConstructor() = assertExpression("<<<<1>>>>") {
5026+
bag(
5027+
bag(
5028+
lit(ionInt(1))
5029+
)
5030+
)
5031+
}
5032+
5033+
@Test
5034+
fun testSpacesInBagConstructor() = checkInputThrowingParserException(
5035+
"< < < < 1 > > > >",
5036+
ErrorCode.PARSE_UNEXPECTED_TOKEN, // partiql-ast parser ErrorCode
5037+
expectErrorContextValues = mapOf(
5038+
Property.LINE_NUMBER to 1L,
5039+
Property.COLUMN_NUMBER to 1L,
5040+
Property.TOKEN_DESCRIPTION to PartiQLParser.ANGLE_LEFT.getAntlrDisplayString(),
5041+
Property.TOKEN_VALUE to ION.newSymbol("<")
5042+
)
5043+
)
5044+
5045+
@Test
5046+
fun testCommentsInBagConstructor() = checkInputThrowingParserException(
5047+
"</* some comment */<<<1>>>>",
5048+
ErrorCode.PARSE_UNEXPECTED_TOKEN, // partiql-ast parser ErrorCode
5049+
expectErrorContextValues = mapOf(
5050+
Property.LINE_NUMBER to 1L,
5051+
Property.COLUMN_NUMBER to 1L,
5052+
Property.TOKEN_DESCRIPTION to PartiQLParser.ANGLE_LEFT.getAntlrDisplayString(),
5053+
Property.TOKEN_VALUE to ION.newSymbol("<")
5054+
)
5055+
)
50225056
}

partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTestBase.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ abstract class PartiQLParserTestBase : TestBase() {
152152
input: String,
153153
errorCode: ErrorCode,
154154
expectErrorContextValues: Map<Property, Any>,
155-
targets: Array<ParserTarget> = arrayOf(ParserTarget.DEFAULT),
156155
assertContext: Boolean = true,
157156
): Unit = forEachTarget {
158157
softAssert {
159158
try {
160159
parser.parseAstStatement(input)
161160
fail("Expected ParserException but there was no Exception")
162161
} catch (ex: ParserException) {
163-
// split parser target does not use ErrorCode
162+
// NOTE: only perform error code and error context checks for `ParserTarget.EXPERIMENTAL` (partiql-ast
163+
// parser).
164164
if (assertContext && (this@forEachTarget == ParserTarget.EXPERIMENTAL)) {
165165
checkErrorAndErrorContext(errorCode, ex, expectErrorContextValues)
166166
}

partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ internal class PartiQLParserDefault : PartiQLParser {
425425
*/
426426
@OptIn(PartiQLValueExperimental::class)
427427
private class Visitor(
428+
private val tokens: CommonTokenStream,
428429
private val locations: SourceLocations.Mutable,
429430
private val parameters: Map<Int, Int> = mapOf(),
430431
) : PartiQLParserBaseVisitor<AstNode>() {
@@ -442,7 +443,7 @@ internal class PartiQLParserDefault : PartiQLParser {
442443
tree: GeneratedParser.RootContext,
443444
): PartiQLParser.Result {
444445
val locations = SourceLocations.Mutable()
445-
val visitor = Visitor(locations, tokens.parameterIndexes)
446+
val visitor = Visitor(tokens, locations, tokens.parameterIndexes)
446447
val root = visitor.visitAs<AstNode>(tree) as Statement
447448
return PartiQLParser.Result(
448449
source = source,
@@ -2022,6 +2023,12 @@ internal class PartiQLParserDefault : PartiQLParser {
20222023
*/
20232024

20242025
override fun visitBag(ctx: GeneratedParser.BagContext) = translate(ctx) {
2026+
// Prohibit hidden characters between angle brackets
2027+
val startTokenIndex = ctx.start.tokenIndex
2028+
val endTokenIndex = ctx.stop.tokenIndex
2029+
if (tokens.getHiddenTokensToRight(startTokenIndex, GeneratedLexer.HIDDEN) != null || tokens.getHiddenTokensToLeft(endTokenIndex, GeneratedLexer.HIDDEN) != null) {
2030+
throw error(ctx, "Invalid bag expression")
2031+
}
20252032
val expressions = visitOrEmpty<Expr>(ctx.expr())
20262033
exprCollection(Expr.Collection.Type.BAG, expressions)
20272034
}

0 commit comments

Comments
 (0)