Skip to content

Commit dcd18c5

Browse files
committed
Improve: getting a snapshot does not block RaftCore task
`RaftCore` no longer blocks on receiving a snapshot from the state-machine worker while replicating a snapshot. Instead, it sends the `Receiver` to the replication task and the replication task blocks on receiving the snapshot.
1 parent 840f207 commit dcd18c5

File tree

3 files changed

+50
-27
lines changed

3 files changed

+50
-27
lines changed

openraft/src/core/raft_core.rs

+8-12
Original file line numberDiff line numberDiff line change
@@ -1476,25 +1476,21 @@ where
14761476
let _ = node.tx_repl.send(Replicate::logs(id, log_id_range));
14771477
}
14781478
Inflight::Snapshot { id, last_log_id } => {
1479-
// TODO(2): move to another task.
1479+
let _ = last_log_id;
1480+
1481+
// Create a channel to let state machine worker to send the snapshot and the replication
1482+
// worker to receive it.
14801483
let (tx, rx) = oneshot::channel();
14811484

14821485
let cmd = sm::Command::get_snapshot(0, tx);
14831486
self.sm_handle
14841487
.send(cmd)
14851488
.map_err(|e| StorageIOError::read_snapshot(None, AnyError::error(e)))?;
14861489

1487-
let snapshot =
1488-
rx.await.map_err(|e| StorageIOError::read_snapshot(None, AnyError::error(e)))?;
1489-
1490-
tracing::debug!("snapshot: {}", snapshot.as_ref().map(|x| &x.meta).summary());
1491-
1492-
if let Some(snapshot) = snapshot {
1493-
debug_assert_eq!(last_log_id, snapshot.meta.last_log_id);
1494-
let _ = node.tx_repl.send(Replicate::snapshot(id, snapshot));
1495-
} else {
1496-
unreachable!("No snapshot");
1497-
}
1490+
// unwrap: The replication channel must not be dropped or it is a bug.
1491+
node.tx_repl.send(Replicate::snapshot(id, rx)).map_err(|_e| {
1492+
StorageIOError::read_snapshot(None, AnyError::error("replication channel closed"))
1493+
})?;
14981494
}
14991495
}
15001496
} else {

openraft/src/core/sm/mod.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ where
219219

220220
#[tracing::instrument(level = "info", skip_all)]
221221
async fn build_snapshot(&mut self) -> Result<SnapshotMeta<C::NodeId, C::Node>, StorageError<C::NodeId>> {
222-
// TODO: move it to another task
222+
// TODO(3): move it to another task
223223
// use futures::future::abortable;
224224
// let (fu, abort_handle) = abortable(async move { builder.build_snapshot().await });
225225

@@ -240,7 +240,10 @@ where
240240

241241
let snapshot = self.state_machine.get_current_snapshot().await?;
242242

243-
tracing::info!("sending back snapshot");
243+
tracing::info!(
244+
"sending back snapshot: meta: {:?}",
245+
snapshot.as_ref().map(|s| s.meta.summary())
246+
);
244247
let _ = tx.send(snapshot);
245248
Ok(())
246249
}

openraft/src/replication/mod.rs

+37-13
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ use std::fmt::Formatter;
77
use std::io::SeekFrom;
88
use std::sync::Arc;
99

10+
use anyerror::AnyError;
1011
use futures::future::FutureExt;
1112
pub(crate) use replication_session_id::ReplicationSessionId;
1213
use tokio::io::AsyncRead;
1314
use tokio::io::AsyncReadExt;
1415
use tokio::io::AsyncSeek;
1516
use tokio::io::AsyncSeekExt;
1617
use tokio::sync::mpsc;
18+
use tokio::sync::oneshot;
1719
use tokio::task::JoinHandle;
1820
use tokio::time::sleep;
1921
use tokio::time::timeout;
@@ -45,6 +47,8 @@ use crate::RPCTypes;
4547
use crate::RaftNetwork;
4648
use crate::RaftNetworkFactory;
4749
use crate::RaftTypeConfig;
50+
use crate::StorageError;
51+
use crate::StorageIOError;
4852
use crate::ToStorageResult;
4953

5054
/// The handle to a spawned replication stream.
@@ -169,7 +173,7 @@ where
169173
repl_id = id;
170174
match r_action {
171175
Payload::Logs(log_id_range) => self.send_log_entries(id, log_id_range).await,
172-
Payload::Snapshot(snapshot) => self.stream_snapshot(id, snapshot).await,
176+
Payload::Snapshot(snapshot_rx) => self.stream_snapshot(id, snapshot_rx).await,
173177
}
174178
}
175179
};
@@ -447,13 +451,33 @@ where
447451
}
448452
}
449453

