Skip to content

Commit 64d80d2

Browse files
authored
feat(test): optimize set stmts in simulation to avoid duplicate replay (risingwavelabs#8420)
1 parent b6244d7 commit 64d80d2

File tree

3 files changed

+102
-25
lines changed

3 files changed

+102
-25
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/tests/simulation/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ etcd-client = { version = "0.2.17", package = "madsim-etcd-client" }
2020
futures = { version = "0.3", default-features = false, features = ["alloc"] }
2121
glob = "0.3"
2222
itertools = "0.10"
23+
lru = { git = "https://github.com/risingwavelabs/lru-rs.git", branch = "evict_by_timestamp" }
2324
madsim = "0.2.17"
2425
paste = "1"
2526
pretty_assertions = "1"
@@ -32,6 +33,7 @@ risingwave_ctl = { path = "../../ctl" }
3233
risingwave_frontend = { path = "../../frontend" }
3334
risingwave_meta = { path = "../../meta" }
3435
risingwave_pb = { path = "../../prost" }
36+
risingwave_sqlparser = { path = "../../sqlparser" }
3537
risingwave_sqlsmith = { path = "../sqlsmith" }
3638
serde = "1.0.152"
3739
serde_derive = "1.0.152"

src/tests/simulation/src/client.rs

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
use std::time::Duration;
1616

17+
use itertools::Itertools;
18+
use lru::{Iter, LruCache};
19+
use risingwave_sqlparser::ast::Statement;
20+
use risingwave_sqlparser::parser::Parser;
21+
1722
/// A RisingWave client.
1823
pub struct RisingWave {
1924
client: tokio_postgres::Client,
@@ -22,26 +27,98 @@ pub struct RisingWave {
2227
dbname: String,
2328
/// The `SET` statements that have been executed on this client.
2429
/// We need to replay them when reconnecting.
25-
set_stmts: Vec<String>,
30+
set_stmts: SetStmts,
31+
}
32+
33+
/// `SetStmts` stores and compacts all `SET` statements that have been executed in the client
34+
/// history.
35+
pub struct SetStmts {
36+
stmts_cache: LruCache<String, String>,
37+
}
38+
39+
impl Default for SetStmts {
40+
fn default() -> Self {
41+
Self {
42+
stmts_cache: LruCache::unbounded(),
43+
}
44+
}
45+
}
46+
47+
struct SetStmtsIterator<'a, 'b>
48+
where
49+
'a: 'b,
50+
{
51+
_stmts: &'a SetStmts,
52+
stmts_iter: core::iter::Rev<Iter<'b, String, String>>,
53+
}
54+
55+
impl<'a, 'b> SetStmtsIterator<'a, 'b> {
56+
fn new(stmts: &'a SetStmts) -> Self {
57+
Self {
58+
_stmts: stmts,
59+
stmts_iter: stmts.stmts_cache.iter().rev(),
60+
}
61+
}
62+
}
63+
64+
impl SetStmts {
65+
fn push(&mut self, sql: &str) {
66+
let ast = Parser::parse_sql(&sql).expect("a set statement should be parsed successfully");
67+
match ast
68+
.into_iter()
69+
.exactly_one()
70+
.expect("should contain only one statement")
71+
{
72+
// record `local` for variable and `SetTransaction` if supported in the future.
73+
Statement::SetVariable {
74+
local: _,
75+
variable,
76+
value: _,
77+
} => {
78+
let key = variable.real_value().to_lowercase();
79+
// store complete sql as value.
80+
self.stmts_cache.put(key, sql.to_string());
81+
}
82+
_ => unreachable!(),
83+
}
84+
}
85+
}
86+
87+
impl Iterator for SetStmtsIterator<'_, '_> {
88+
type Item = String;
89+
90+
fn next(&mut self) -> Option<Self::Item> {
91+
let (_, stmt) = self.stmts_iter.next()?;
92+
Some(stmt.clone())
93+
}
2694
}
2795

2896
impl RisingWave {
2997
pub async fn connect(
3098
host: String,
3199
dbname: String,
32100
) -> Result<Self, tokio_postgres::error::Error> {
33-
Self::reconnect(host, dbname, vec![]).await
101+
let set_stmts = SetStmts::default();
102+
let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
103+
Ok(Self {
104+
client,
105+
task,
106+
host,
107+
dbname,
108+
set_stmts,
109+
})
34110
}
35111

36-
pub async fn reconnect(
37-
host: String,
38-
dbname: String,
39-
set_stmts: Vec<String>,
40-
) -> Result<Self, tokio_postgres::error::Error> {
112+
pub async fn connect_inner(
113+
host: &str,
114+
dbname: &str,
115+
set_stmts: &SetStmts,
116+
) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
117+
{
41118
let (client, connection) = tokio_postgres::Config::new()
42-
.host(&host)
119+
.host(host)
43120
.port(4566)
44-
.dbname(&dbname)
121+
.dbname(dbname)
45122
.user("root")
46123
.connect_timeout(Duration::from_secs(5))
47124
.connect(tokio_postgres::NoTls)
@@ -64,16 +141,17 @@ impl RisingWave {
64141
.simple_query("SET VISIBILITY_MODE TO checkpoint;")
65142
.await?;
66143
// replay all SET statements
67-
for stmt in &set_stmts {
68-
client.simple_query(stmt).await?;
144+
for stmt in SetStmtsIterator::new(&set_stmts) {
145+
client.simple_query(&stmt).await?;
69146
}
70-
Ok(RisingWave {
71-
client,
72-
task,
73-
host,
74-
dbname,
75-
set_stmts,
76-
})
147+
Ok((client, task))
148+
}
149+
150+
pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
151+
let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
152+
self.client = client;
153+
self.task = task;
154+
Ok(())
77155
}
78156

79157
/// Returns a reference of the inner Postgres client.
@@ -97,16 +175,11 @@ impl sqllogictest::AsyncDB for RisingWave {
97175

98176
if self.client.is_closed() {
99177
// connection error, reset the client
100-
*self = Self::reconnect(
101-
self.host.clone(),
102-
self.dbname.clone(),
103-
self.set_stmts.clone(),
104-
)
105-
.await?;
178+
self.reconnect().await?;
106179
}
107180

108181
if sql.trim_start().to_lowercase().starts_with("set") {
109-
self.set_stmts.push(sql.to_string());
182+
self.set_stmts.push(sql);
110183
}
111184

112185
let mut output = vec![];

0 commit comments

Comments
 (0)