Skip to content

Commit adc24f5

Browse files
committed
change: pass all logs to apply_entry_to_state_machine(), not just Normal logs.
Pass `Entry<D>` to `apply_entry_to_state_machine()`, not just the only `EntryPayload::Normal(normal_log)`. Thus the state machine is able to save the membership changes if it prefers to. Why: In practice, a snapshot contains info about all applied logs, including the membership config log. Before this change, the state machine does not receive any membership log thus when making a snapshot, one needs to walk through all applied logs to get the last membership that is included in state machine. By letting the state machine remember the membership log applied, the snapshto creation becomes more convinient and intuitive: it does not need to scan the applied logs any more.
1 parent 8e0cca5 commit adc24f5

File tree

7 files changed

+202
-71
lines changed

7 files changed

+202
-71
lines changed

async-raft/src/core/append_entries.rs

+3-12
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,7 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
225225
.filter_map(|idx| {
226226
if let Some(entry) = self.entries_cache.remove(&idx) {
227227
last_entry_seen = Some(entry.log_id);
228-
match entry.payload {
229-
EntryPayload::Normal(inner) => Some((entry.log_id, inner.data)),
230-
_ => None,
231-
}
228+
Some(entry)
232229
} else {
233230
None
234231
}
@@ -251,7 +248,7 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
251248
let handle = tokio::spawn(async move {
252249
// Create a new vector of references to the entries data ... might have to change this
253250
// interface a bit before 1.0.
254-
let entries_refs: Vec<_> = entries.iter().map(|(k, v)| (k, v)).collect();
251+
let entries_refs: Vec<_> = entries.iter().collect();
255252
storage.replicate_to_state_machine(&entries_refs).await?;
256253
Ok(last_entry_seen)
257254
});
@@ -280,13 +277,7 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
280277
if let Some(entry) = entries.last() {
281278
new_last_applied = Some(entry.log_id);
282279
}
283-
let data_entries: Vec<_> = entries
284-
.iter()
285-
.filter_map(|entry| match &entry.payload {
286-
EntryPayload::Normal(inner) => Some((&entry.log_id, &inner.data)),
287-
_ => None,
288-
})
289-
.collect();
280+
let data_entries: Vec<_> = entries.iter().collect();
290281
if data_entries.is_empty() {
291282
return Ok(new_last_applied);
292283
}

async-raft/src/core/client.rs

+25-24
Original file line numberDiff line numberDiff line change
@@ -307,22 +307,22 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
307307
/// Handle the post-commit logic for a client request.
308308
#[tracing::instrument(level = "trace", skip(self, req))]
309309
pub(super) async fn client_request_post_commit(&mut self, req: ClientRequestEntry<D, R>) {
310+
let entry = &req.entry;
311+
310312
match req.tx {
311313
ClientOrInternalResponseTx::Client(tx) => {
312-
match &req.entry.payload {
313-
EntryPayload::Normal(inner) => {
314-
match self.apply_entry_to_state_machine(&req.entry.log_id, &inner.data).await {
315-
Ok(data) => {
316-
let _ = tx.send(Ok(ClientWriteResponse {
317-
index: req.entry.log_id.index,
318-
data,
319-
}));
320-
}
321-
Err(err) => {
322-
let _ = tx.send(Err(ClientWriteError::RaftError(err)));
323-
}
314+
match &entry.payload {
315+
EntryPayload::Normal(_) => match self.apply_entry_to_state_machine(&entry).await {
316+
Ok(data) => {
317+
let _ = tx.send(Ok(ClientWriteResponse {
318+
index: req.entry.log_id.index,
319+
data,
320+
}));
324321
}
325-
}
322+
Err(err) => {
323+
let _ = tx.send(Err(ClientWriteError::RaftError(err)));
324+
}
325+
},
326326
_ => {
327327
// Why is this a bug, and why are we shutting down? This is because we can not easily
328328
// encode these constraints in the type system, and client requests should be the only
@@ -334,9 +334,15 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
334334
}
335335
}
336336
ClientOrInternalResponseTx::Internal(tx) => {
337-
self.core.last_applied = req.entry.log_id;
337+
// TODO(xp): copied from above, need refactor.
338+
let res = match self.apply_entry_to_state_machine(&entry).await {
339+
Ok(_data) => Ok(entry.log_id.index),
340+
Err(err) => Err(err),
341+
};
342+
343+
self.core.last_applied = entry.log_id;
338344
self.leader_report_metrics();
339-
let _ = tx.send(Ok(req.entry.log_id.index));
345+
let _ = tx.send(res);
340346
}
341347
}
342348

@@ -346,13 +352,14 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
346352

347353
/// Apply the given log entry to the state machine.
348354
#[tracing::instrument(level = "trace", skip(self, entry))]
349-
pub(super) async fn apply_entry_to_state_machine(&mut self, log_id: &LogId, entry: &D) -> RaftResult<R> {
355+
pub(super) async fn apply_entry_to_state_machine(&mut self, entry: &Entry<D>) -> RaftResult<R> {
350356
// First, we just ensure that we apply any outstanding up to, but not including, the index
351357
// of the given entry. We need to be able to return the data response from applying this
352358
// entry to the state machine.
353359
//
354360
// Note that this would only ever happen if a node had unapplied logs from before becoming leader.
355361

362+
let log_id = &entry.log_id;
356363
let index = log_id.index;
357364

358365
let expected_next_index = self.core.last_applied.index + 1;
@@ -368,13 +375,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
368375
self.core.last_applied = entry.log_id;
369376
}
370377

371-
let data_entries: Vec<_> = entries
372-
.iter()
373-
.filter_map(|entry| match &entry.payload {
374-
EntryPayload::Normal(inner) => Some((&entry.log_id, &inner.data)),
375-
_ => None,
376-
})
377-
.collect();
378+
let data_entries: Vec<_> = entries.iter().collect();
378379
if !data_entries.is_empty() {
379380
self.core
380381
.storage
@@ -393,7 +394,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
393394
}
394395
}
395396
// Apply this entry to the state machine and return its data response.
396-
let res = self.core.storage.apply_entry_to_state_machine(&log_id, entry).await.map_err(|err| {
397+
let res = self.core.storage.apply_entry_to_state_machine(entry).await.map_err(|err| {
397398
if err.downcast_ref::<S::ShutdownError>().is_some() {
398399
// If this is an instance of the storage impl's shutdown error, then trigger shutdown.
399400
self.core.map_fatal_storage_error(err)

async-raft/src/storage.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ where
178178
/// specific transaction is being started, or perhaps committed. This may be where a key/value
179179
/// is being stored. This may be where an entry is being appended to an immutable log.
180180
///
181+
/// An impl should do:
182+
/// - Deal with the EntryPayload::Normal() log, which is business logic log.
183+
/// - Optionally, deal with EntryPayload::ConfigChange or EntryPayload::SnapshotPointer log if they are concerned.
184+
/// E.g. when an impl need to track the membership changing.
185+
///
181186
/// Error handling for this method is note worthy. If an error is returned from a call to this
182187
/// method, the error will be inspected, and if the error is an instance of
183188
/// `RaftStorage::ShutdownError`, then Raft will go into shutdown in order to preserve the
@@ -186,15 +191,15 @@ where
186191
///
187192
/// It is important to note that even in cases where an application specific error is returned,
188193
/// implementations should still record that the entry has been applied to the state machine.
189-
async fn apply_entry_to_state_machine(&self, index: &LogId, data: &D) -> Result<R>;
194+
async fn apply_entry_to_state_machine(&self, data: &Entry<D>) -> Result<R>;
190195

191196
/// Apply the given payload of entries to the state machine, as part of replication.
192197
///
193198
/// The Raft protocol guarantees that only logs which have been _committed_, that is, logs which
194199
/// have been replicated to a majority of the cluster, will be applied to the state machine.
195200
///
196201
/// Errors returned from this method will cause Raft to go into shutdown.
197-
async fn replicate_to_state_machine(&self, entries: &[(&LogId, &D)]) -> Result<()>;
202+
async fn replicate_to_state_machine(&self, entries: &[&Entry<D>]) -> Result<()>;
198203

199204
/// Perform log compaction, returning a handle to the generated snapshot.
200205
///
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
mod fixtures;
2+
3+
use std::sync::Arc;
4+
5+
use anyhow::Result;
6+
use async_raft::raft::MembershipConfig;
7+
use async_raft::Config;
8+
use async_raft::State;
9+
use fixtures::RaftRouter;
10+
use futures::stream::StreamExt;
11+
use maplit::hashset;
12+
13+
/// All log should be applied to state machine.
14+
///
15+
/// What does this test do?
16+
///
17+
/// - bring a cluster with 3 voter and 2 non-voter.
18+
/// - check last_membership in state machine.
19+
///
20+
/// RUST_LOG=async_raft,memstore,state_machine_apply_membership=trace cargo test -p async-raft --test
21+
/// state_machine_apply_membership
22+
#[tokio::test(flavor = "multi_thread", worker_threads = 6)]
23+
async fn state_machine_apply_membership() -> Result<()> {
24+
fixtures::init_tracing();
25+
26+
// Setup test dependencies.
27+
let config = Arc::new(Config::build("test".into()).validate().expect("failed to build Raft config"));
28+
let router = Arc::new(RaftRouter::new(config.clone()));
29+
router.new_raft_node(0).await;
30+
31+
let mut want = 0;
32+
33+
// Assert all nodes are in non-voter state & have no entries.
34+
router.wait_for_log(&hashset![0], want, None, "empty").await?;
35+
router.wait_for_state(&hashset![0], State::NonVoter, None, "empty").await?;
36+
router.assert_pristine_cluster().await;
37+
38+
// Initialize the cluster, then assert that a stable cluster was formed & held.
39+
tracing::info!("--- initializing cluster");
40+
router.initialize_from_single_node(0).await?;
41+
want += 1;
42+
43+
router.wait_for_log(&hashset![0], want, None, "init").await?;
44+
router.assert_stable_cluster(Some(1), Some(want)).await;
45+
46+
for i in 0..=0 {
47+
let sto = router.get_storage_handle(&i).await?;
48+
let sm = sto.get_state_machine().await;
49+
assert_eq!(
50+
Some(MembershipConfig {
51+
members: hashset![0],
52+
members_after_consensus: None
53+
}),
54+
sm.last_membership
55+
);
56+
}
57+
58+
// Sync some new nodes.
59+
router.new_raft_node(1).await;
60+
router.new_raft_node(2).await;
61+
router.new_raft_node(3).await;
62+
router.new_raft_node(4).await;
63+
64+
tracing::info!("--- adding new nodes to cluster");
65+
let mut new_nodes = futures::stream::FuturesUnordered::new();
66+
new_nodes.push(router.add_non_voter(0, 1));
67+
new_nodes.push(router.add_non_voter(0, 2));
68+
new_nodes.push(router.add_non_voter(0, 3));
69+
new_nodes.push(router.add_non_voter(0, 4));
70+
while let Some(inner) = new_nodes.next().await {
71+
inner?;
72+
}
73+
74+
tracing::info!("--- changing cluster config");
75+
router.change_membership(0, hashset![0, 1, 2]).await?;
76+
want += 2;
77+
78+
router.wait_for_log(&hashset![0, 1, 2, 3, 4], want, None, "cluster of 5 candidates").await?;
79+
80+
tracing::info!("--- check applied membership config");
81+
for i in 0..5 {
82+
let sto = router.get_storage_handle(&i).await?;
83+
let sm = sto.get_state_machine().await;
84+
assert_eq!(
85+
Some(MembershipConfig {
86+
members: hashset![0, 1, 2],
87+
members_after_consensus: None
88+
}),
89+
sm.last_membership
90+
);
91+
}
92+
93+
Ok(())
94+
}

async-raft/tests/stepdown.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ async fn stepdown() -> Result<()> {
128128
assert!(metrics.current_term >= 2, "term incr when leader changes");
129129
router.assert_stable_cluster(Some(metrics.current_term), Some(want)).await;
130130
router
131-
.assert_storage_state(metrics.current_term, want, None, LogId { term: 0, index: 0 }, None)
131+
.assert_storage_state(metrics.current_term, want, None, LogId { term: 2, index: 4 }, None)
132132
.await;
133133
// ----------------------------------- ^^^ this is `0` instead of `4` because blank payloads from new leaders
134134
// and config change entries are never applied to the state machine.

memstore/src/lib.rs

+49-23
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ pub struct MemStoreSnapshot {
7474
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
7575
pub struct MemStoreStateMachine {
7676
pub last_applied_log: LogId,
77+
78+
pub last_membership: Option<MembershipConfig>,
79+
7780
/// A mapping of client IDs to their state info.
7881
pub client_serial_responses: HashMap<String, (u64, Option<String>)>,
7982
/// The current status of a client by ID.
@@ -233,7 +236,7 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
233236
async fn get_log_entries(&self, start: u64, stop: u64) -> Result<Vec<Entry<ClientRequest>>> {
234237
// Invalid request, return empty vec.
235238
if start > stop {
236-
tracing::error!("invalid request, start > stop");
239+
tracing::error!("get_log_entries: invalid request, start({}) > stop({})", start, stop);
237240
return Ok(vec![]);
238241
}
239242
let log = self.log.read().await;
@@ -243,7 +246,7 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
243246
#[tracing::instrument(level = "trace", skip(self))]
244247
async fn delete_logs_from(&self, start: u64, stop: Option<u64>) -> Result<()> {
245248
if stop.as_ref().map(|stop| &start > stop).unwrap_or(false) {
246-
tracing::error!("invalid request, start > stop");
249+
tracing::error!("delete_logs_from: invalid request, start({}) > stop({:?})", start, stop);
247250
return Ok(());
248251
}
249252
let mut log = self.log.write().await;
@@ -276,50 +279,73 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
276279
Ok(())
277280
}
278281

279-
#[tracing::instrument(level = "trace", skip(self, data))]
280-
async fn apply_entry_to_state_machine(&self, index: &LogId, data: &ClientRequest) -> Result<ClientResponse> {
282+
#[tracing::instrument(level = "trace", skip(self, entry))]
283+
async fn apply_entry_to_state_machine(&self, entry: &Entry<ClientRequest>) -> Result<ClientResponse> {
281284
let mut sm = self.sm.write().await;
282-
sm.last_applied_log = *index;
283-
if let Some((serial, res)) = sm.client_serial_responses.get(&data.client) {
284-
if serial == &data.serial {
285-
return Ok(ClientResponse(res.clone()));
285+
sm.last_applied_log = entry.log_id;
286+
287+
return match entry.payload {
288+
EntryPayload::Blank => return Ok(ClientResponse(None)),
289+
EntryPayload::SnapshotPointer(_) => return Ok(ClientResponse(None)),
290+
EntryPayload::Normal(ref norm) => {
291+
let data = &norm.data;
292+
if let Some((serial, res)) = sm.client_serial_responses.get(&data.client) {
293+
if serial == &data.serial {
294+
return Ok(ClientResponse(res.clone()));
295+
}
296+
}
297+
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
298+
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
299+
Ok(ClientResponse(previous))
286300
}
287-
}
288-
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
289-
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
290-
Ok(ClientResponse(previous))
301+
EntryPayload::ConfigChange(ref mem) => {
302+
sm.last_membership = Some(mem.membership.clone());
303+
return Ok(ClientResponse(None));
304+
}
305+
};
291306
}
292307

293308
#[tracing::instrument(level = "trace", skip(self, entries))]
294-
async fn replicate_to_state_machine(&self, entries: &[(&LogId, &ClientRequest)]) -> Result<()> {
309+
async fn replicate_to_state_machine(&self, entries: &[&Entry<ClientRequest>]) -> Result<()> {
295310
let mut sm = self.sm.write().await;
296-
for (index, data) in entries {
297-
sm.last_applied_log = **index;
298-
if let Some((serial, _)) = sm.client_serial_responses.get(&data.client) {
299-
if serial == &data.serial {
300-
continue;
311+
for entry in entries {
312+
sm.last_applied_log = entry.log_id;
313+
314+
match entry.payload {
315+
EntryPayload::Blank => {}
316+
EntryPayload::SnapshotPointer(_) => {}
317+
EntryPayload::Normal(ref norm) => {
318+
let data = &norm.data;
319+
if let Some((serial, _)) = sm.client_serial_responses.get(&data.client) {
320+
if serial == &data.serial {
321+
continue;
322+
}
323+
}
324+
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
325+
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
301326
}
302-
}
303-
let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
304-
sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
327+
EntryPayload::ConfigChange(ref mem) => {
328+
sm.last_membership = Some(mem.membership.clone());
329+
}
330+
};
305331
}
306332
Ok(())
307333
}
308334

309335
#[tracing::instrument(level = "trace", skip(self))]
310336
async fn do_log_compaction(&self) -> Result<CurrentSnapshotData<Self::Snapshot>> {
311337
let (data, last_applied_log);
338+
let membership_config;
312339
{
313340
// Serialize the data of the state machine.
314341
let sm = self.sm.read().await;
315342
data = serde_json::to_vec(&*sm)?;
316343
last_applied_log = sm.last_applied_log;
344+
membership_config = sm.last_membership.clone().unwrap_or_else(|| MembershipConfig::new_initial(self.id));
317345
} // Release state machine read lock.
318346

319347
let snapshot_size = data.len();
320348

321-
let membership_config = self.get_membership_from_log(Some(last_applied_log.index)).await?;
322-
323349
let snapshot_idx = {
324350
let mut l = self.snapshot_idx.lock().unwrap();
325351
*l += 1;

0 commit comments

Comments
 (0)