Skip to content

Commit 67df1c1

Browse files
refactor(iroh): move protocol relevant impls into node/protocols (#2831)
Some internal refactorings
1 parent c9d1ba7 commit 67df1c1

File tree

8 files changed

+313
-293
lines changed

8 files changed

+313
-293
lines changed

iroh/src/node.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,18 @@ use iroh_net::{
5959
endpoint::{DirectAddrsStream, RemoteInfo},
6060
AddrInfo, Endpoint, NodeAddr,
6161
};
62-
use protocol::BlobsProtocol;
62+
use protocol::blobs::BlobsProtocol;
6363
use quic_rpc::{transport::ServerEndpoint as _, RpcServer};
6464
use tokio::task::{JoinError, JoinSet};
6565
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
6666
use tracing::{debug, error, info, info_span, trace, warn, Instrument};
6767

68-
use crate::node::{docs::DocsEngine, nodes_storage::store_node_addrs, protocol::ProtocolMap};
68+
use crate::node::{
69+
nodes_storage::store_node_addrs,
70+
protocol::{docs::DocsProtocol, ProtocolMap},
71+
};
6972

7073
mod builder;
71-
mod docs;
7274
mod nodes_storage;
7375
mod protocol;
7476
mod rpc;
@@ -296,7 +298,7 @@ impl<D: iroh_blobs::store::Store> NodeInner<D> {
296298
if let GcPolicy::Interval(gc_period) = gc_policy {
297299
let protocols = protocols.clone();
298300
let handle = local_pool.spawn(move || async move {
299-
let docs_engine = protocols.get_typed::<DocsEngine>(DOCS_ALPN);
301+
let docs_engine = protocols.get_typed::<DocsProtocol>(DOCS_ALPN);
300302
let blobs = protocols
301303
.get_typed::<BlobsProtocol<D>>(iroh_blobs::protocol::ALPN)
302304
.expect("missing blobs");

iroh/src/node/builder.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,12 @@ use tokio::task::JoinError;
3232
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
3333
use tracing::{debug, error_span, trace, Instrument};
3434

35-
use super::{
36-
docs::DocsEngine, rpc_status::RpcStatus, IrohServerEndpoint, JoinErrToStr, Node, NodeInner,
37-
};
35+
use super::{rpc_status::RpcStatus, IrohServerEndpoint, JoinErrToStr, Node, NodeInner};
3836
use crate::{
3937
client::RPC_ALPN,
4038
node::{
4139
nodes_storage::load_node_addrs,
42-
protocol::{BlobsProtocol, ProtocolMap},
40+
protocol::{blobs::BlobsProtocol, docs::DocsProtocol, ProtocolMap},
4341
ProtocolHandler,
4442
},
4543
rpc_protocol::RpcService,
@@ -654,8 +652,8 @@ where
654652
let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone());
655653

656654
// Spawn the docs engine, if enabled.
657-
// This returns None for DocsStorage::Disabled, otherwise Some(DocsEngine).
658-
let docs = DocsEngine::spawn(
655+
// This returns None for DocsStorage::Disabled, otherwise Some(DocsProtocol).
656+
let docs = DocsProtocol::spawn(
659657
self.docs_storage,
660658
self.blobs_store.clone(),
661659
self.storage.default_author_storage(),
@@ -813,7 +811,7 @@ impl<D: iroh_blobs::store::Store> ProtocolBuilder<D> {
813811
store: D,
814812
gossip: Gossip,
815813
downloader: Downloader,
816-
docs: Option<DocsEngine>,
814+
docs: Option<DocsProtocol>,
817815
) -> Self {
818816
// Register blobs.
819817
let blobs_proto = BlobsProtocol::new_with_events(

iroh/src/node/protocol.rs

Lines changed: 5 additions & 269 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,13 @@
11
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};
22

3-
use anyhow::{anyhow, Result};
3+
use anyhow::Result;
44
use futures_lite::future::Boxed as BoxedFuture;
55
use futures_util::future::join_all;
6-
use iroh_blobs::{
7-
downloader::{DownloadRequest, Downloader},
8-
get::{
9-
db::{DownloadProgress, GetState},
10-
Stats,
11-
},
12-
provider::EventSender,
13-
util::{
14-
local_pool::LocalPoolHandle,
15-
progress::{AsyncChannelProgressSender, ProgressSender},
16-
SetTagOption,
17-
},
18-
HashAndFormat, TempTag,
19-
};
20-
use iroh_net::{endpoint::Connecting, Endpoint, NodeAddr};
21-
use tracing::{debug, warn};
6+
use iroh_net::endpoint::Connecting;
227

23-
use crate::{
24-
client::blobs::DownloadMode,
25-
rpc_protocol::blobs::{BatchId, DownloadRequest as BlobDownloadRequest},
26-
};
8+
pub(crate) mod blobs;
9+
pub(crate) mod docs;
10+
pub(crate) mod gossip;
2711

2812
/// Handler for incoming connections.
2913
///
@@ -95,251 +79,3 @@ impl ProtocolMap {
9579
join_all(handlers).await;
9680
}
9781
}
98-
99-
#[derive(Debug)]
100-
pub(crate) struct BlobsProtocol<S> {
101-
rt: LocalPoolHandle,
102-
store: S,
103-
events: EventSender,
104-
downloader: Downloader,
105-
batches: tokio::sync::Mutex<BlobBatches>,
106-
}
107-
108-
/// Name used for logging when new node addresses are added from gossip.
109-
const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download";
110-
111-
/// Keeps track of all the currently active batch operations of the blobs api.
112-
#[derive(Debug, Default)]
113-
pub(crate) struct BlobBatches {
114-
/// Currently active batches
115-
batches: BTreeMap<BatchId, BlobBatch>,
116-
/// Used to generate new batch ids.
117-
max: u64,
118-
}
119-
120-
/// A single batch of blob operations
121-
#[derive(Debug, Default)]
122-
struct BlobBatch {
123-
/// The tags in this batch.
124-
tags: BTreeMap<HashAndFormat, Vec<TempTag>>,
125-
}
126-
127-
impl BlobBatches {
128-
/// Create a new unique batch id.
129-
pub(crate) fn create(&mut self) -> BatchId {
130-
let id = self.max;
131-
self.max += 1;
132-
BatchId(id)
133-
}
134-
135-
/// Store a temp tag in a batch identified by a batch id.
136-
pub(crate) fn store(&mut self, batch: BatchId, tt: TempTag) {
137-
let entry = self.batches.entry(batch).or_default();
138-
entry.tags.entry(tt.hash_and_format()).or_default().push(tt);
139-
}
140-
141-
/// Remove a tag from a batch.
142-
pub(crate) fn remove_one(&mut self, batch: BatchId, content: &HashAndFormat) -> Result<()> {
143-
if let Some(batch) = self.batches.get_mut(&batch) {
144-
if let Some(tags) = batch.tags.get_mut(content) {
145-
tags.pop();
146-
if tags.is_empty() {
147-
batch.tags.remove(content);
148-
}
149-
return Ok(());
150-
}
151-
}
152-
// this can happen if we try to upgrade a tag from an expired batch
153-
anyhow::bail!("tag not found in batch");
154-
}
155-
156-
/// Remove an entire batch.
157-
pub(crate) fn remove(&mut self, batch: BatchId) {
158-
self.batches.remove(&batch);
159-
}
160-
}
161-
162-
impl<S: iroh_blobs::store::Store> BlobsProtocol<S> {
163-
pub(crate) fn new_with_events(
164-
store: S,
165-
rt: LocalPoolHandle,
166-
events: EventSender,
167-
downloader: Downloader,
168-
) -> Self {
169-
Self {
170-
rt,
171-
store,
172-
events,
173-
downloader,
174-
batches: Default::default(),
175-
}
176-
}
177-
178-
pub(crate) fn store(&self) -> &S {
179-
&self.store
180-
}
181-
182-
pub(crate) async fn batches(&self) -> tokio::sync::MutexGuard<'_, BlobBatches> {
183-
self.batches.lock().await
184-
}
185-
186-
pub(crate) async fn download(
187-
&self,
188-
endpoint: Endpoint,
189-
req: BlobDownloadRequest,
190-
progress: AsyncChannelProgressSender<DownloadProgress>,
191-
) -> Result<()> {
192-
let BlobDownloadRequest {
193-
hash,
194-
format,
195-
nodes,
196-
tag,
197-
mode,
198-
} = req;
199-
let hash_and_format = HashAndFormat { hash, format };
200-
let temp_tag = self.store.temp_tag(hash_and_format);
201-
let stats = match mode {
202-
DownloadMode::Queued => {
203-
self.download_queued(endpoint, hash_and_format, nodes, progress.clone())
204-
.await?
205-
}
206-
DownloadMode::Direct => {
207-
self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone())
208-
.await?
209-
}
210-
};
211-
212-
progress.send(DownloadProgress::AllDone(stats)).await.ok();
213-
match tag {
214-
SetTagOption::Named(tag) => {
215-
self.store.set_tag(tag, Some(hash_and_format)).await?;
216-
}
217-
SetTagOption::Auto => {
218-
self.store.create_tag(hash_and_format).await?;
219-
}
220-
}
221-
drop(temp_tag);
222-
223-
Ok(())
224-
}
225-
226-
async fn download_queued(
227-
&self,
228-
endpoint: Endpoint,
229-
hash_and_format: HashAndFormat,
230-
nodes: Vec<NodeAddr>,
231-
progress: AsyncChannelProgressSender<DownloadProgress>,
232-
) -> Result<Stats> {
233-
let mut node_ids = Vec::with_capacity(nodes.len());
234-
let mut any_added = false;
235-
for node in nodes {
236-
node_ids.push(node.node_id);
237-
if !node.info.is_empty() {
238-
endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?;
239-
any_added = true;
240-
}
241-
}
242-
let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some());
243-
anyhow::ensure!(can_download, "no way to reach a node for download");
244-
let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress);
245-
let handle = self.downloader.queue(req).await;
246-
let stats = handle.await?;
247-
Ok(stats)
248-
}
249-
250-
#[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))]
251-
async fn download_direct_from_nodes(
252-
&self,
253-
endpoint: Endpoint,
254-
hash_and_format: HashAndFormat,
255-
nodes: Vec<NodeAddr>,
256-
progress: AsyncChannelProgressSender<DownloadProgress>,
257-
) -> Result<Stats> {
258-
let mut last_err = None;
259-
let mut remaining_nodes = nodes.len();
260-
let mut nodes_iter = nodes.into_iter();
261-
'outer: loop {
262-
match iroh_blobs::get::db::get_to_db_in_steps(
263-
self.store.clone(),
264-
hash_and_format,
265-
progress.clone(),
266-
)
267-
.await?
268-
{
269-
GetState::Complete(stats) => return Ok(stats),
270-
GetState::NeedsConn(needs_conn) => {
271-
let (conn, node_id) = 'inner: loop {
272-
match nodes_iter.next() {
273-
None => break 'outer,
274-
Some(node) => {
275-
remaining_nodes -= 1;
276-
let node_id = node.node_id;
277-
if node_id == endpoint.node_id() {
278-
debug!(
279-
?remaining_nodes,
280-
"skip node {} (it is the node id of ourselves)",
281-
node_id.fmt_short()
282-
);
283-
continue 'inner;
284-
}
285-
match endpoint.connect(node, iroh_blobs::protocol::ALPN).await {
286-
Ok(conn) => break 'inner (conn, node_id),
287-
Err(err) => {
288-
debug!(
289-
?remaining_nodes,
290-
"failed to connect to {}: {err}",
291-
node_id.fmt_short()
292-
);
293-
continue 'inner;
294-
}
295-
}
296-
}
297-
}
298-
};
299-
match needs_conn.proceed(conn).await {
300-
Ok(stats) => return Ok(stats),
301-
Err(err) => {
302-
warn!(
303-
?remaining_nodes,
304-
"failed to download from {}: {err}",
305-
node_id.fmt_short()
306-
);
307-
last_err = Some(err);
308-
}
309-
}
310-
}
311-
}
312-
}
313-
match last_err {
314-
Some(err) => Err(err.into()),
315-
None => Err(anyhow!("No nodes to download from provided")),
316-
}
317-
}
318-
}
319-
320-
impl<S: iroh_blobs::store::Store> ProtocolHandler for BlobsProtocol<S> {
321-
fn accept(self: Arc<Self>, conn: Connecting) -> BoxedFuture<Result<()>> {
322-
Box::pin(async move {
323-
iroh_blobs::provider::handle_connection(
324-
conn.await?,
325-
self.store.clone(),
326-
self.events.clone(),
327-
self.rt.clone(),
328-
)
329-
.await;
330-
Ok(())
331-
})
332-
}
333-
334-
fn shutdown(self: Arc<Self>) -> BoxedFuture<()> {
335-
Box::pin(async move {
336-
self.store.shutdown().await;
337-
})
338-
}
339-
}
340-
341-
impl ProtocolHandler for iroh_gossip::net::Gossip {
342-
fn accept(self: Arc<Self>, conn: Connecting) -> BoxedFuture<Result<()>> {
343-
Box::pin(async move { self.handle_connection(conn.await?).await })
344-
}
345-
}

0 commit comments

Comments
 (0)