diff --git a/README.md b/README.md index 510fad0f8c..b2f238cf51 100644 --- a/README.md +++ b/README.md @@ -97,9 +97,8 @@ let router = Router::builder(endpoint) struct Echo; impl ProtocolHandler for Echo { - fn accept(self: Arc, connecting: Connecting) -> BoxedFuture> { + fn accept(&self, connection: Connection) -> BoxedFuture> { Box::pin(async move { - let connection = connecting.await?; let (mut send, mut recv) = connection.accept_bi().await?; // Echo any bytes received back directly. diff --git a/iroh/examples/echo.rs b/iroh/examples/echo.rs index 414cd579e1..b0a5d28d49 100644 --- a/iroh/examples/echo.rs +++ b/iroh/examples/echo.rs @@ -8,7 +8,7 @@ use anyhow::Result; use iroh::{ - endpoint::Connecting, + endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr, }; @@ -73,11 +73,9 @@ impl ProtocolHandler for Echo { /// /// The returned future runs on a newly spawned tokio task, so it can run as long as /// the connection lasts. - fn accept(&self, connecting: Connecting) -> BoxFuture> { + fn accept(&self, connection: Connection) -> BoxFuture> { // We have to return a boxed future from the handler. Box::pin(async move { - // Wait for the connection to be fully established. - let connection = connecting.await?; // We can get the remote's node id from the connection. let node_id = connection.remote_node_id()?; println!("accepted connection from {node_id}"); diff --git a/iroh/examples/search.rs b/iroh/examples/search.rs index 1cd265dfa8..a715e590e4 100644 --- a/iroh/examples/search.rs +++ b/iroh/examples/search.rs @@ -34,7 +34,7 @@ use std::{collections::BTreeSet, sync::Arc}; use anyhow::Result; use clap::Parser; use iroh::{ - endpoint::Connecting, + endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeId, }; @@ -127,12 +127,10 @@ impl ProtocolHandler for BlobSearch { /// /// The returned future runs on a newly spawned tokio task, so it can run as long as /// the connection lasts. - fn accept(&self, connecting: Connecting) -> BoxFuture> { + fn accept(&self, connection: Connection) -> BoxFuture> { let this = self.clone(); // We have to return a boxed future from the handler. Box::pin(async move { - // Wait for the connection to be fully established. - let connection = connecting.await?; // We can get the remote's node id from the connection. let node_id = connection.remote_node_id()?; println!("accepted connection from {node_id}"); diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index 574ec68918..d386f4d5ec 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -5,7 +5,7 @@ //! ```no_run //! # use anyhow::Result; //! # use futures_lite::future::Boxed as BoxedFuture; -//! # use iroh::{endpoint::Connecting, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr}; +//! # use iroh::{endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr}; //! # //! # async fn test_compile() -> Result<()> { //! let endpoint = Endpoint::builder().discovery_n0().bind().await?; @@ -22,9 +22,8 @@ //! struct Echo; //! //! impl ProtocolHandler for Echo { -//! fn accept(&self, connecting: Connecting) -> BoxedFuture> { +//! fn accept(&self, connection: Connection) -> BoxedFuture> { //! Box::pin(async move { -//! let connection = connecting.await?; //! let (mut send, mut recv) = connection.accept_bi().await?; //! //! // Echo any bytes received back directly. @@ -41,6 +40,7 @@ use std::{collections::BTreeMap, sync::Arc}; use anyhow::Result; +use iroh_base::NodeId; use n0_future::{ boxed::BoxFuture, join_all, @@ -50,7 +50,10 @@ use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; use tracing::{error, info_span, trace, warn, Instrument}; -use crate::{endpoint::Connecting, Endpoint}; +use crate::{ + endpoint::{Connecting, Connection}, + Endpoint, +}; /// The built router. /// @@ -109,10 +112,20 @@ pub struct RouterBuilder { /// The protocol handler must then be registered on the node for an ALPN protocol with /// [`crate::protocol::RouterBuilder::accept`]. pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static { + /// Optional interception point to handle the `Connecting` state. + /// + /// This enables accepting 0-RTT data from clients, among other things. + fn on_connecting(&self, connecting: Connecting) -> BoxFuture> { + Box::pin(async move { + let conn = connecting.await?; + Ok(conn) + }) + } + /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. - fn accept(&self, conn: Connecting) -> BoxFuture>; + fn accept(&self, connection: Connection) -> BoxFuture>; /// Called when the node shuts down. fn shutdown(&self) -> BoxFuture<()> { @@ -121,7 +134,11 @@ pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static { } impl ProtocolHandler for Arc { - fn accept(&self, conn: Connecting) -> BoxFuture> { + fn on_connecting(&self, conn: Connecting) -> BoxFuture> { + self.as_ref().on_connecting(conn) + } + + fn accept(&self, conn: Connection) -> BoxFuture> { self.as_ref().accept(conn) } @@ -131,7 +148,11 @@ impl ProtocolHandler for Arc { } impl ProtocolHandler for Box { - fn accept(&self, conn: Connecting) -> BoxFuture> { + fn on_connecting(&self, conn: Connecting) -> BoxFuture> { + self.as_ref().on_connecting(conn) + } + + fn accept(&self, conn: Connection) -> BoxFuture> { self.as_ref().accept(conn) } @@ -350,8 +371,66 @@ async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc

{ + if let Err(err) = handler.accept(connection).await { + warn!("Handling incoming connection ended with error: {err}"); + } + } + Err(err) => { + warn!("Handling incoming connecting ended with error: {err}"); + } + } +} + +/// Wraps an existing protocol, limiting its access, +/// based on the provided function. +/// +/// Any refused connection will be closed with an error code of `0` and reason `not allowed`. +#[derive(derive_more::Debug, Clone)] +pub struct AccessLimit { + proto: P, + #[debug("limiter")] + limiter: Arc bool + Send + Sync + 'static>, +} + +impl AccessLimit

{ + /// Create a new `AccessLimit`. + /// + /// The function should return `true` for nodes that are allowed to + /// connect, and `false` otherwise. + pub fn new(proto: P, limiter: F) -> Self + where + F: Fn(NodeId) -> bool + Send + Sync + 'static, + { + Self { + proto, + limiter: Arc::new(limiter), + } + } +} + +impl ProtocolHandler for AccessLimit

{ + fn on_connecting(&self, conn: Connecting) -> BoxFuture> { + self.proto.on_connecting(conn) + } + + fn accept(&self, conn: Connection) -> BoxFuture> { + let this = self.clone(); + Box::pin(async move { + let remote = conn.remote_node_id()?; + let is_allowed = (this.limiter)(remote); + if !is_allowed { + conn.close(0u32.into(), b"not allowed"); + anyhow::bail!("not allowed"); + } + this.proto.accept(conn).await?; + Ok(()) + }) + } + + fn shutdown(&self) -> BoxFuture<()> { + self.proto.shutdown() } } @@ -374,4 +453,53 @@ mod tests { Ok(()) } + + // The protocol definition: + #[derive(Debug, Clone)] + struct Echo; + + const ECHO_ALPN: &[u8] = b"/iroh/echo/1"; + + impl ProtocolHandler for Echo { + fn accept(&self, connection: Connection) -> BoxFuture> { + println!("accepting echo"); + Box::pin(async move { + let (mut send, mut recv) = connection.accept_bi().await?; + + // Echo any bytes received back directly. + let _bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; + + send.finish()?; + connection.closed().await; + + Ok(()) + }) + } + } + #[tokio::test] + async fn test_limiter() -> Result<()> { + let e1 = Endpoint::builder().bind().await?; + // deny all access + let proto = AccessLimit::new(Echo, |_node_id| false); + let r1 = Router::builder(e1.clone()) + .accept(ECHO_ALPN, proto) + .spawn() + .await?; + + let addr1 = r1.endpoint().node_addr().await?; + + let e2 = Endpoint::builder().bind().await?; + + println!("connecting"); + let conn = e2.connect(addr1, ECHO_ALPN).await?; + + let (_send, mut recv) = conn.open_bi().await?; + let response = recv.read_to_end(1000).await.unwrap_err(); + assert!(format!("{:#?}", response).contains("not allowed")); + + r1.shutdown().await?; + e2.close().await; + + Ok(()) + } }