42
42
//! ```
43
43
use std:: { any:: Any , collections:: BTreeMap , sync:: Arc } ;
44
44
45
- use anyhow:: { anyhow , Result } ;
45
+ use anyhow:: Result ;
46
46
use futures_buffered:: join_all;
47
47
use futures_lite:: future:: Boxed as BoxedFuture ;
48
- use futures_util:: {
49
- future:: { MapErr , Shared } ,
50
- FutureExt , TryFutureExt ,
51
- } ;
52
- use tokio:: task:: { JoinError , JoinSet } ;
48
+ use tokio:: { sync:: Mutex , task:: JoinSet } ;
53
49
use tokio_util:: { sync:: CancellationToken , task:: AbortOnDropHandle } ;
54
- use tracing:: { debug , error , warn} ;
50
+ use tracing:: { error , info_span , trace , warn, Instrument } ;
55
51
56
52
use crate :: { endpoint:: Connecting , Endpoint } ;
57
53
@@ -92,17 +88,10 @@ pub struct Router {
92
88
endpoint : Endpoint ,
93
89
protocols : Arc < ProtocolMap > ,
94
90
// `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl.
95
- // So we need
96
- // - `Shared` so we can `task.await` from all `Node` clones
97
- // - `MapErr` to map the `JoinError` to a `String`, because `JoinError` is `!Clone`
98
- // - `AbortOnDropHandle` to make sure that the `task` is cancelled when all `Node`s are dropped
99
- // (`Shared` acts like an `Arc` around its inner future).
100
- task : Shared < MapErr < AbortOnDropHandle < ( ) > , JoinErrToStr > > ,
91
+ task : Arc < Mutex < Option < AbortOnDropHandle < ( ) > > > > ,
101
92
cancel_token : CancellationToken ,
102
93
}
103
94
104
- type JoinErrToStr = Box < dyn Fn ( JoinError ) -> String + Send + Sync + ' static > ;
105
-
106
95
/// Builder for creating a [`Router`] for accepting protocols.
107
96
#[ derive( Debug ) ]
108
97
pub struct RouterBuilder {
@@ -201,16 +190,32 @@ impl Router {
201
190
& self . endpoint
202
191
}
203
192
193
+ /// Checks if the router is already shutdown.
194
+ pub fn is_shutdown ( & self ) -> bool {
195
+ self . cancel_token . is_cancelled ( )
196
+ }
197
+
204
198
/// Shuts down the accept loop cleanly.
205
199
///
200
+ /// When this function returns, all [`ProtocolHandler`]s will be shutdown and
201
+ /// `Endpoint::close` will have been called.
202
+ ///
203
+ /// If already shutdown, it returns `Ok`.
204
+ ///
206
205
/// If some [`ProtocolHandler`] panicked in the accept loop, this will propagate
207
206
/// that panic into the result here.
208
- pub async fn shutdown ( self ) -> Result < ( ) > {
207
+ pub async fn shutdown ( & self ) -> Result < ( ) > {
208
+ if self . is_shutdown ( ) {
209
+ return Ok ( ( ) ) ;
210
+ }
211
+
209
212
// Trigger shutdown of the main run task by activating the cancel token.
210
213
self . cancel_token . cancel ( ) ;
211
214
212
215
// Wait for the main task to terminate.
213
- self . task . await . map_err ( |err| anyhow ! ( err) ) ?;
216
+ if let Some ( task) = self . task . lock ( ) . await . take ( ) {
217
+ task. await ?;
218
+ }
214
219
215
220
Ok ( ( ) )
216
221
}
@@ -267,25 +272,16 @@ impl RouterBuilder {
267
272
let cancel_token = cancel. clone ( ) ;
268
273
269
274
let run_loop_fut = async move {
275
+ // Make sure to cancel the token, if this future ever exits.
276
+ let _cancel_guard = cancel_token. clone ( ) . drop_guard ( ) ;
277
+
270
278
let protocols = protos;
271
279
loop {
272
280
tokio:: select! {
273
281
biased;
274
282
_ = cancel_token. cancelled( ) => {
275
283
break ;
276
284
} ,
277
- // handle incoming p2p connections.
278
- incoming = endpoint. accept( ) => {
279
- let Some ( incoming) = incoming else {
280
- break ;
281
- } ;
282
-
283
- let protocols = protocols. clone( ) ;
284
- join_set. spawn( async move {
285
- handle_connection( incoming, protocols) . await ;
286
- anyhow:: Ok ( ( ) )
287
- } ) ;
288
- } ,
289
285
// handle task terminations and quit on panics.
290
286
res = join_set. join_next( ) , if !join_set. is_empty( ) => {
291
287
match res {
@@ -294,18 +290,34 @@ impl RouterBuilder {
294
290
error!( "Task panicked: {outer:?}" ) ;
295
291
break ;
296
292
} else if outer. is_cancelled( ) {
297
- debug !( "Task cancelled: {outer:?}" ) ;
293
+ trace !( "Task cancelled: {outer:?}" ) ;
298
294
} else {
299
295
error!( "Task failed: {outer:?}" ) ;
300
296
break ;
301
297
}
302
298
}
303
- Some ( Ok ( Err ( inner) ) ) => {
304
- debug!( "Task errored: {inner:?}" ) ;
299
+ Some ( Ok ( Some ( ( ) ) ) ) => {
300
+ trace!( "Task finished" ) ;
301
+ }
302
+ Some ( Ok ( None ) ) => {
303
+ trace!( "Task cancelled" ) ;
305
304
}
306
305
_ => { }
307
306
}
308
307
} ,
308
+
309
+ // handle incoming p2p connections.
310
+ incoming = endpoint. accept( ) => {
311
+ let Some ( incoming) = incoming else {
312
+ break ;
313
+ } ;
314
+
315
+ let protocols = protocols. clone( ) ;
316
+ let token = cancel_token. child_token( ) ;
317
+ join_set. spawn( async move {
318
+ token. run_until_cancelled( handle_connection( incoming, protocols) ) . await
319
+ } . instrument( info_span!( "router.accept" ) ) ) ;
320
+ } ,
309
321
}
310
322
}
311
323
@@ -316,14 +328,12 @@ impl RouterBuilder {
316
328
join_set. shutdown ( ) . await ;
317
329
} ;
318
330
let task = tokio:: task:: spawn ( run_loop_fut) ;
319
- let task = AbortOnDropHandle :: new ( task)
320
- . map_err ( Box :: new ( |e : JoinError | e. to_string ( ) ) as JoinErrToStr )
321
- . shared ( ) ;
331
+ let task = AbortOnDropHandle :: new ( task) ;
322
332
323
333
Ok ( Router {
324
334
endpoint : self . endpoint ,
325
335
protocols,
326
- task,
336
+ task : Arc :: new ( Mutex :: new ( Some ( task ) ) ) ,
327
337
cancel_token : cancel,
328
338
} )
329
339
}
0 commit comments