5
5
//! ```no_run
6
6
//! # use anyhow::Result;
7
7
//! # use futures_lite::future::Boxed as BoxedFuture;
8
- //! # use iroh::{endpoint::Connecting , protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr};
8
+ //! # use iroh::{endpoint::Connection , protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr};
9
9
//! #
10
10
//! # async fn test_compile() -> Result<()> {
11
11
//! let endpoint = Endpoint::builder().discovery_n0().bind().await?;
22
22
//! struct Echo;
23
23
//!
24
24
//! impl ProtocolHandler for Echo {
25
- //! fn accept(&self, connecting: Connecting ) -> BoxedFuture<Result<()>> {
25
+ //! fn accept(&self, connection: Connection ) -> BoxedFuture<Result<()>> {
26
26
//! Box::pin(async move {
27
- //! let connection = connecting.await?;
28
27
//! let (mut send, mut recv) = connection.accept_bi().await?;
29
28
//!
30
29
//! // Echo any bytes received back directly.
41
40
use std:: { collections:: BTreeMap , sync:: Arc } ;
42
41
43
42
use anyhow:: Result ;
43
+ use iroh_base:: NodeId ;
44
44
use n0_future:: {
45
45
boxed:: BoxFuture ,
46
46
join_all,
@@ -50,7 +50,10 @@ use tokio::sync::Mutex;
50
50
use tokio_util:: sync:: CancellationToken ;
51
51
use tracing:: { error, info_span, trace, warn, Instrument } ;
52
52
53
- use crate :: { endpoint:: Connecting , Endpoint } ;
53
+ use crate :: {
54
+ endpoint:: { Connecting , Connection } ,
55
+ Endpoint ,
56
+ } ;
54
57
55
58
/// The built router.
56
59
///
@@ -109,10 +112,20 @@ pub struct RouterBuilder {
109
112
/// The protocol handler must then be registered on the node for an ALPN protocol with
110
113
/// [`crate::protocol::RouterBuilder::accept`].
111
114
pub trait ProtocolHandler : Send + Sync + std:: fmt:: Debug + ' static {
115
+ /// Optional interception point to handle the `Connecting` state.
116
+ ///
117
+ /// This enables accepting 0-RTT data from clients, among other things.
118
+ fn on_connecting ( & self , connecting : Connecting ) -> BoxFuture < Result < Connection > > {
119
+ Box :: pin ( async move {
120
+ let conn = connecting. await ?;
121
+ Ok ( conn)
122
+ } )
123
+ }
124
+
112
125
/// Handle an incoming connection.
113
126
///
114
127
/// This runs on a freshly spawned tokio task so this can be long-running.
115
- fn accept ( & self , conn : Connecting ) -> BoxFuture < Result < ( ) > > ;
128
+ fn accept ( & self , connection : Connection ) -> BoxFuture < Result < ( ) > > ;
116
129
117
130
/// Called when the node shuts down.
118
131
fn shutdown ( & self ) -> BoxFuture < ( ) > {
@@ -121,7 +134,11 @@ pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
121
134
}
122
135
123
136
impl < T : ProtocolHandler > ProtocolHandler for Arc < T > {
124
- fn accept ( & self , conn : Connecting ) -> BoxFuture < Result < ( ) > > {
137
+ fn on_connecting ( & self , conn : Connecting ) -> BoxFuture < Result < Connection > > {
138
+ self . as_ref ( ) . on_connecting ( conn)
139
+ }
140
+
141
+ fn accept ( & self , conn : Connection ) -> BoxFuture < Result < ( ) > > {
125
142
self . as_ref ( ) . accept ( conn)
126
143
}
127
144
@@ -131,7 +148,11 @@ impl<T: ProtocolHandler> ProtocolHandler for Arc<T> {
131
148
}
132
149
133
150
impl < T : ProtocolHandler > ProtocolHandler for Box < T > {
134
- fn accept ( & self , conn : Connecting ) -> BoxFuture < Result < ( ) > > {
151
+ fn on_connecting ( & self , conn : Connecting ) -> BoxFuture < Result < Connection > > {
152
+ self . as_ref ( ) . on_connecting ( conn)
153
+ }
154
+
155
+ fn accept ( & self , conn : Connection ) -> BoxFuture < Result < ( ) > > {
135
156
self . as_ref ( ) . accept ( conn)
136
157
}
137
158
@@ -350,8 +371,66 @@ async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc<P
350
371
warn ! ( "Ignoring connection: unsupported ALPN protocol" ) ;
351
372
return ;
352
373
} ;
353
- if let Err ( err) = handler. accept ( connecting) . await {
354
- warn ! ( "Handling incoming connection ended with error: {err}" ) ;
374
+ match handler. on_connecting ( connecting) . await {
375
+ Ok ( connection) => {
376
+ if let Err ( err) = handler. accept ( connection) . await {
377
+ warn ! ( "Handling incoming connection ended with error: {err}" ) ;
378
+ }
379
+ }
380
+ Err ( err) => {
381
+ warn ! ( "Handling incoming connecting ended with error: {err}" ) ;
382
+ }
383
+ }
384
+ }
385
+
386
+ /// Wraps an existing protocol, limiting its access,
387
+ /// based on the provided function.
388
+ ///
389
+ /// Any refused connection will be closed with an error code of `0` and reason `not allowed`.
390
+ #[ derive( derive_more:: Debug , Clone ) ]
391
+ pub struct AccessLimit < P : ProtocolHandler + Clone > {
392
+ proto : P ,
393
+ #[ debug( "limiter" ) ]
394
+ limiter : Arc < dyn Fn ( NodeId ) -> bool + Send + Sync + ' static > ,
395
+ }
396
+
397
+ impl < P : ProtocolHandler + Clone > AccessLimit < P > {
398
+ /// Create a new `AccessLimit`.
399
+ ///
400
+ /// The function should return `true` for nodes that are allowed to
401
+ /// connect, and `false` otherwise.
402
+ pub fn new < F > ( proto : P , limiter : F ) -> Self
403
+ where
404
+ F : Fn ( NodeId ) -> bool + Send + Sync + ' static ,
405
+ {
406
+ Self {
407
+ proto,
408
+ limiter : Arc :: new ( limiter) ,
409
+ }
410
+ }
411
+ }
412
+
413
+ impl < P : ProtocolHandler + Clone > ProtocolHandler for AccessLimit < P > {
414
+ fn on_connecting ( & self , conn : Connecting ) -> BoxFuture < Result < Connection > > {
415
+ self . proto . on_connecting ( conn)
416
+ }
417
+
418
+ fn accept ( & self , conn : Connection ) -> BoxFuture < Result < ( ) > > {
419
+ let this = self . clone ( ) ;
420
+ Box :: pin ( async move {
421
+ let remote = conn. remote_node_id ( ) ?;
422
+ let is_allowed = ( this. limiter ) ( remote) ;
423
+ if !is_allowed {
424
+ conn. close ( 0u32 . into ( ) , b"not allowed" ) ;
425
+ anyhow:: bail!( "not allowed" ) ;
426
+ }
427
+ this. proto . accept ( conn) . await ?;
428
+ Ok ( ( ) )
429
+ } )
430
+ }
431
+
432
+ fn shutdown ( & self ) -> BoxFuture < ( ) > {
433
+ self . proto . shutdown ( )
355
434
}
356
435
}
357
436
@@ -374,4 +453,53 @@ mod tests {
374
453
375
454
Ok ( ( ) )
376
455
}
456
+
457
+ // The protocol definition:
458
+ #[ derive( Debug , Clone ) ]
459
+ struct Echo ;
460
+
461
+ const ECHO_ALPN : & [ u8 ] = b"/iroh/echo/1" ;
462
+
463
+ impl ProtocolHandler for Echo {
464
+ fn accept ( & self , connection : Connection ) -> BoxFuture < Result < ( ) > > {
465
+ println ! ( "accepting echo" ) ;
466
+ Box :: pin ( async move {
467
+ let ( mut send, mut recv) = connection. accept_bi ( ) . await ?;
468
+
469
+ // Echo any bytes received back directly.
470
+ let _bytes_sent = tokio:: io:: copy ( & mut recv, & mut send) . await ?;
471
+
472
+ send. finish ( ) ?;
473
+ connection. closed ( ) . await ;
474
+
475
+ Ok ( ( ) )
476
+ } )
477
+ }
478
+ }
479
+ #[ tokio:: test]
480
+ async fn test_limiter ( ) -> Result < ( ) > {
481
+ let e1 = Endpoint :: builder ( ) . bind ( ) . await ?;
482
+ // deny all access
483
+ let proto = AccessLimit :: new ( Echo , |_node_id| false ) ;
484
+ let r1 = Router :: builder ( e1. clone ( ) )
485
+ . accept ( ECHO_ALPN , proto)
486
+ . spawn ( )
487
+ . await ?;
488
+
489
+ let addr1 = r1. endpoint ( ) . node_addr ( ) . await ?;
490
+
491
+ let e2 = Endpoint :: builder ( ) . bind ( ) . await ?;
492
+
493
+ println ! ( "connecting" ) ;
494
+ let conn = e2. connect ( addr1, ECHO_ALPN ) . await ?;
495
+
496
+ let ( _send, mut recv) = conn. open_bi ( ) . await ?;
497
+ let response = recv. read_to_end ( 1000 ) . await . unwrap_err ( ) ;
498
+ assert ! ( format!( "{:#?}" , response) . contains( "not allowed" ) ) ;
499
+
500
+ r1. shutdown ( ) . await ?;
501
+ e2. close ( ) . await ;
502
+
503
+ Ok ( ( ) )
504
+ }
377
505
}
0 commit comments