diff --git a/prql-compiler/src/sql/dialect.rs b/prql-compiler/src/sql/dialect.rs index 69711d12ceea..7972526955b0 100644 --- a/prql-compiler/src/sql/dialect.rs +++ b/prql-compiler/src/sql/dialect.rs @@ -105,6 +105,12 @@ pub(super) trait DialectHandler { fn intersect_all(&self) -> bool { self.except_all() } + + /// Support for CONCAT function. + /// When not supported we fallback to use `||` as concat operator. + fn has_concat_function(&self) -> bool { + true + } } impl DialectHandler for GenericDialect {} @@ -117,6 +123,10 @@ impl DialectHandler for SQLiteDialect { fn except_all(&self) -> bool { false } + + fn has_concat_function(&self) -> bool { + false + } } impl DialectHandler for MsSqlDialect { diff --git a/prql-compiler/src/sql/gen_expr.rs b/prql-compiler/src/sql/gen_expr.rs index cd4dc4742c01..6487df6aa230 100644 --- a/prql-compiler/src/sql/gen_expr.rs +++ b/prql-compiler/src/sql/gen_expr.rs @@ -93,26 +93,7 @@ pub(super) fn translate_expr_kind(item: ExprKind, ctx: &mut Context) -> Result { - let args = f_string_items - .into_iter() - .map(|item| match item { - InterpolateItem::String(string) => { - Ok(sql_ast::Expr::Value(Value::SingleQuotedString(string))) - } - InterpolateItem::Expr(node) => translate_expr_kind(node.kind, ctx), - }) - .map(|r| r.map(|e| FunctionArg::Unnamed(FunctionArgExpr::Expr(e)))) - .collect::>>()?; - - sql_ast::Expr::Function(Function { - name: ObjectName(vec![sql_ast::Ident::new("CONCAT")]), - args, - distinct: false, - over: None, - special: false, - }) - } + ExprKind::FString(f_string_items) => translate_fstring(f_string_items, ctx)?, ExprKind::Literal(l) => translate_literal(l)?, ExprKind::Switch(mut cases) => { let default = cases @@ -352,6 +333,59 @@ pub(super) fn translate_query_sstring( ) } +fn translate_fstring_with_concat_function( + items: Vec>, + ctx: &mut Context, +) -> Result { + let args = items + .into_iter() + .map(|item| match item { + InterpolateItem::String(string) => { + Ok(sql_ast::Expr::Value(Value::SingleQuotedString(string))) + } + InterpolateItem::Expr(node) => translate_expr_kind(node.kind, ctx), + }) + .map(|r| r.map(|e| FunctionArg::Unnamed(FunctionArgExpr::Expr(e)))) + .collect::>>()?; + + Ok(sql_ast::Expr::Function(Function { + name: ObjectName(vec![sql_ast::Ident::new("CONCAT")]), + args, + distinct: false, + over: None, + special: false, + })) +} + +fn translate_fstring_with_concat_operator( + items: Vec>, + ctx: &mut Context, +) -> Result { + let string = items + .into_iter() + .map(|f_string_item| match f_string_item { + InterpolateItem::String(string) => Ok(Value::SingleQuotedString(string).to_string()), + InterpolateItem::Expr(node) => { + translate_expr_kind(node.kind, ctx).map(|expr| expr.to_string()) + } + }) + .collect::>>()? + .join("||"); + + Ok(sql_ast::Expr::Identifier(sql_ast::Ident::new(string))) +} + +pub(super) fn translate_fstring( + items: Vec>, + ctx: &mut Context, +) -> Result { + if ctx.dialect.has_concat_function() { + translate_fstring_with_concat_function(items, ctx) + } else { + translate_fstring_with_concat_operator(items, ctx) + } +} + /// Aggregate several ordered ranges into one, computing the intersection. /// /// Returns a tuple of `(start, end)`, where `end` is optional. @@ -816,6 +850,10 @@ impl SQLExpression for UnaryOperator { mod test { use super::*; use crate::ast::pl::Range; + use crate::sql::context::AnchorContext; + use crate::{ + parser::parse, semantic::resolve, sql::dialect::GenericDialect, sql::dialect::SQLiteDialect, + }; use insta::assert_yaml_snapshot; #[test] @@ -903,4 +941,50 @@ mod test { Ok(()) } + + #[test] + fn test_translate_fstring() -> Result<()> { + let mut context_with_concat_function: Context; + let mut context_without_concat_function: Context; + + { + let query = resolve(parse("from foo")?)?; + let (anchor, _) = AnchorContext::of(query); + context_with_concat_function = Context { + dialect: Box::new(GenericDialect {}), + anchor, + omit_ident_prefix: false, + pre_projection: false, + }; + } + { + let query = resolve(parse("from foo")?)?; + let (anchor, _) = AnchorContext::of(query); + context_without_concat_function = Context { + dialect: Box::new(SQLiteDialect {}), + anchor, + omit_ident_prefix: false, + pre_projection: false, + }; + } + + fn str_lit(s: &str) -> InterpolateItem { + InterpolateItem::String(s.to_string()) + } + + assert_yaml_snapshot!(translate_fstring(vec![ + str_lit("hello"), + str_lit("world"), + ], &mut context_with_concat_function)?.to_string(), @r###" + --- + "CONCAT('hello', 'world')" + "###); + + assert_yaml_snapshot!(translate_fstring(vec![str_lit("hello"), str_lit("world")], &mut context_without_concat_function)?.to_string(), @r###" + --- + "'hello'||'world'" + "###); + + Ok(()) + } } diff --git a/prql-compiler/src/test.rs b/prql-compiler/src/test.rs index 09fbc63852b4..694d447910ad 100644 --- a/prql-compiler/src/test.rs +++ b/prql-compiler/src/test.rs @@ -1516,8 +1516,8 @@ fn test_f_string() { ] "###; - let sql = compile(query).unwrap(); - assert_display_snapshot!(sql, + assert_display_snapshot!( + compile(query).unwrap(), @r###" SELECT CONCAT( @@ -1532,6 +1532,23 @@ fn test_f_string() { employees "### ); + + assert_display_snapshot!( + crate::compile( + query, + sql::Options::default() + .no_signature() + .with_dialect(sql::Dialect::SQLite) + .some() + ).unwrap(), + @r###" + SELECT + 'Hello my name is ' || prefix || first_name || ' ' || last_name, + 'and I am ' || year_born - now() || ' years old.' + FROM + employees + "### + ) } #[test]