Skip to content

Commit 3e16848

Browse files
feat(iroh)!: allow for limiting incoming connections on the router (#3157)
## Description Changes the `ProtocolHandler` trait to allow for intercepting `connecting` and `connection` states explicitly ## Breaking Changes - changed: `iroh::protocol::ProtocolHandler::accept` now takes `Connection` instead of `Connecting` ## Notes & open questions The new limiter could also just be part of the test code/an example, instead of giving users an API, unclear to me ## Change checklist - [ ] Self-review. - [ ] Documentation updates following the [style guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text), if relevant. - [ ] Tests if relevant. - [ ] All breaking changes documented.
1 parent f6b5f5c commit 3e16848

File tree

4 files changed

+142
-19
lines changed

4 files changed

+142
-19
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,8 @@ let router = Router::builder(endpoint)
9797
struct Echo;
9898

9999
impl ProtocolHandler for Echo {
100-
fn accept(self: Arc<Self>, connecting: Connecting) -> BoxedFuture<Result<()>> {
100+
fn accept(&self, connection: Connection) -> BoxedFuture<Result<()>> {
101101
Box::pin(async move {
102-
let connection = connecting.await?;
103102
let (mut send, mut recv) = connection.accept_bi().await?;
104103

105104
// Echo any bytes received back directly.

iroh/examples/echo.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
use anyhow::Result;
1010
use iroh::{
11-
endpoint::Connecting,
11+
endpoint::Connection,
1212
protocol::{ProtocolHandler, Router},
1313
Endpoint, NodeAddr,
1414
};
@@ -73,11 +73,9 @@ impl ProtocolHandler for Echo {
7373
///
7474
/// The returned future runs on a newly spawned tokio task, so it can run as long as
7575
/// the connection lasts.
76-
fn accept(&self, connecting: Connecting) -> BoxFuture<Result<()>> {
76+
fn accept(&self, connection: Connection) -> BoxFuture<Result<()>> {
7777
// We have to return a boxed future from the handler.
7878
Box::pin(async move {
79-
// Wait for the connection to be fully established.
80-
let connection = connecting.await?;
8179
// We can get the remote's node id from the connection.
8280
let node_id = connection.remote_node_id()?;
8381
println!("accepted connection from {node_id}");

iroh/examples/search.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use std::{collections::BTreeSet, sync::Arc};
3434
use anyhow::Result;
3535
use clap::Parser;
3636
use iroh::{
37-
endpoint::Connecting,
37+
endpoint::Connection,
3838
protocol::{ProtocolHandler, Router},
3939
Endpoint, NodeId,
4040
};
@@ -127,12 +127,10 @@ impl ProtocolHandler for BlobSearch {
127127
///
128128
/// The returned future runs on a newly spawned tokio task, so it can run as long as
129129
/// the connection lasts.
130-
fn accept(&self, connecting: Connecting) -> BoxFuture<Result<()>> {
130+
fn accept(&self, connection: Connection) -> BoxFuture<Result<()>> {
131131
let this = self.clone();
132132
// We have to return a boxed future from the handler.
133133
Box::pin(async move {
134-
// Wait for the connection to be fully established.
135-
let connection = connecting.await?;
136134
// We can get the remote's node id from the connection.
137135
let node_id = connection.remote_node_id()?;
138136
println!("accepted connection from {node_id}");

iroh/src/protocol.rs

Lines changed: 137 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
//! ```no_run
66
//! # use anyhow::Result;
77
//! # use futures_lite::future::Boxed as BoxedFuture;
8-
//! # use iroh::{endpoint::Connecting, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr};
8+
//! # use iroh::{endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr};
99
//! #
1010
//! # async fn test_compile() -> Result<()> {
1111
//! let endpoint = Endpoint::builder().discovery_n0().bind().await?;
@@ -22,9 +22,8 @@
2222
//! struct Echo;
2323
//!
2424
//! impl ProtocolHandler for Echo {
25-
//! fn accept(&self, connecting: Connecting) -> BoxedFuture<Result<()>> {
25+
//! fn accept(&self, connection: Connection) -> BoxedFuture<Result<()>> {
2626
//! Box::pin(async move {
27-
//! let connection = connecting.await?;
2827
//! let (mut send, mut recv) = connection.accept_bi().await?;
2928
//!
3029
//! // Echo any bytes received back directly.
@@ -41,6 +40,7 @@
4140
use std::{collections::BTreeMap, sync::Arc};
4241

4342
use anyhow::Result;
43+
use iroh_base::NodeId;
4444
use n0_future::{
4545
boxed::BoxFuture,
4646
join_all,
@@ -50,7 +50,10 @@ use tokio::sync::Mutex;
5050
use tokio_util::sync::CancellationToken;
5151
use tracing::{error, info_span, trace, warn, Instrument};
5252

53-
use crate::{endpoint::Connecting, Endpoint};
53+
use crate::{
54+
endpoint::{Connecting, Connection},
55+
Endpoint,
56+
};
5457

5558
/// The built router.
5659
///
@@ -109,10 +112,20 @@ pub struct RouterBuilder {
109112
/// The protocol handler must then be registered on the node for an ALPN protocol with
110113
/// [`crate::protocol::RouterBuilder::accept`].
111114
pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
115+
/// Optional interception point to handle the `Connecting` state.
116+
///
117+
/// This enables accepting 0-RTT data from clients, among other things.
118+
fn on_connecting(&self, connecting: Connecting) -> BoxFuture<Result<Connection>> {
119+
Box::pin(async move {
120+
let conn = connecting.await?;
121+
Ok(conn)
122+
})
123+
}
124+
112125
/// Handle an incoming connection.
113126
///
114127
/// This runs on a freshly spawned tokio task so this can be long-running.
115-
fn accept(&self, conn: Connecting) -> BoxFuture<Result<()>>;
128+
fn accept(&self, connection: Connection) -> BoxFuture<Result<()>>;
116129

117130
/// Called when the node shuts down.
118131
fn shutdown(&self) -> BoxFuture<()> {
@@ -121,7 +134,11 @@ pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static {
121134
}
122135

123136
impl<T: ProtocolHandler> ProtocolHandler for Arc<T> {
124-
fn accept(&self, conn: Connecting) -> BoxFuture<Result<()>> {
137+
fn on_connecting(&self, conn: Connecting) -> BoxFuture<Result<Connection>> {
138+
self.as_ref().on_connecting(conn)
139+
}
140+
141+
fn accept(&self, conn: Connection) -> BoxFuture<Result<()>> {
125142
self.as_ref().accept(conn)
126143
}
127144

@@ -131,7 +148,11 @@ impl<T: ProtocolHandler> ProtocolHandler for Arc<T> {
131148
}
132149

133150
impl<T: ProtocolHandler> ProtocolHandler for Box<T> {
134-
fn accept(&self, conn: Connecting) -> BoxFuture<Result<()>> {
151+
fn on_connecting(&self, conn: Connecting) -> BoxFuture<Result<Connection>> {
152+
self.as_ref().on_connecting(conn)
153+
}
154+
155+
fn accept(&self, conn: Connection) -> BoxFuture<Result<()>> {
135156
self.as_ref().accept(conn)
136157
}
137158

@@ -350,8 +371,66 @@ async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc<P
350371
warn!("Ignoring connection: unsupported ALPN protocol");
351372
return;
352373
};
353-
if let Err(err) = handler.accept(connecting).await {
354-
warn!("Handling incoming connection ended with error: {err}");
374+
match handler.on_connecting(connecting).await {
375+
Ok(connection) => {
376+
if let Err(err) = handler.accept(connection).await {
377+
warn!("Handling incoming connection ended with error: {err}");
378+
}
379+
}
380+
Err(err) => {
381+
warn!("Handling incoming connecting ended with error: {err}");
382+
}
383+
}
384+
}
385+
386+
/// Wraps an existing protocol, limiting its access,
387+
/// based on the provided function.
388+
///
389+
/// Any refused connection will be closed with an error code of `0` and reason `not allowed`.
390+
#[derive(derive_more::Debug, Clone)]
391+
pub struct AccessLimit<P: ProtocolHandler + Clone> {
392+
proto: P,
393+
#[debug("limiter")]
394+
limiter: Arc<dyn Fn(NodeId) -> bool + Send + Sync + 'static>,
395+
}
396+
397+
impl<P: ProtocolHandler + Clone> AccessLimit<P> {
398+
/// Create a new `AccessLimit`.
399+
///
400+
/// The function should return `true` for nodes that are allowed to
401+
/// connect, and `false` otherwise.
402+
pub fn new<F>(proto: P, limiter: F) -> Self
403+
where
404+
F: Fn(NodeId) -> bool + Send + Sync + 'static,
405+
{
406+
Self {
407+
proto,
408+
limiter: Arc::new(limiter),
409+
}
410+
}
411+
}
412+
413+
impl<P: ProtocolHandler + Clone> ProtocolHandler for AccessLimit<P> {
414+
fn on_connecting(&self, conn: Connecting) -> BoxFuture<Result<Connection>> {
415+
self.proto.on_connecting(conn)
416+
}
417+
418+
fn accept(&self, conn: Connection) -> BoxFuture<Result<()>> {
419+
let this = self.clone();
420+
Box::pin(async move {
421+
let remote = conn.remote_node_id()?;
422+
let is_allowed = (this.limiter)(remote);
423+
if !is_allowed {
424+
conn.close(0u32.into(), b"not allowed");
425+
anyhow::bail!("not allowed");
426+
}
427+
this.proto.accept(conn).await?;
428+
Ok(())
429+
})
430+
}
431+
432+
fn shutdown(&self) -> BoxFuture<()> {
433+
self.proto.shutdown()
355434
}
356435
}
357436

@@ -374,4 +453,53 @@ mod tests {
374453

375454
Ok(())
376455
}
456+
457+
// The protocol definition:
458+
#[derive(Debug, Clone)]
459+
struct Echo;
460+
461+
const ECHO_ALPN: &[u8] = b"/iroh/echo/1";
462+
463+
impl ProtocolHandler for Echo {
464+
fn accept(&self, connection: Connection) -> BoxFuture<Result<()>> {
465+
println!("accepting echo");
466+
Box::pin(async move {
467+
let (mut send, mut recv) = connection.accept_bi().await?;
468+
469+
// Echo any bytes received back directly.
470+
let _bytes_sent = tokio::io::copy(&mut recv, &mut send).await?;
471+
472+
send.finish()?;
473+
connection.closed().await;
474+
475+
Ok(())
476+
})
477+
}
478+
}
479+
#[tokio::test]
480+
async fn test_limiter() -> Result<()> {
481+
let e1 = Endpoint::builder().bind().await?;
482+
// deny all access
483+
let proto = AccessLimit::new(Echo, |_node_id| false);
484+
let r1 = Router::builder(e1.clone())
485+
.accept(ECHO_ALPN, proto)
486+
.spawn()
487+
.await?;
488+
489+
let addr1 = r1.endpoint().node_addr().await?;
490+
491+
let e2 = Endpoint::builder().bind().await?;
492+
493+
println!("connecting");
494+
let conn = e2.connect(addr1, ECHO_ALPN).await?;
495+
496+
let (_send, mut recv) = conn.open_bi().await?;
497+
let response = recv.read_to_end(1000).await.unwrap_err();
498+
assert!(format!("{:#?}", response).contains("not allowed"));
499+
500+
r1.shutdown().await?;
501+
e2.close().await;
502+
503+
Ok(())
504+
}
377505
}

0 commit comments

Comments
 (0)