Skip to content

Commit 3a04ba7

Browse files
authored
Fix the scope analysis of the CTEs (#729)
1 parent 8325e40 commit 3a04ba7

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java

+8-7
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ protected Scope visitTable(Table node, Optional<Scope> scope)
160160
if (withQuery.isPresent()) {
161161
// currently we only care about the table that is actually a model instead of a alias table that use cte table
162162
// return empty scope here.
163-
Scope outputScope = createScopeForCommonTableExpression(node, withQuery.get(), scope);
163+
Optional<Scope> withScope = analysis.tryGetScope(withQuery.get().getQuery());
164+
Scope outputScope = createScopeForCommonTableExpression(node, withQuery.get(), withScope);
164165
analysis.setScope(node, outputScope);
165166
return outputScope;
166167
}
@@ -202,8 +203,7 @@ protected Scope visitTable(Table node, Optional<Scope> scope)
202203
private Scope createScopeForCommonTableExpression(Table table, WithQuery withQuery, Optional<Scope> scope)
203204
{
204205
Query query = withQuery.getQuery();
205-
Analysis analyzed = new Analysis(query);
206-
Optional<Scope> queryScope = Optional.ofNullable(analyze(analyzed, query, sessionContext, wrenMDL));
206+
Optional<Scope> queryScope = analysis.tryGetScope(query);
207207
List<Field> fields;
208208
Optional<List<Identifier>> columnNames = withQuery.getColumnNames();
209209
if (columnNames.isPresent()) {
@@ -243,7 +243,7 @@ private List<Field> createScopeForQuery(Query query, QualifiedName scopeName, Op
243243
else {
244244
SingleColumn singleColumn = (SingleColumn) selectItem;
245245
String name = singleColumn.getAlias().map(Identifier::getValue)
246-
.or(() -> Optional.ofNullable(QueryUtil.getQualifiedName(singleColumn.getExpression())).map(QualifiedName::toString))
246+
.or(() -> Optional.ofNullable(QueryUtil.getQualifiedName(singleColumn.getExpression()).getSuffix()))
247247
.orElse(singleColumn.getExpression().toString());
248248
if (scope.isPresent()) {
249249
Optional<Field> fieldOptional = scope.get().getRelationType().resolveAnyField(QueryUtil.getQualifiedName(singleColumn.getExpression()));
@@ -402,7 +402,7 @@ private Scope analyzeFrom(QuerySpecification node, Optional<Scope> scope)
402402

403403
private void analyzeWhere(Expression node, Scope scope)
404404
{
405-
ExpressionAnalysis expressionAnalysis = analyzeExpression(node, scope);
405+
analyzeExpression(node, scope);
406406
}
407407

408408
private void analyzeWindowSpecification(WindowSpecification windowSpecification, Scope scope)
@@ -561,7 +561,7 @@ protected Scope visitTableSubquery(TableSubquery node, Optional<Scope> scope)
561561
private Optional<Scope> analyzeWith(Query node, Optional<Scope> scope)
562562
{
563563
if (node.getWith().isEmpty()) {
564-
return Optional.empty();
564+
return scope.map(s -> Scope.builder().parent(Optional.of(s)).build());
565565
}
566566

567567
With with = node.getWith().get();
@@ -572,7 +572,8 @@ private Optional<Scope> analyzeWith(Query node, Optional<Scope> scope)
572572
if (withScopeBuilder.containsNamedQuery(name)) {
573573
throw new IllegalArgumentException(format("WITH query name '%s' specified more than once", name));
574574
}
575-
process(withQuery.getQuery(), withScopeBuilder.build());
575+
Scope withQueryScope = process(withQuery.getQuery(), withScopeBuilder.build());
576+
analysis.setScope(withQuery.getQuery(), withQueryScope);
576577
withScopeBuilder.namedQuery(name, withQuery);
577578
}
578579

wren-base/src/test/java/io/wren/base/sqlrewrite/analyzer/TestDecisionPointAnalyzer.java

+52
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,58 @@ WITH t1 as (SELECT * FROM customer), t2 as (SELECT * FROM orders)
676676
new ExprSource("t1.custkey", "customer", "custkey", new NodeLocation(2, 33)),
677677
new ExprSource("t2.custkey", "orders", "custkey", new NodeLocation(2, 50))));
678678
}
679+
680+
statement = parseSql("""
681+
WITH t1 as (SELECT customer.custkey FROM customer),
682+
t2 as (SELECT t1.custkey FROM t1)
683+
SELECT t2.custkey FROM t2
684+
""");
685+
result = DecisionPointAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, mdl);
686+
assertThat(result.size()).isEqualTo(3);
687+
QueryAnalysis queryAnalysis = result.get(2);
688+
assertThat(queryAnalysis.getSelectItems().get(0).getExprSources().size()).isEqualTo(1);
689+
690+
statement = parseSql("""
691+
WITH t1 as (SELECT customer.custkey FROM customer),
692+
t2 as (SELECT t1.custkey as custkey_alias FROM t1)
693+
SELECT custkey_alias FROM t2
694+
""");
695+
result = DecisionPointAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, mdl);
696+
assertThat(result.size()).isEqualTo(3);
697+
queryAnalysis = result.get(2);
698+
assertThat(queryAnalysis.getSelectItems().get(0).getExprSources().size()).isEqualTo(1);
699+
700+
statement = parseSql("""
701+
WITH t1 as (SELECT customer.custkey FROM customer),
702+
t2 as (SELECT * FROM t1)
703+
SELECT custkey FROM t2
704+
""");
705+
result = DecisionPointAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, mdl);
706+
assertThat(result.size()).isEqualTo(3);
707+
queryAnalysis = result.get(2);
708+
assertThat(queryAnalysis.getSelectItems().get(0).getExprSources().size()).isEqualTo(1);
709+
710+
statement = parseSql("""
711+
WITH t1 as (SELECT customer.custkey FROM customer),
712+
t2 as (SELECT * FROM t1)
713+
SELECT custkey FROM t2
714+
""");
715+
result = DecisionPointAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, mdl);
716+
assertThat(result.size()).isEqualTo(3);
717+
queryAnalysis = result.get(2);
718+
assertThat(queryAnalysis.getSelectItems().get(0).getExprSources().size()).isEqualTo(1);
719+
720+
721+
// we only analyze the top-level expression
722+
statement = parseSql("""
723+
WITH t1 as (SELECT customer.custkey FROM customer),
724+
t2 as (SELECT (t1.custkey + 1) as custkey_plus FROM t1)
725+
SELECT custkey_plus FROM t2
726+
""");
727+
result = DecisionPointAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, mdl);
728+
assertThat(result.size()).isEqualTo(3);
729+
queryAnalysis = result.get(2);
730+
assertThat(queryAnalysis.getSelectItems().get(0).getExprSources().size()).isEqualTo(0);
679731
}
680732

681733
@Test

0 commit comments

Comments
 (0)