Skip to content

Commit 19e4b2a

Browse files
authored
fix: update fstring implementaion for sqlite dialect (#1638)
* fix: fix fstring concat behavior in sqlite * test: add unit test for sqlite fstring * fix: style
1 parent 5280161 commit 19e4b2a

File tree

3 files changed

+133
-22
lines changed

3 files changed

+133
-22
lines changed

prql-compiler/src/sql/dialect.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ pub(super) trait DialectHandler {
105105
fn intersect_all(&self) -> bool {
106106
self.except_all()
107107
}
108+
109+
/// Support for CONCAT function.
110+
/// When not supported we fallback to use `||` as concat operator.
111+
fn has_concat_function(&self) -> bool {
112+
true
113+
}
108114
}
109115

110116
impl DialectHandler for GenericDialect {}
@@ -117,6 +123,10 @@ impl DialectHandler for SQLiteDialect {
117123
fn except_all(&self) -> bool {
118124
false
119125
}
126+
127+
fn has_concat_function(&self) -> bool {
128+
false
129+
}
120130
}
121131

122132
impl DialectHandler for MsSqlDialect {

prql-compiler/src/sql/gen_expr.rs

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,7 @@ pub(super) fn translate_expr_kind(item: ExprKind, ctx: &mut Context) -> Result<s
9393

9494
sql_ast::Expr::Identifier(sql_ast::Ident::new(string))
9595
}
96-
ExprKind::FString(f_string_items) => {
97-
let args = f_string_items
98-
.into_iter()
99-
.map(|item| match item {
100-
InterpolateItem::String(string) => {
101-
Ok(sql_ast::Expr::Value(Value::SingleQuotedString(string)))
102-
}
103-
InterpolateItem::Expr(node) => translate_expr_kind(node.kind, ctx),
104-
})
105-
.map(|r| r.map(|e| FunctionArg::Unnamed(FunctionArgExpr::Expr(e))))
106-
.collect::<Result<Vec<_>>>()?;
107-
108-
sql_ast::Expr::Function(Function {
109-
name: ObjectName(vec![sql_ast::Ident::new("CONCAT")]),
110-
args,
111-
distinct: false,
112-
over: None,
113-
special: false,
114-
})
115-
}
96+
ExprKind::FString(f_string_items) => translate_fstring(f_string_items, ctx)?,
11697
ExprKind::Literal(l) => translate_literal(l)?,
11798
ExprKind::Switch(mut cases) => {
11899
let default = cases
@@ -352,6 +333,59 @@ pub(super) fn translate_query_sstring(
352333
)
353334
}
354335

336+
fn translate_fstring_with_concat_function(
337+
items: Vec<InterpolateItem<Expr>>,
338+
ctx: &mut Context,
339+
) -> Result<sql_ast::Expr> {
340+
let args = items
341+
.into_iter()
342+
.map(|item| match item {
343+
InterpolateItem::String(string) => {
344+
Ok(sql_ast::Expr::Value(Value::SingleQuotedString(string)))
345+
}
346+
InterpolateItem::Expr(node) => translate_expr_kind(node.kind, ctx),
347+
})
348+
.map(|r| r.map(|e| FunctionArg::Unnamed(FunctionArgExpr::Expr(e))))
349+
.collect::<Result<Vec<_>>>()?;
350+
351+
Ok(sql_ast::Expr::Function(Function {
352+
name: ObjectName(vec![sql_ast::Ident::new("CONCAT")]),
353+
args,
354+
distinct: false,
355+
over: None,
356+
special: false,
357+
}))
358+
}
359+
360+
fn translate_fstring_with_concat_operator(
361+
items: Vec<InterpolateItem<Expr>>,
362+
ctx: &mut Context,
363+
) -> Result<sql_ast::Expr> {
364+
let string = items
365+
.into_iter()
366+
.map(|f_string_item| match f_string_item {
367+
InterpolateItem::String(string) => Ok(Value::SingleQuotedString(string).to_string()),
368+
InterpolateItem::Expr(node) => {
369+
translate_expr_kind(node.kind, ctx).map(|expr| expr.to_string())
370+
}
371+
})
372+
.collect::<Result<Vec<String>>>()?
373+
.join("||");
374+
375+
Ok(sql_ast::Expr::Identifier(sql_ast::Ident::new(string)))
376+
}
377+
378+
pub(super) fn translate_fstring(
379+
items: Vec<InterpolateItem<Expr>>,
380+
ctx: &mut Context,
381+
) -> Result<sql_ast::Expr> {
382+
if ctx.dialect.has_concat_function() {
383+
translate_fstring_with_concat_function(items, ctx)
384+
} else {
385+
translate_fstring_with_concat_operator(items, ctx)
386+
}
387+
}
388+
355389
/// Aggregate several ordered ranges into one, computing the intersection.
356390
///
357391
/// Returns a tuple of `(start, end)`, where `end` is optional.
@@ -816,6 +850,10 @@ impl SQLExpression for UnaryOperator {
816850
mod test {
817851
use super::*;
818852
use crate::ast::pl::Range;
853+
use crate::sql::context::AnchorContext;
854+
use crate::{
855+
parser::parse, semantic::resolve, sql::dialect::GenericDialect, sql::dialect::SQLiteDialect,
856+
};
819857
use insta::assert_yaml_snapshot;
820858

821859
#[test]
@@ -903,4 +941,50 @@ mod test {
903941

904942
Ok(())
905943
}
944+
945+
#[test]
946+
fn test_translate_fstring() -> Result<()> {
947+
let mut context_with_concat_function: Context;
948+
let mut context_without_concat_function: Context;
949+
950+
{
951+
let query = resolve(parse("from foo")?)?;
952+
let (anchor, _) = AnchorContext::of(query);
953+
context_with_concat_function = Context {
954+
dialect: Box::new(GenericDialect {}),
955+
anchor,
956+
omit_ident_prefix: false,
957+
pre_projection: false,
958+
};
959+
}
960+
{
961+
let query = resolve(parse("from foo")?)?;
962+
let (anchor, _) = AnchorContext::of(query);
963+
context_without_concat_function = Context {
964+
dialect: Box::new(SQLiteDialect {}),
965+
anchor,
966+
omit_ident_prefix: false,
967+
pre_projection: false,
968+
};
969+
}
970+
971+
fn str_lit(s: &str) -> InterpolateItem<Expr> {
972+
InterpolateItem::String(s.to_string())
973+
}
974+
975+
assert_yaml_snapshot!(translate_fstring(vec![
976+
str_lit("hello"),
977+
str_lit("world"),
978+
], &mut context_with_concat_function)?.to_string(), @r###"
979+
---
980+
"CONCAT('hello', 'world')"
981+
"###);
982+
983+
assert_yaml_snapshot!(translate_fstring(vec![str_lit("hello"), str_lit("world")], &mut context_without_concat_function)?.to_string(), @r###"
984+
---
985+
"'hello'||'world'"
986+
"###);
987+
988+
Ok(())
989+
}
906990
}

prql-compiler/src/test.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,8 @@ fn test_f_string() {
15161516
]
15171517
"###;
15181518

1519-
let sql = compile(query).unwrap();
1520-
assert_display_snapshot!(sql,
1519+
assert_display_snapshot!(
1520+
compile(query).unwrap(),
15211521
@r###"
15221522
SELECT
15231523
CONCAT(
@@ -1532,6 +1532,23 @@ fn test_f_string() {
15321532
employees
15331533
"###
15341534
);
1535+
1536+
assert_display_snapshot!(
1537+
crate::compile(
1538+
query,
1539+
sql::Options::default()
1540+
.no_signature()
1541+
.with_dialect(sql::Dialect::SQLite)
1542+
.some()
1543+
).unwrap(),
1544+
@r###"
1545+
SELECT
1546+
'Hello my name is ' || prefix || first_name || ' ' || last_name,
1547+
'and I am ' || year_born - now() || ' years old.'
1548+
FROM
1549+
employees
1550+
"###
1551+
)
15351552
}
15361553

15371554
#[test]

0 commit comments

Comments
 (0)