Skip to content

Commit df58bc6

Browse files
committed
fix decimal format
1 parent ac1a893 commit df58bc6

File tree

4 files changed

+37
-13
lines changed

4 files changed

+37
-13
lines changed

graphmdl-sqlrewrite/src/main/java/io/graphmdl/sqlrewrite/PreAggregationRewrite.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
import static io.graphmdl.sqlrewrite.Utils.analyzeFrom;
5050
import static io.graphmdl.sqlrewrite.Utils.toCatalogSchemaTableName;
5151
import static io.trino.sql.QueryUtil.getQualifiedName;
52-
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
52+
import static io.trino.sql.SqlFormatter.Dialect.DUCKDB;
53+
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL;
5354
import static java.lang.String.format;
5455
import static java.util.Locale.ENGLISH;
5556
import static java.util.Objects.requireNonNull;
@@ -69,12 +70,12 @@ public static Optional<String> rewrite(
6970
GraphMDL graphMDL)
7071
{
7172
try {
72-
Statement statement = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DOUBLE));
73+
Statement statement = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL));
7374
PreAggregationAnalysis aggregationAnalysis = new PreAggregationAnalysis();
7475
Statement rewritten = (Statement) new Rewriter(sessionContext, converter, graphMDL, aggregationAnalysis).process(statement, Optional.empty());
7576
if (rewritten instanceof Query
7677
&& aggregationAnalysis.onlyPreAggregationTables()) {
77-
return Optional.of(SqlFormatter.formatSql(rewritten));
78+
return Optional.of(SqlFormatter.formatSql(rewritten, DUCKDB));
7879
}
7980
}
8081
catch (Exception e) {

graphmdl-sqlrewrite/src/test/java/io/graphmdl/TestPreAggregationRewrite.java

+17-5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import static com.google.common.base.MoreObjects.toStringHelper;
3535
import static io.graphmdl.base.GraphMDLTypes.DATE;
36+
import static io.graphmdl.base.GraphMDLTypes.DECIMAL;
3637
import static io.graphmdl.base.GraphMDLTypes.INTEGER;
3738
import static io.graphmdl.base.GraphMDLTypes.TIMESTAMP;
3839
import static io.graphmdl.base.GraphMDLTypes.VARCHAR;
@@ -42,8 +43,9 @@
4243
import static io.graphmdl.base.dto.TimeGrain.TimeUnit.YEAR;
4344
import static io.graphmdl.base.dto.TimeGrain.timeGrain;
4445
import static io.graphmdl.testing.AbstractTestFramework.withDefaultCatalogSchema;
46+
import static io.trino.sql.SqlFormatter.Dialect.DUCKDB;
4547
import static io.trino.sql.SqlFormatter.formatSql;
46-
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
48+
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL;
4749
import static java.lang.String.format;
4850
import static org.assertj.core.api.Assertions.assertThat;
4951

@@ -95,7 +97,7 @@ public void init()
9597
List.of(
9698
column("author", VARCHAR, null, true),
9799
column("album_name", VARCHAR, null, true, "Album.name")),
98-
List.of(column("price", INTEGER, null, true, "avg(Album.price)")),
100+
List.of(column("price", DECIMAL, null, true, "avg(Album.price)")),
99101
List.of(
100102
timeGrain("p_date", "Album.publish_date", List.of(YEAR)),
101103
timeGrain("r_date", "Album.release_date", List.of(YEAR))),
@@ -300,6 +302,16 @@ public void testTableAliasScope()
300302
"with test_a as (with AvgCollection as (select * from table_Collection) select * from AvgCollection) select * from table_AvgCollection");
301303
}
302304

305+
@Test
306+
public void testDecimalRewrite()
307+
{
308+
assertRewrite(
309+
"SELECT * from AvgCollection where avg = DECIMAL '1.0'",
310+
"graphmdl",
311+
"test",
312+
"SELECT * FROM table_AvgCollection WHERE avg = 1.0");
313+
}
314+
303315
@DataProvider(name = "unexpectedStatementProvider")
304316
public Object[][] unexpectedStatementProvider()
305317
{
@@ -395,9 +407,9 @@ private void assertRewrite(
395407
defaultSchema,
396408
tableConverter).orElseThrow(() -> new AssertionError("No rewrite result"));
397409

398-
Statement expect = sqlParser.createStatement(expectSql, new ParsingOptions(AS_DOUBLE));
399-
Statement actualStatement = sqlParser.createStatement(result, new ParsingOptions(AS_DOUBLE));
400-
assertThat(result).isEqualTo(formatSql(expect));
410+
Statement expect = sqlParser.createStatement(expectSql, new ParsingOptions(AS_DECIMAL));
411+
Statement actualStatement = sqlParser.createStatement(result, new ParsingOptions(AS_DECIMAL));
412+
assertThat(result).isEqualTo(formatSql(expect, DUCKDB));
401413
assertThat(actualStatement).isEqualTo(expect);
402414
}
403415

trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java

+15-5
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
import static io.trino.sql.RowPatternFormatter.formatPattern;
116116
import static io.trino.sql.SqlFormatter.Dialect.BIGQUERY;
117117
import static io.trino.sql.SqlFormatter.Dialect.DEFAULT;
118+
import static io.trino.sql.SqlFormatter.Dialect.DUCKDB;
118119
import static io.trino.sql.SqlFormatter.formatName;
119120
import static io.trino.sql.SqlFormatter.formatSql;
120121
import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL;
@@ -138,7 +139,7 @@ public static String formatExpression(Expression expression, Dialect dialect)
138139

139140
private static String formatIdentifier(String s, Dialect dialect)
140141
{
141-
if (dialect == DEFAULT) {
142+
if (dialect == DEFAULT || dialect == DUCKDB) {
142143
return '"' + s.replace("\"", "\"\"") + '"';
143144
}
144145
else if (dialect == BIGQUERY) {
@@ -288,7 +289,7 @@ protected String visitArrayConstructor(ArrayConstructor node, Void context)
288289
protected String visitSubscriptExpression(SubscriptExpression node, Void context)
289290
{
290291
String subscript;
291-
if (dialect == DEFAULT) {
292+
if (dialect == DEFAULT || dialect == DUCKDB) {
292293
subscript = formatSql(node.getIndex(), dialect);
293294
}
294295
else if (dialect == BIGQUERY) {
@@ -313,12 +314,18 @@ protected String visitLongLiteral(LongLiteral node, Void context)
313314
@Override
314315
protected String visitDoubleLiteral(DoubleLiteral node, Void context)
315316
{
317+
if (dialect == DUCKDB) {
318+
return String.valueOf(node.getValue());
319+
}
316320
return doubleFormatter.get().format(node.getValue());
317321
}
318322

319323
@Override
320324
protected String visitDecimalLiteral(DecimalLiteral node, Void context)
321325
{
326+
if (dialect == DUCKDB) {
327+
return node.getValue();
328+
}
322329
// TODO return node value without "DECIMAL '..'" when FeaturesConfig#parseDecimalLiteralsAsDouble switch is removed
323330
return "DECIMAL '" + node.getValue() + "'";
324331
}
@@ -338,6 +345,9 @@ protected String visitTimeLiteral(TimeLiteral node, Void context)
338345
@Override
339346
protected String visitTimestampLiteral(TimestampLiteral node, Void context)
340347
{
348+
if (dialect == DUCKDB) {
349+
return "'" + node.getValue() + "'";
350+
}
341351
return "TIMESTAMP '" + node.getValue() + "'";
342352
}
343353

@@ -612,7 +622,7 @@ protected String visitLikePredicate(LikePredicate node, Void context)
612622
.append(" LIKE ")
613623
.append(process(node.getPattern(), context));
614624

615-
if (dialect == DEFAULT) {
625+
if (dialect == DEFAULT || dialect == DUCKDB) {
616626
node.getEscape().ifPresent(escape -> builder.append(" ESCAPE ")
617627
.append(process(escape, context)));
618628
}
@@ -936,7 +946,7 @@ static String formatStringLiteral(String s, Dialect dialect)
936946
}
937947

938948
StringBuilder builder = new StringBuilder();
939-
if (dialect == DEFAULT) {
949+
if (dialect == DEFAULT || dialect == DUCKDB) {
940950
builder.append("U&");
941951
}
942952
builder.append("'");
@@ -952,7 +962,7 @@ static String formatStringLiteral(String s, Dialect dialect)
952962
builder.append(ch);
953963
}
954964
else {
955-
if (dialect == DEFAULT) {
965+
if (dialect == DEFAULT || dialect == DUCKDB) {
956966
if (codePoint <= 0xFFFF) {
957967
builder.append('\\');
958968
builder.append(format("%04X", codePoint));

trino-parser/src/main/java/io/trino/sql/SqlFormatter.java

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ public enum Dialect
165165
{
166166
DEFAULT,
167167
BIGQUERY,
168+
DUCKDB,
168169
}
169170

170171
private SqlFormatter() {}

0 commit comments

Comments
 (0)