450-
#[tracing::instrument(level = "trace", skip(self, snapshot))]
454+
#[tracing::instrument(level = "info", skip_all)]
451455
async fn stream_snapshot(
452456
&mut self,
453457
id: u64,
454-
mut snapshot: Snapshot<C::NodeId, C::Node, SM::SnapshotData>,
458+
rx: oneshot::Receiver<Option<Snapshot<C::NodeId, C::Node, SM::SnapshotData>>>,
455459
) -> Result<(), ReplicationError<C::NodeId, C::Node>> {
456-
tracing::debug!(id = display(id), snapshot = debug(&snapshot.meta), "stream_snapshot",);
460+
tracing::info!(id = display(id), "{}", func_name!());
461+
462+
let snapshot = rx.await.map_err(|e| {
463+
let io_err = StorageIOError::read_snapshot(None, AnyError::error(e));
464+
StorageError::IO { source: io_err }
465+
})?;
466+
467+
tracing::info!(
468+
"received snapshot: id={}; meta:{}",
469+
id,
470+
snapshot.as_ref().map(|x| &x.meta).summary()
471+
);
472+
473+
let mut snapshot = match snapshot {
474+
None => {
475+
let io_err = StorageIOError::read_snapshot(None, AnyError::error("snapshot not found"));
476+
let sto_err = StorageError::IO { source: io_err };
477+
return Err(ReplicationError::StorageError(sto_err));
478+
}
479+
Some(x) => x,
480+
};
457481

458482
let err_x = || (ErrorSubject::Snapshot(Some(snapshot.meta.signature())), ErrorVerb::Read);
459483

@@ -572,8 +596,8 @@ where
572596
Payload::Logs(log_id_range) => {
573597
format!("Logs{{id={}, {}}}", self.id, log_id_range)
574598
}
575-
Payload::Snapshot(snapshot) => {
576-
format!("Snapshot{{id={}, {}}}", self.id, snapshot.meta.summary())
599+
Payload::Snapshot(_) => {
600+
format!("Snapshot{{id={}}}", self.id)
577601
}
578602
}
579603
}
@@ -592,10 +616,10 @@ where
592616
}
593617
}
594618

595-
fn new_snapshot(id: u64, snapshot: Snapshot<NID, N, SD>) -> Self {
619+
fn new_snapshot(id: u64, snapshot_rx: oneshot::Receiver<Option<Snapshot<NID, N, SD>>>) -> Self {
596620
Self {
597621
id,
598-
payload: Payload::Snapshot(snapshot),
622+
payload: Payload::Snapshot(snapshot_rx),
599623
}
600624
}
601625
}
@@ -610,7 +634,7 @@ where
610634
SD: AsyncRead + AsyncSeek + Send + Unpin + 'static,
611635
{
612636
Logs(LogIdRange<NID>),
613-
Snapshot(Snapshot<NID, N, SD>),
637+
Snapshot(oneshot::Receiver<Option<Snapshot<NID, N, SD>>>),
614638
}
615639

616640
impl<NID, N, SD> Debug for Payload<NID, N, SD>
@@ -624,8 +648,8 @@ where
624648
Self::Logs(log_id_range) => {
625649
write!(f, "Logs({})", log_id_range)
626650
}
627-
Self::Snapshot(snapshot) => {
628-
write!(f, "Snapshot({:?})", snapshot.meta)
651+
Self::Snapshot(_) => {
652+
write!(f, "Snapshot()")
629653
}
630654
}
631655
}
@@ -665,8 +689,8 @@ where
665689
Self::Data(Data::new_logs(id, log_id_range))
666690
}
667691

668-
pub(crate) fn snapshot(id: u64, snapshot: Snapshot<NID, N, SD>) -> Self {
669-
Self::Data(Data::new_snapshot(id, snapshot))
692+
pub(crate) fn snapshot(id: u64, snapshot_rx: oneshot::Receiver<Option<Snapshot<NID, N, SD>>>) -> Self {
693+
Self::Data(Data::new_snapshot(id, snapshot_rx))
670694
}
671695
}
672696

0 commit comments

Comments
 (0)