Skip to content

fix(core): fix the field resolving for the relationship column #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ protected Void visitDereferenceExpression(DereferenceExpression node, Void conte
{
QualifiedName qualifiedName = getQualifiedName(node);
if (qualifiedName != null) {
scope.getRelationType().getFields().stream()
.filter(field -> field.canResolve(qualifiedName))
.findAny()
scope.getRelationType().resolveAnyField(qualifiedName)
.ifPresent(field -> referenceFields.put(NodeRef.of(node), field));
}
else {
Expand All @@ -97,9 +95,7 @@ protected Void visitDereferenceExpression(DereferenceExpression node, Void conte
protected Void visitIdentifier(Identifier node, Void context)
{
QualifiedName qualifiedName = QualifiedName.of(ImmutableList.of(node));
scope.getRelationType().getFields().stream()
.filter(field -> field.canResolve(qualifiedName))
.findAny()
scope.getRelationType().resolveAnyField(qualifiedName)
.ifPresent(field -> referenceFields.put(NodeRef.of(node), field));
return null;
}
Expand All @@ -108,9 +104,7 @@ protected Void visitIdentifier(Identifier node, Void context)
protected Void visitSubscriptExpression(SubscriptExpression node, Void context)
{
QualifiedName qualifiedName = getQualifiedName(node.getBase());
scope.getRelationType().getFields().stream()
.filter(field -> field.canResolve(qualifiedName))
.findAny()
scope.getRelationType().resolveAnyField(qualifiedName)
.ifPresent(field -> referenceFields.put(NodeRef.of(node), field));
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ public class Field
// e.g. select table.col_1 from select * from table; => is this legal ? this is false
private final Optional<QualifiedName> relationAlias;
private final CatalogSchemaTableName tableName;
// the name of the column in the table
private final String columnName;
// the name of the dataset where the column comes from
private final Optional<String> sourceDatasetName;
// the name of the column in the dataset where the column comes from
private final Optional<Column> sourceColumn;
// the name of the column in the query (If the column is aliased, this is the alias, otherwise it's the column name)
private final Optional<String> name;

private Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@
import io.wren.base.SessionContext;
import io.wren.base.WrenTypes;
import io.wren.base.dto.Column;
import io.wren.base.dto.JoinType;
import io.wren.base.dto.Manifest;
import io.wren.base.dto.Metric;
import io.wren.base.dto.Model;
import io.wren.base.sqlrewrite.AbstractTestFramework;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Optional;
import java.util.function.Function;

import static io.wren.base.CatalogSchemaTableName.catalogSchemaTableName;
import static io.wren.base.WrenMDL.EMPTY;
import static io.wren.base.WrenMDL.fromManifest;
import static io.wren.base.dto.Column.relationshipColumn;
import static io.wren.base.dto.Relationship.relationship;
import static io.wren.base.sqlrewrite.Utils.parseSql;
import static io.wren.base.sqlrewrite.analyzer.StatementAnalyzer.analyze;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -209,6 +213,28 @@ WITH t1 as (SELECT "c1", "c2" FROM (select * from test.test.foo) table_1) select
assertThat(scope.get().getRelationType().getFields().get(1).getName().get()).isEqualTo("c2");
}

@Test
public void testScopeWithRelationship()
{
SessionContext sessionContext = SessionContext.builder().setCatalog("test").setSchema("test").build();
Manifest manifest = Manifest.builder()
.setCatalog("test")
.setSchema("test")
.setModels(ImmutableList.of(
Model.model("table_1", "SELECT * FROM foo", ImmutableList.of(
varcharColumn("c1"), varcharColumn("c2"), relationshipColumn("table_2", "table_2", "relationship_1_2"))),
Model.model("table_2", "SELECT * FROM bar", ImmutableList.of(varcharColumn("c1"), varcharColumn("c2")))))
.setRelationships(ImmutableList.of(
relationship("relationship_1_2", List.of("table_1", "table_2"), JoinType.ONE_TO_ONE, "table_1.c1 = table_2.c1")))
.build();

Statement statement = parseSql("SELECT table_2.c2 FROM table_1 JOIN table_2 ON table_1.c1 = table_2.c1");
Analysis analysis = new Analysis(statement);
analyze(analysis, statement, sessionContext, fromManifest(manifest));
assertThat(analysis.getCollectedColumns().get(catalogSchemaTableName("test", "test", "table_1"))).containsExactly("c1");
assertThat(analysis.getCollectedColumns().get(catalogSchemaTableName("test", "test", "table_2"))).containsExactly("c1", "c2");
}

private static Column varcharColumn(String name)
{
return Column.column(name, "VARCHAR", null, false, null);
Expand Down
Loading