@@ -9,15 +9,15 @@ use anyhow::{anyhow, Result};
9
9
use tokio:: sync:: RwLock ;
10
10
use tracing_subscriber:: prelude:: * ;
11
11
12
- use async_raft:: { Config , NodeId , Raft , RaftMetrics , RaftNetwork , State } ;
13
12
use async_raft:: async_trait:: async_trait;
14
13
use async_raft:: error:: { ChangeConfigError , ClientReadError , ClientWriteError } ;
14
+ use async_raft:: raft:: ClientWriteRequest ;
15
+ use async_raft:: raft:: MembershipConfig ;
15
16
use async_raft:: raft:: { AppendEntriesRequest , AppendEntriesResponse } ;
16
17
use async_raft:: raft:: { InstallSnapshotRequest , InstallSnapshotResponse } ;
17
18
use async_raft:: raft:: { VoteRequest , VoteResponse } ;
18
- use async_raft:: raft:: ClientWriteRequest ;
19
- use async_raft:: raft:: MembershipConfig ;
20
19
use async_raft:: storage:: RaftStorage ;
20
+ use async_raft:: { Config , NodeId , Raft , RaftMetrics , RaftNetwork , State } ;
21
21
use memstore:: { ClientRequest as MemClientRequest , ClientResponse as MemClientResponse , MemStore } ;
22
22
23
23
/// A concrete Raft type used during testing.
@@ -119,16 +119,18 @@ impl RaftRouter {
119
119
/// Get a handle to the storage backend for the target node.
120
120
pub async fn get_storage_handle ( & self , node_id : & NodeId ) -> Result < Arc < MemStore > > {
121
121
let rt = self . routing_table . read ( ) . await ;
122
- let addr = rt. get ( node_id) . ok_or_else ( ||anyhow:: anyhow!( "could not find node {} in routing table" , node_id) ) ?;
122
+ let addr = rt
123
+ . get ( node_id)
124
+ . ok_or_else ( || anyhow:: anyhow!( "could not find node {} in routing table" , node_id) ) ?;
123
125
let sto = addr. clone ( ) . 1 ;
124
126
Ok ( sto)
125
127
}
126
128
127
129
/// Wait for metrics until it satisfies some condition.
128
130
#[ tracing:: instrument( level = "info" , skip( self , func) ) ]
129
131
pub async fn wait_for_metrics < T > ( & self , node_id : & NodeId , func : T , timeout : tokio:: time:: Duration , msg : & str ) -> Result < RaftMetrics >
130
- where
131
- T : Fn ( & RaftMetrics ) -> bool ,
132
+ where
133
+ T : Fn ( & RaftMetrics ) -> bool ,
132
134
{
133
135
let rt = self . routing_table . read ( ) . await ;
134
136
let node = rt. get ( node_id) . ok_or_else ( || anyhow:: anyhow!( "node {} not found" , node_id) ) ?;
@@ -157,30 +159,55 @@ impl RaftRouter {
157
159
}
158
160
}
159
161
162
+ /// Same as wait_for_log() but provides additional timeout argument.
163
+ #[ tracing:: instrument( level = "info" , skip( self ) ) ]
164
+ pub async fn wait_for_log_timeout ( & self , node_ids : & HashSet < u64 > , want_log : u64 , timeout : tokio:: time:: Duration , msg : & str ) -> Result < ( ) > {
165
+ for i in node_ids. iter ( ) {
166
+ self . wait_for_metrics (
167
+ & i,
168
+ |x| x. last_log_index == want_log,
169
+ timeout,
170
+ & format ! ( "{} n{}.last_log_index -> {}" , msg, i, want_log) ,
171
+ )
172
+ . await ?;
173
+ self . wait_for_metrics (
174
+ & i,
175
+ |x| x. last_applied == want_log,
176
+ timeout,
177
+ & format ! ( "{} n{}.last_applied -> {}" , msg, i, want_log) ,
178
+ )
179
+ . await ?;
180
+ }
181
+ Ok ( ( ) )
182
+ }
183
+
160
184
/// Wait for specified nodes until they applied upto `want_log`(inclusive) logs.
161
185
#[ tracing:: instrument( level = "info" , skip( self ) ) ]
162
- pub async fn wait_for_nodes_log ( & self , node_ids : & HashSet < u64 > , want_log : u64 , timeout : tokio:: time:: Duration , msg : & str ) -> Result < ( ) > {
186
+ pub async fn wait_for_log ( & self , node_ids : & HashSet < u64 > , want_log : u64 , msg : & str ) -> Result < ( ) > {
187
+ let timeout = tokio:: time:: Duration :: from_millis ( 500 ) ;
188
+ self . wait_for_log_timeout ( node_ids, want_log, timeout, msg) . await
189
+ }
190
+
191
+ /// Same as wait_for_state() but provides additional timeout argument.
192
+ #[ tracing:: instrument( level = "info" , skip( self ) ) ]
193
+ pub async fn wait_for_state_timeout ( & self , node_ids : & HashSet < u64 > , want_state : State , timeout : tokio:: time:: Duration , msg : & str ) -> Result < ( ) > {
163
194
for i in node_ids. iter ( ) {
164
- self
165
- . wait_for_metrics (
166
- & i,
167
- |x| x. last_log_index == want_log,
168
- timeout,
169
- & format ! ( "{} n{}.last_log_index -> {}" , msg, i, want_log) ,
170
- )
171
- . await ?;
172
- self
173
- . wait_for_metrics (
174
- & i,
175
- |x| x. last_applied == want_log,
176
- timeout,
177
- & format ! ( "{} n{}.last_applied -> {}" , msg, i, want_log) ,
178
- )
179
- . await ?;
195
+ self . wait_for_metrics (
196
+ & i,
197
+ |x| x. state == want_state,
198
+ timeout,
199
+ & format ! ( "{} n{}.state -> {:?}" , msg, i, want_state) ,
200
+ )
201
+ . await ?;
180
202
}
181
203
Ok ( ( ) )
182
204
}
183
205
206
+ pub async fn wait_for_state ( & self , node_ids : & HashSet < u64 > , want_state : State , msg : & str ) -> Result < ( ) > {
207
+ let timeout = tokio:: time:: Duration :: from_millis ( 500 ) ;
208
+ self . wait_for_state_timeout ( node_ids, want_state, timeout, msg) . await
209
+ }
210
+
184
211
/// Get the ID of the current leader.
185
212
pub async fn leader ( & self ) -> Option < NodeId > {
186
213
let isolated = self . isolated_nodes . read ( ) . await ;
0 commit comments