Skip to content

Commit a635592

Browse files
Merge 0eaed7d into 43d0ea4
2 parents 43d0ea4 + 0eaed7d commit a635592

File tree

1 file changed

+46
-36
lines changed

1 file changed

+46
-36
lines changed

iroh/src/protocol.rs

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,12 @@
4242
//! ```
4343
use std::{any::Any, collections::BTreeMap, sync::Arc};
4444

45-
use anyhow::{anyhow, Result};
45+
use anyhow::Result;
4646
use futures_buffered::join_all;
4747
use futures_lite::future::Boxed as BoxedFuture;
48-
use futures_util::{
49-
future::{MapErr, Shared},
50-
FutureExt, TryFutureExt,
51-
};
52-
use tokio::task::{JoinError, JoinSet};
48+
use tokio::{sync::Mutex, task::JoinSet};
5349
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
54-
use tracing::{debug, error, warn};
50+
use tracing::{error, info_span, trace, warn, Instrument};
5551

5652
use crate::{endpoint::Connecting, Endpoint};
5753

@@ -92,17 +88,10 @@ pub struct Router {
9288
endpoint: Endpoint,
9389
protocols: Arc<ProtocolMap>,
9490
// `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl.
95-
// So we need
96-
// - `Shared` so we can `task.await` from all `Node` clones
97-
// - `MapErr` to map the `JoinError` to a `String`, because `JoinError` is `!Clone`
98-
// - `AbortOnDropHandle` to make sure that the `task` is cancelled when all `Node`s are dropped
99-
// (`Shared` acts like an `Arc` around its inner future).
100-
task: Shared<MapErr<AbortOnDropHandle<()>, JoinErrToStr>>,
91+
task: Arc<Mutex<Option<AbortOnDropHandle<()>>>>,
10192
cancel_token: CancellationToken,
10293
}
10394

104-
type JoinErrToStr = Box<dyn Fn(JoinError) -> String + Send + Sync + 'static>;
105-
10695
/// Builder for creating a [`Router`] for accepting protocols.
10796
#[derive(Debug)]
10897
pub struct RouterBuilder {
@@ -201,16 +190,32 @@ impl Router {
201190
&self.endpoint
202191
}
203192

193+
/// Checks if the router is already shutdown.
194+
pub fn is_shutdown(&self) -> bool {
195+
self.cancel_token.is_cancelled()
196+
}
197+
204198
/// Shuts down the accept loop cleanly.
205199
///
200+
/// When this function returns, all [`ProtocolHandler`]s will be shutdown and
201+
/// `Endpoint::close` will have been called.
202+
///
203+
/// If already shutdown, it returns `Ok`.
204+
///
206205
/// If some [`ProtocolHandler`] panicked in the accept loop, this will propagate
207206
/// that panic into the result here.
208-
pub async fn shutdown(self) -> Result<()> {
207+
pub async fn shutdown(&self) -> Result<()> {
208+
if self.is_shutdown() {
209+
return Ok(());
210+
}
211+
209212
// Trigger shutdown of the main run task by activating the cancel token.
210213
self.cancel_token.cancel();
211214

212215
// Wait for the main task to terminate.
213-
self.task.await.map_err(|err| anyhow!(err))?;
216+
if let Some(task) = self.task.lock().await.take() {
217+
task.await?;
218+
}
214219

215220
Ok(())
216221
}
@@ -267,25 +272,16 @@ impl RouterBuilder {
267272
let cancel_token = cancel.clone();
268273

269274
let run_loop_fut = async move {
275+
// Make sure to cancel the token, if this future ever exits.
276+
let _cancel_guard = cancel_token.clone().drop_guard();
277+
270278
let protocols = protos;
271279
loop {
272280
tokio::select! {
273281
biased;
274282
_ = cancel_token.cancelled() => {
275283
break;
276284
},
277-
// handle incoming p2p connections.
278-
incoming = endpoint.accept() => {
279-
let Some(incoming) = incoming else {
280-
break;
281-
};
282-
283-
let protocols = protocols.clone();
284-
join_set.spawn(async move {
285-
handle_connection(incoming, protocols).await;
286-
anyhow::Ok(())
287-
});
288-
},
289285
// handle task terminations and quit on panics.
290286
res = join_set.join_next(), if !join_set.is_empty() => {
291287
match res {
@@ -294,18 +290,34 @@ impl RouterBuilder {
294290
error!("Task panicked: {outer:?}");
295291
break;
296292
} else if outer.is_cancelled() {
297-
debug!("Task cancelled: {outer:?}");
293+
trace!("Task cancelled: {outer:?}");
298294
} else {
299295
error!("Task failed: {outer:?}");
300296
break;
301297
}
302298
}
303-
Some(Ok(Err(inner))) => {
304-
debug!("Task errored: {inner:?}");
299+
Some(Ok(Some(()))) => {
300+
trace!("Task finished");
301+
}
302+
Some(Ok(None)) => {
303+
trace!("Task cancelled");
305304
}
306305
_ => {}
307306
}
308307
},
308+
309+
// handle incoming p2p connections.
310+
incoming = endpoint.accept() => {
311+
let Some(incoming) = incoming else {
312+
break;
313+
};
314+
315+
let protocols = protocols.clone();
316+
let token = cancel_token.child_token();
317+
join_set.spawn(async move {
318+
token.run_until_cancelled(handle_connection(incoming, protocols)).await
319+
}.instrument(info_span!("router.accept")));
320+
},
309321
}
310322
}
311323

@@ -316,14 +328,12 @@ impl RouterBuilder {
316328
join_set.shutdown().await;
317329
};
318330
let task = tokio::task::spawn(run_loop_fut);
319-
let task = AbortOnDropHandle::new(task)
320-
.map_err(Box::new(|e: JoinError| e.to_string()) as JoinErrToStr)
321-
.shared();
331+
let task = AbortOnDropHandle::new(task);
322332

323333
Ok(Router {
324334
endpoint: self.endpoint,
325335
protocols,
326-
task,
336+
task: Arc::new(Mutex::new(Some(task))),
327337
cancel_token: cancel,
328338
})
329339
}

0 commit comments

Comments
 (0)