Skip to content

Commit 13ded84

Browse files
authored
refactor(iroh): Allow to register custom protocols (#2358)
## Description * feat(iroh-net): Allow to set the list of accepted ALPN protocols at runtime * refactor(iroh): Spawning a node can now be performed in two stages: First `Builder::build()` is called, which returns a Future that resolves to a new type `ProtocolBuilder`. The `ProtocolBuilder` is then spawned into the actual, running `Node`. If the intermediate step is not needed, `Builder::spawn` performs both in one call, therefore this change is not breaking. * feat(iroh): Allow to accept custom ALPN protocols in an Iroh node. Introduce a `ProtocolHandler` trait for accepting incoming connections and adds `ProtocolBuilder::accept` to register these handlers per ALPN. * refactor(iroh): Move towards more structured concurrency by spawning tasks into a `JoinSet` * refactor(iroh): Improve shutdown flow and perform more things concurently. originally based on #2357 but now directly to main ## Breaking Changes * `iroh_net::endpoint::make_server_config` takes `Arc<quinn::TransportConfig>` instead of `Option<quinn::TransportConfig>`. If you used the `None` case, replace with `quinn::TransportConfig::default()`. ## Notes & open questions <!-- Any notes, remarks or open questions you have to make about the PR. --> ## Change checklist - [x] Self-review. - [x] Documentation updates if relevant. - [ ] Tests if relevant. - [x] All breaking changes documented.
1 parent 96081e5 commit 13ded84

File tree

8 files changed

+679
-189
lines changed

8 files changed

+679
-189
lines changed

iroh-blobs/src/store/fs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,8 @@ impl Actor {
14861486
let mut msgs = PeekableFlumeReceiver::new(self.state.msgs.clone());
14871487
while let Some(msg) = msgs.recv() {
14881488
if let ActorMessage::Shutdown { tx } = msg {
1489+
// Make sure the database is dropped before we send the reply.
1490+
drop(self);
14891491
if let Some(tx) = tx {
14901492
tx.send(()).ok();
14911493
}

iroh-blobs/src/store/traits.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ pub trait ReadableStore: Map {
295295
}
296296

297297
/// The mutable part of a Bao store.
298-
pub trait Store: ReadableStore + MapMut {
298+
pub trait Store: ReadableStore + MapMut + std::fmt::Debug {
299299
/// This trait method imports a file from a local path.
300300
///
301301
/// `data` is the path to the file.

iroh-net/src/endpoint.rs

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,12 @@ impl Builder {
125125
}
126126
};
127127
let secret_key = self.secret_key.unwrap_or_else(SecretKey::generate);
128-
let mut server_config = make_server_config(
129-
&secret_key,
130-
self.alpn_protocols,
131-
self.transport_config,
132-
self.keylog,
133-
)?;
134-
if let Some(c) = self.concurrent_connections {
135-
server_config.concurrent_connections(c);
136-
}
128+
let static_config = StaticConfig {
129+
transport_config: Arc::new(self.transport_config.unwrap_or_default()),
130+
keylog: self.keylog,
131+
concurrent_connections: self.concurrent_connections,
132+
secret_key: secret_key.clone(),
133+
};
137134
let dns_resolver = self
138135
.dns_resolver
139136
.unwrap_or_else(|| default_resolver().clone());
@@ -149,7 +146,7 @@ impl Builder {
149146
#[cfg(any(test, feature = "test-utils"))]
150147
insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify,
151148
};
152-
Endpoint::bind(Some(server_config), msock_opts, self.keylog).await
149+
Endpoint::bind(static_config, msock_opts, self.alpn_protocols).await
153150
}
154151

155152
// # The very common methods everyone basically needs.
@@ -296,17 +293,41 @@ impl Builder {
296293
}
297294
}
298295

296+
/// Configuration for a [`quinn::Endpoint`] that cannot be changed at runtime.
297+
#[derive(Debug)]
298+
struct StaticConfig {
299+
secret_key: SecretKey,
300+
transport_config: Arc<quinn::TransportConfig>,
301+
keylog: bool,
302+
concurrent_connections: Option<u32>,
303+
}
304+
305+
impl StaticConfig {
306+
/// Create a [`quinn::ServerConfig`] with the specified ALPN protocols.
307+
fn create_server_config(&self, alpn_protocols: Vec<Vec<u8>>) -> Result<quinn::ServerConfig> {
308+
let mut server_config = make_server_config(
309+
&self.secret_key,
310+
alpn_protocols,
311+
self.transport_config.clone(),
312+
self.keylog,
313+
)?;
314+
if let Some(c) = self.concurrent_connections {
315+
server_config.concurrent_connections(c);
316+
}
317+
Ok(server_config)
318+
}
319+
}
320+
299321
/// Creates a [`quinn::ServerConfig`] with the given secret key and limits.
300322
pub fn make_server_config(
301323
secret_key: &SecretKey,
302324
alpn_protocols: Vec<Vec<u8>>,
303-
transport_config: Option<quinn::TransportConfig>,
325+
transport_config: Arc<quinn::TransportConfig>,
304326
keylog: bool,
305327
) -> Result<quinn::ServerConfig> {
306328
let tls_server_config = tls::make_server_config(secret_key, alpn_protocols, keylog)?;
307329
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_server_config));
308-
server_config.transport_config(Arc::new(transport_config.unwrap_or_default()));
309-
330+
server_config.transport_config(transport_config);
310331
Ok(server_config)
311332
}
312333

@@ -334,12 +355,11 @@ pub fn make_server_config(
334355
/// [QUIC]: https://quicwg.org
335356
#[derive(Clone, Debug)]
336357
pub struct Endpoint {
337-
secret_key: Arc<SecretKey>,
338358
msock: Handle,
339359
endpoint: quinn::Endpoint,
340360
rtt_actor: Arc<rtt_actor::RttHandle>,
341-
keylog: bool,
342361
cancel_token: CancellationToken,
362+
static_config: Arc<StaticConfig>,
343363
}
344364

345365
impl Endpoint {
@@ -359,16 +379,17 @@ impl Endpoint {
359379
/// This is for internal use, the public interface is the [`Builder`] obtained from
360380
/// [Self::builder]. See the methods on the builder for documentation of the parameters.
361381
async fn bind(
362-
server_config: Option<quinn::ServerConfig>,
382+
static_config: StaticConfig,
363383
msock_opts: magicsock::Options,
364-
keylog: bool,
384+
initial_alpns: Vec<Vec<u8>>,
365385
) -> Result<Self> {
366-
let secret_key = msock_opts.secret_key.clone();
367-
let span = info_span!("magic_ep", me = %secret_key.public().fmt_short());
386+
let span = info_span!("magic_ep", me = %static_config.secret_key.public().fmt_short());
368387
let _guard = span.enter();
369388
let msock = magicsock::MagicSock::spawn(msock_opts).await?;
370389
trace!("created magicsock");
371390

391+
let server_config = static_config.create_server_config(initial_alpns)?;
392+
372393
let mut endpoint_config = quinn::EndpointConfig::default();
373394
// Setting this to false means that quinn will ignore packets that have the QUIC fixed bit
374395
// set to 0. The fixed bit is the 3rd bit of the first byte of a packet.
@@ -379,22 +400,31 @@ impl Endpoint {
379400

380401
let endpoint = quinn::Endpoint::new_with_abstract_socket(
381402
endpoint_config,
382-
server_config,
403+
Some(server_config),
383404
msock.clone(),
384405
Arc::new(quinn::TokioRuntime),
385406
)?;
386407
trace!("created quinn endpoint");
387408

388409
Ok(Self {
389-
secret_key: Arc::new(secret_key),
390410
msock,
391411
endpoint,
392412
rtt_actor: Arc::new(rtt_actor::RttHandle::new()),
393-
keylog,
394413
cancel_token: CancellationToken::new(),
414+
static_config: Arc::new(static_config),
395415
})
396416
}
397417

418+
/// Set the list of accepted ALPN protocols.
419+
///
420+
/// This will only affect new incoming connections.
421+
/// Note that this *overrides* the current list of ALPNs.
422+
pub fn set_alpns(&self, alpns: Vec<Vec<u8>>) -> Result<()> {
423+
let server_config = self.static_config.create_server_config(alpns)?;
424+
self.endpoint.set_server_config(Some(server_config));
425+
Ok(())
426+
}
427+
398428
// # Methods for establishing connectivity.
399429

400430
/// Connects to a remote [`Endpoint`].
@@ -480,10 +510,10 @@ impl Endpoint {
480510
let client_config = {
481511
let alpn_protocols = vec![alpn.to_vec()];
482512
let tls_client_config = tls::make_client_config(
483-
&self.secret_key,
513+
&self.static_config.secret_key,
484514
Some(*node_id),
485515
alpn_protocols,
486-
self.keylog,
516+
self.static_config.keylog,
487517
)?;
488518
let mut client_config = quinn::ClientConfig::new(Arc::new(tls_client_config));
489519
let mut transport_config = quinn::TransportConfig::default();
@@ -579,15 +609,15 @@ impl Endpoint {
579609

580610
/// Returns the secret_key of this endpoint.
581611
pub fn secret_key(&self) -> &SecretKey {
582-
&self.secret_key
612+
&self.static_config.secret_key
583613
}
584614

585615
/// Returns the node id of this endpoint.
586616
///
587617
/// This ID is the unique addressing information of this node and other peers must know
588618
/// it to be able to connect to this node.
589619
pub fn node_id(&self) -> NodeId {
590-
self.secret_key.public()
620+
self.static_config.secret_key.public()
591621
}
592622

593623
/// Returns the current [`NodeAddr`] for this endpoint.

iroh/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,7 @@ required-features = ["examples"]
102102
[[example]]
103103
name = "client"
104104
required-features = ["examples"]
105+
106+
[[example]]
107+
name = "custom-protocol"
108+
required-features = ["examples"]

iroh/examples/custom-protocol.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
use std::sync::Arc;
2+
3+
use anyhow::Result;
4+
use clap::Parser;
5+
use futures_lite::future::Boxed as BoxedFuture;
6+
use iroh::{
7+
client::MemIroh,
8+
net::{
9+
endpoint::{get_remote_node_id, Connecting},
10+
Endpoint, NodeId,
11+
},
12+
node::ProtocolHandler,
13+
};
14+
use tracing_subscriber::{prelude::*, EnvFilter};
15+
16+
#[derive(Debug, Parser)]
17+
pub struct Cli {
18+
#[clap(subcommand)]
19+
command: Command,
20+
}
21+
22+
#[derive(Debug, Parser)]
23+
pub enum Command {
24+
Accept,
25+
Connect { node: NodeId },
26+
}
27+
28+
#[tokio::main]
29+
async fn main() -> Result<()> {
30+
setup_logging();
31+
let args = Cli::parse();
32+
// create a new node
33+
let builder = iroh::node::Node::memory().build().await?;
34+
let proto = ExampleProto::new(builder.client().clone(), builder.endpoint().clone());
35+
let node = builder
36+
.accept(EXAMPLE_ALPN, Arc::new(proto.clone()))
37+
.spawn()
38+
.await?;
39+
40+
// print the ticket if this is the accepting side
41+
match args.command {
42+
Command::Accept => {
43+
let node_id = node.node_id();
44+
println!("node id: {node_id}");
45+
// wait until ctrl-c
46+
tokio::signal::ctrl_c().await?;
47+
}
48+
Command::Connect { node: node_id } => {
49+
proto.connect(node_id).await?;
50+
}
51+
}
52+
53+
node.shutdown().await?;
54+
55+
Ok(())
56+
}
57+
58+
const EXAMPLE_ALPN: &[u8] = b"example-proto/0";
59+
60+
#[derive(Debug, Clone)]
61+
struct ExampleProto {
62+
client: MemIroh,
63+
endpoint: Endpoint,
64+
}
65+
66+
impl ProtocolHandler for ExampleProto {
67+
fn accept(self: Arc<Self>, connecting: Connecting) -> BoxedFuture<Result<()>> {
68+
Box::pin(async move {
69+
let connection = connecting.await?;
70+
let peer = get_remote_node_id(&connection)?;
71+
println!("accepted connection from {peer}");
72+
let mut send_stream = connection.open_uni().await?;
73+
// Let's create a new blob for each incoming connection.
74+
// This functions as an example of using existing iroh functionality within a protocol
75+
// (you likely don't want to create a new blob for each connection for real)
76+
let content = format!("this blob is created for my beloved peer {peer} ♥");
77+
let hash = self
78+
.client
79+
.blobs()
80+
.add_bytes(content.as_bytes().to_vec())
81+
.await?;
82+
// Send the hash over our custom protocol.
83+
send_stream.write_all(hash.hash.as_bytes()).await?;
84+
send_stream.finish().await?;
85+
println!("closing connection from {peer}");
86+
Ok(())
87+
})
88+
}
89+
}
90+
91+
impl ExampleProto {
92+
pub fn new(client: MemIroh, endpoint: Endpoint) -> Self {
93+
Self { client, endpoint }
94+
}
95+
96+
pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> {
97+
println!("our node id: {}", self.endpoint.node_id());
98+
println!("connecting to {remote_node_id}");
99+
let conn = self
100+
.endpoint
101+
.connect_by_node_id(&remote_node_id, EXAMPLE_ALPN)
102+
.await?;
103+
let mut recv_stream = conn.accept_uni().await?;
104+
let hash_bytes = recv_stream.read_to_end(32).await?;
105+
let hash = iroh::blobs::Hash::from_bytes(hash_bytes.try_into().unwrap());
106+
println!("received hash: {hash}");
107+
self.client
108+
.blobs()
109+
.download(hash, remote_node_id.into())
110+
.await?
111+
.await?;
112+
println!("blob downloaded");
113+
let content = self.client.blobs().read_to_bytes(hash).await?;
114+
let message = String::from_utf8(content.to_vec())?;
115+
println!("blob content: {message}");
116+
Ok(())
117+
}
118+
}
119+
120+
/// Set the RUST_LOG env var to one of {debug,info,warn} to see logging.
121+
fn setup_logging() {
122+
tracing_subscriber::registry()
123+
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
124+
.with(EnvFilter::from_default_env())
125+
.try_init()
126+
.ok();
127+
}

0 commit comments

Comments
 (0)