From f8edda30d94eb2e08d29b41816e811242ef08650 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 7 Mar 2025 20:25:30 +0800 Subject: [PATCH] exclude the cte table when source node checking --- .../wren/base/sqlrewrite/WrenSqlRewrite.java | 24 +++++-- .../io/wren/testing/TestMDLResourceV2.java | 62 +++++++++++++++++++ 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java index e80185c65..a98b5def9 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java @@ -34,6 +34,7 @@ import io.wren.base.dto.Model; import io.wren.base.dto.Relationable; import io.wren.base.sqlrewrite.analyzer.Analysis; +import io.wren.base.sqlrewrite.analyzer.Scope; import io.wren.base.sqlrewrite.analyzer.StatementAnalyzer; import org.jgrapht.graph.DirectedAcyclicGraph; import org.jgrapht.graph.GraphCycleProhibitedException; @@ -98,12 +99,15 @@ public Statement apply(Statement root, SessionContext sessionContext, Analysis a // Some node be applied `count(*)` which won't be collected but its source is required. analysis.getRequiredSourceNodes().forEach(node -> { - String tableName = analysis.getSourceNodeNames(node).map(QualifiedName::toString) - .orElseThrow(() -> new IllegalArgumentException(format("source node name not found: %s", node))); - if (!tableRequiredFields.containsKey(tableName)) { - Relationable relationable = wrenMDL.getRelationable(tableName) - .orElseThrow(() -> new IllegalArgumentException(format("dataset not found: %s", tableName))); - tableRequiredFields.put(tableName, relationable.getColumns().stream().filter(column -> !column.isCalculated()).map(Column::getName).collect(toImmutableSet())); + Scope scope = analysis.getScope(node); + if (tryGetTableName(node).flatMap(name -> scope.getNamedQuery(name.toString())).isEmpty()) { + String tableName = analysis.getSourceNodeNames(node).map(QualifiedName::toString) + .orElseThrow(() -> new IllegalArgumentException(format("source node name not found: %s", node))); + if (!tableRequiredFields.containsKey(tableName)) { + Relationable relationable = wrenMDL.getRelationable(tableName) + .orElseThrow(() -> new IllegalArgumentException(format("dataset not found: %s", tableName))); + tableRequiredFields.put(tableName, relationable.getColumns().stream().filter(column -> !column.isCalculated()).map(Column::getName).collect(toImmutableSet())); + } } }); @@ -142,6 +146,14 @@ public Statement apply(Statement root, SessionContext sessionContext, Analysis a } } + private Optional tryGetTableName(Node node) + { + if (node instanceof Table) { + return Optional.of(((Table) node).getName()); + } + return Optional.empty(); + } + private void addDescriptor(String name, Set requiredFields, WrenMDL wrenMDL, ImmutableList.Builder descriptorsBuilder) { if (wrenMDL.getModel(name).isPresent()) { diff --git a/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java b/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java index 1c160da67..5378ab94b 100644 --- a/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java +++ b/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java @@ -247,6 +247,68 @@ public void testSetManyToMany() } + @Test + public void testPlanCountWithClause() + { + Manifest manifest = Manifest.builder() + .setCatalog("wrenai") + .setSchema("tpch") + .setModels(List.of( + model("Customer", "SELECT * FROM tpch.customer", + List.of(column("custkey", "integer", null, false, "c_custkey"), + column("name", "varchar", null, false, "c_name"))), + model("Orders", "SELECT * FROM tpch.orders", + List.of(column("orderkey", "integer", null, false, "o_orderkey"), + column("custkey", "integer", null, false, "o_custkey"), + column("customer", "Customer", "CustomerOrders", false), + calculatedColumn("customer_name", "varchar", "customer.name")), + "orderkey"))) + .setRelationships(List.of(relationship("CustomerOrders", List.of("Customer", "Orders"), JoinType.MANY_TO_MANY, "Customer.custkey = Orders.custkey"))) + .build(); + String manifestStr = base64Encode(toJson(manifest)); + DryPlanDtoV2 dryPlanDto = new DryPlanDtoV2(manifestStr, "select count(*) from (with orders_custkey as (select custkey from \"Orders\") select * from orders_custkey) "); + String dryPlan = dryPlanV2(dryPlanDto); + assertThat(dryPlan).isEqualTo(""" + WITH + "Orders" AS ( + SELECT + "Orders"."orderkey" "orderkey" + , "Orders"."custkey" "custkey" + FROM + ( + SELECT + "Orders"."orderkey" "orderkey" + , "Orders"."custkey" "custkey" + FROM + ( + SELECT + o_orderkey "orderkey" + , o_custkey "custkey" + FROM + ( + SELECT * + FROM + tpch.orders + ) "Orders" + ) "Orders" + ) "Orders" + )\s + SELECT count(*) + FROM + ( + WITH + orders_custkey AS ( + SELECT custkey + FROM + "Orders" + )\s + SELECT * + FROM + orders_custkey + ) t + """); + } + private String toJson(Manifest manifest) { return MANIFEST_JSON_CODEC.toJson(manifest);