14
14
15
15
use std:: time:: Duration ;
16
16
17
+ use itertools:: Itertools ;
18
+ use lru:: { Iter , LruCache } ;
19
+ use risingwave_sqlparser:: ast:: Statement ;
20
+ use risingwave_sqlparser:: parser:: Parser ;
21
+
17
22
/// A RisingWave client.
18
23
pub struct RisingWave {
19
24
client : tokio_postgres:: Client ,
@@ -22,26 +27,98 @@ pub struct RisingWave {
22
27
dbname : String ,
23
28
/// The `SET` statements that have been executed on this client.
24
29
/// We need to replay them when reconnecting.
25
- set_stmts : Vec < String > ,
30
+ set_stmts : SetStmts ,
31
+ }
32
+
33
+ /// `SetStmts` stores and compacts all `SET` statements that have been executed in the client
34
+ /// history.
35
+ pub struct SetStmts {
36
+ stmts_cache : LruCache < String , String > ,
37
+ }
38
+
39
+ impl Default for SetStmts {
40
+ fn default ( ) -> Self {
41
+ Self {
42
+ stmts_cache : LruCache :: unbounded ( ) ,
43
+ }
44
+ }
45
+ }
46
+
47
+ struct SetStmtsIterator < ' a , ' b >
48
+ where
49
+ ' a : ' b ,
50
+ {
51
+ _stmts : & ' a SetStmts ,
52
+ stmts_iter : core:: iter:: Rev < Iter < ' b , String , String > > ,
53
+ }
54
+
55
+ impl < ' a , ' b > SetStmtsIterator < ' a , ' b > {
56
+ fn new ( stmts : & ' a SetStmts ) -> Self {
57
+ Self {
58
+ _stmts : stmts,
59
+ stmts_iter : stmts. stmts_cache . iter ( ) . rev ( ) ,
60
+ }
61
+ }
62
+ }
63
+
64
+ impl SetStmts {
65
+ fn push ( & mut self , sql : & str ) {
66
+ let ast = Parser :: parse_sql ( & sql) . expect ( "a set statement should be parsed successfully" ) ;
67
+ match ast
68
+ . into_iter ( )
69
+ . exactly_one ( )
70
+ . expect ( "should contain only one statement" )
71
+ {
72
+ // record `local` for variable and `SetTransaction` if supported in the future.
73
+ Statement :: SetVariable {
74
+ local : _,
75
+ variable,
76
+ value : _,
77
+ } => {
78
+ let key = variable. real_value ( ) . to_lowercase ( ) ;
79
+ // store complete sql as value.
80
+ self . stmts_cache . put ( key, sql. to_string ( ) ) ;
81
+ }
82
+ _ => unreachable ! ( ) ,
83
+ }
84
+ }
85
+ }
86
+
87
+ impl Iterator for SetStmtsIterator < ' _ , ' _ > {
88
+ type Item = String ;
89
+
90
+ fn next ( & mut self ) -> Option < Self :: Item > {
91
+ let ( _, stmt) = self . stmts_iter . next ( ) ?;
92
+ Some ( stmt. clone ( ) )
93
+ }
26
94
}
27
95
28
96
impl RisingWave {
29
97
pub async fn connect (
30
98
host : String ,
31
99
dbname : String ,
32
100
) -> Result < Self , tokio_postgres:: error:: Error > {
33
- Self :: reconnect ( host, dbname, vec ! [ ] ) . await
101
+ let set_stmts = SetStmts :: default ( ) ;
102
+ let ( client, task) = Self :: connect_inner ( & host, & dbname, & set_stmts) . await ?;
103
+ Ok ( Self {
104
+ client,
105
+ task,
106
+ host,
107
+ dbname,
108
+ set_stmts,
109
+ } )
34
110
}
35
111
36
- pub async fn reconnect (
37
- host : String ,
38
- dbname : String ,
39
- set_stmts : Vec < String > ,
40
- ) -> Result < Self , tokio_postgres:: error:: Error > {
112
+ pub async fn connect_inner (
113
+ host : & str ,
114
+ dbname : & str ,
115
+ set_stmts : & SetStmts ,
116
+ ) -> Result < ( tokio_postgres:: Client , tokio:: task:: JoinHandle < ( ) > ) , tokio_postgres:: error:: Error >
117
+ {
41
118
let ( client, connection) = tokio_postgres:: Config :: new ( )
42
- . host ( & host)
119
+ . host ( host)
43
120
. port ( 4566 )
44
- . dbname ( & dbname)
121
+ . dbname ( dbname)
45
122
. user ( "root" )
46
123
. connect_timeout ( Duration :: from_secs ( 5 ) )
47
124
. connect ( tokio_postgres:: NoTls )
@@ -64,16 +141,17 @@ impl RisingWave {
64
141
. simple_query ( "SET VISIBILITY_MODE TO checkpoint;" )
65
142
. await ?;
66
143
// replay all SET statements
67
- for stmt in & set_stmts {
68
- client. simple_query ( stmt) . await ?;
144
+ for stmt in SetStmtsIterator :: new ( & set_stmts) {
145
+ client. simple_query ( & stmt) . await ?;
69
146
}
70
- Ok ( RisingWave {
71
- client,
72
- task,
73
- host,
74
- dbname,
75
- set_stmts,
76
- } )
147
+ Ok ( ( client, task) )
148
+ }
149
+
150
+ pub async fn reconnect ( & mut self ) -> Result < ( ) , tokio_postgres:: error:: Error > {
151
+ let ( client, task) = Self :: connect_inner ( & self . host , & self . dbname , & self . set_stmts ) . await ?;
152
+ self . client = client;
153
+ self . task = task;
154
+ Ok ( ( ) )
77
155
}
78
156
79
157
/// Returns a reference of the inner Postgres client.
@@ -97,16 +175,11 @@ impl sqllogictest::AsyncDB for RisingWave {
97
175
98
176
if self . client . is_closed ( ) {
99
177
// connection error, reset the client
100
- * self = Self :: reconnect (
101
- self . host . clone ( ) ,
102
- self . dbname . clone ( ) ,
103
- self . set_stmts . clone ( ) ,
104
- )
105
- . await ?;
178
+ self . reconnect ( ) . await ?;
106
179
}
107
180
108
181
if sql. trim_start ( ) . to_lowercase ( ) . starts_with ( "set" ) {
109
- self . set_stmts . push ( sql. to_string ( ) ) ;
182
+ self . set_stmts . push ( sql) ;
110
183
}
111
184
112
185
let mut output = vec ! [ ] ;
0 commit comments