@@ -23,15 +23,17 @@ use pyo3::{pyclass, pymethods, PyErr, PyResult};
23
23
use std:: collections:: hash_map:: Entry ;
24
24
use std:: collections:: HashMap ;
25
25
use std:: hash:: Hash ;
26
+ use std:: ops:: ControlFlow ;
26
27
use std:: sync:: Arc ;
28
+ use wren_core:: ast:: { visit_statements_mut, Expr , Statement , Value } ;
29
+ use wren_core:: dialect:: GenericDialect ;
27
30
use wren_core:: logical_plan:: utils:: map_data_type;
28
31
use wren_core:: mdl:: context:: create_ctx_with_mdl;
29
32
use wren_core:: mdl:: function:: {
30
33
ByPassAggregateUDF , ByPassScalarUDF , ByPassWindowFunction , FunctionType ,
31
34
RemoteFunction ,
32
35
} ;
33
36
use wren_core:: { mdl, AggregateUDF , AnalyzedWrenMDL , ScalarUDF , WindowUDF } ;
34
-
35
37
/// The Python wrapper for the Wren Core session context.
36
38
#[ pyclass( name = "SessionContext" ) ]
37
39
#[ derive( Clone ) ]
@@ -188,6 +190,43 @@ impl PySessionContext {
188
190
} ) ;
189
191
Ok ( builder. values ( ) . cloned ( ) . collect ( ) )
190
192
}
193
+
194
+ /// Push down the limit to the given SQL.
195
+ /// If the limit is None, the SQL will be returned as is.
196
+ /// If the limit is greater than the pushdown limit, the limit will be replaced with the pushdown limit.
197
+ /// Otherwise, the limit will be kept as is.
198
+ #[ pyo3( signature = ( sql, limit=None ) ) ]
199
+ pub fn pushdown_limit ( & self , sql : & str , limit : Option < usize > ) -> PyResult < String > {
200
+ if limit. is_none ( ) {
201
+ return Ok ( sql. to_string ( ) ) ;
202
+ }
203
+ let pushdown = limit. unwrap ( ) ;
204
+ let mut statements =
205
+ wren_core:: parser:: Parser :: parse_sql ( & GenericDialect { } , sql)
206
+ . map_err ( CoreError :: from) ?;
207
+ if statements. len ( ) != 1 {
208
+ return Err ( CoreError :: new ( "Only one statement is allowed" ) . into ( ) ) ;
209
+ }
210
+ visit_statements_mut ( & mut statements, |stmt| {
211
+ if let Statement :: Query ( q) = stmt {
212
+ if let Some ( limit) = & q. limit {
213
+ if let Expr :: Value ( Value :: Number ( n, is) ) = limit {
214
+ if n. parse :: < usize > ( ) . unwrap ( ) > pushdown {
215
+ q. limit = Some ( Expr :: Value ( Value :: Number (
216
+ pushdown. to_string ( ) ,
217
+ is. clone ( ) ,
218
+ ) ) ) ;
219
+ }
220
+ }
221
+ } else {
222
+ q. limit =
223
+ Some ( Expr :: Value ( Value :: Number ( pushdown. to_string ( ) , false ) ) ) ;
224
+ }
225
+ }
226
+ ControlFlow :: < ( ) > :: Continue ( ( ) )
227
+ } ) ;
228
+ Ok ( statements[ 0 ] . to_string ( ) )
229
+ }
191
230
}
192
231
193
232
impl PySessionContext {
0 commit comments