Skip to content

feat(mdns): only send listening addresses that match interface #6003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 87 additions & 56 deletions protocols/mdns/src/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
convert::Infallible,
fmt,
future::Future,
io,
io, mem,
net::IpAddr,
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
task::{Context, Poll, Waker},
time::Instant,
};

use futures::{channel::mpsc, Stream, StreamExt};
use if_watch::IfEvent;
use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
use iface::ListenAddressUpdate;
use libp2p_core::{multiaddr::Protocol, transport::PortUse, Endpoint, Multiaddr};
use libp2p_identity::PeerId;
use libp2p_swarm::{
behaviour::FromSwarm, dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour,
Expand All @@ -64,30 +64,22 @@
/// The IfWatcher type.
type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;

type TaskHandle: Abort;

/// Create a new instance of the `IfWatcher` type.
fn new_watcher() -> Result<Self::Watcher, std::io::Error>;

#[track_caller]
fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
}

#[allow(unreachable_pub)] // Not re-exported.
pub trait Abort {
fn abort(self);
fn spawn(task: impl Future<Output = ()> + Send + 'static);
}

/// The type of a [`Behaviour`] using the `async-io` implementation.
#[cfg(feature = "async-io")]
pub mod async_io {
use std::future::Future;

use async_std::task::JoinHandle;
use if_watch::smol::IfWatcher;

use super::Provider;
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort};
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer};

#[doc(hidden)]
pub enum AsyncIo {}
Expand All @@ -96,20 +88,13 @@
type Socket = AsyncUdpSocket;
type Timer = AsyncTimer;
type Watcher = IfWatcher;
type TaskHandle = JoinHandle<()>;

fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
IfWatcher::new()
}

fn spawn(task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
async_std::task::spawn(task)
}
}

impl Abort for JoinHandle<()> {
fn abort(self) {
async_std::task::spawn(self.cancel());
fn spawn(task: impl Future<Output = ()> + Send + 'static) {
async_std::task::spawn(task);
}
}

Expand All @@ -122,10 +107,9 @@
use std::future::Future;

use if_watch::tokio::IfWatcher;
use tokio::task::JoinHandle;

use super::Provider;
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer};

#[doc(hidden)]
pub enum Tokio {}
Expand All @@ -134,20 +118,13 @@
type Socket = TokioUdpSocket;
type Timer = TokioTimer;
type Watcher = IfWatcher;
type TaskHandle = JoinHandle<()>;

fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
IfWatcher::new()
}

fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
tokio::spawn(task)
}
}

impl Abort for JoinHandle<()> {
fn abort(self) {
JoinHandle::abort(&self)
fn spawn(task: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(task);
}
}

Expand All @@ -167,8 +144,8 @@
/// Iface watcher.
if_watch: P::Watcher,

/// Handles to tasks running the mDNS queries.
if_tasks: HashMap<IpAddr, P::TaskHandle>,
/// Channel for sending address updates to interface tasks.
if_tasks: HashMap<IpAddr, mpsc::Sender<ListenAddressUpdate>>,

query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
Expand All @@ -185,16 +162,17 @@
closest_expiration: Option<P::Timer>,

/// The current set of listen addresses.
///
/// This is shared across all interface tasks using an [`RwLock`].
/// The [`Behaviour`] updates this upon new [`FromSwarm`]
/// events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
listen_addresses: Arc<RwLock<ListenAddresses>>,
listen_addresses: ListenAddresses,

local_peer_id: PeerId,

/// Pending behaviour events to be emitted.
pending_events: VecDeque<ToSwarm<Event, Infallible>>,

/// Pending address updates to send to interfaces.
pending_address_updates: Vec<ListenAddressUpdate>,

waker: Waker,
}

impl<P> Behaviour<P>
Expand All @@ -216,6 +194,8 @@
listen_addresses: Default::default(),
local_peer_id,
pending_events: Default::default(),
pending_address_updates: Default::default(),
waker: Waker::noop().clone(),

Check failure on line 198 in protocols/mdns/src/behaviour.rs

View workflow job for this annotation

GitHub Actions / Compile with MSRV

use of unstable library feature 'noop_waker'

Check failure on line 198 in protocols/mdns/src/behaviour.rs

View workflow job for this annotation

GitHub Actions / clippy (1.83.0)

use of unstable library feature 'noop_waker'

Check failure on line 198 in protocols/mdns/src/behaviour.rs

View workflow job for this annotation

GitHub Actions / clippy (beta)

current MSRV (Minimum Supported Rust Version) is `1.83.0` but this item is stable since `1.85.0`
})
}

Expand All @@ -241,6 +221,30 @@
}
self.closest_expiration = Some(P::Timer::at(now));
}

/// Try to send an address update to the interface task that matches the address' IP.
///
/// Returns the address if the sending failed due to a full channel.
fn try_send_address_update(
&mut self,
cx: &mut Context<'_>,
update: ListenAddressUpdate,
) -> Option<ListenAddressUpdate> {
let ip = update.ip_addr()?;
let tx = self.if_tasks.get_mut(&ip)?;
match tx.poll_ready(cx) {
Poll::Ready(Ok(())) => {
tx.start_send(update).expect("Channel is ready.");
None
}
Poll::Ready(Err(e)) if e.is_disconnected() => {
tracing::error!("`InterfaceState` for ip {ip} dropped");
self.if_tasks.remove(&ip);
None
}
_ => Some(update),
}
}
}

impl<P> NetworkBehaviour for Behaviour<P>
Expand Down Expand Up @@ -301,10 +305,14 @@
}

fn on_swarm_event(&mut self, event: FromSwarm) {
self.listen_addresses
.write()
.unwrap_or_else(|e| e.into_inner())
.on_swarm_event(&event);
if !self.listen_addresses.on_swarm_event(&event) {
return;
}
if let Some(update) = ListenAddressUpdate::from_swarm(event).and_then(|update| {
self.try_send_address_update(&mut Context::from_waker(&self.waker.clone()), update)
}) {
self.pending_address_updates.push(update);
}
}

#[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
Expand All @@ -313,6 +321,13 @@
cx: &mut Context<'_>,
) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
loop {
// Send address updates to interface tasks.
for update in mem::take(&mut self.pending_address_updates) {
if let Some(update) = self.try_send_address_update(cx, update) {
self.pending_address_updates.push(update);
}
}

// Check for pending events and emit them.
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(event);
Expand All @@ -322,25 +337,34 @@
while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
match event {
Ok(IfEvent::Up(inet)) => {
let addr = inet.addr();
if addr.is_loopback() {
let ip_addr = inet.addr();
if ip_addr.is_loopback() {
continue;
}
if addr.is_ipv4() && self.config.enable_ipv6
|| addr.is_ipv6() && !self.config.enable_ipv6
if ip_addr.is_ipv4() && self.config.enable_ipv6
|| ip_addr.is_ipv6() && !self.config.enable_ipv6
{
continue;
}
if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
if let Entry::Vacant(e) = self.if_tasks.entry(ip_addr) {
let (addr_tx, addr_rx) = mpsc::channel(10); // Chosen arbitrarily.
let listen_addresses = self
.listen_addresses
.iter()
.filter(|multiaddr| multiaddr_matches_ip(multiaddr, &ip_addr))
.cloned()
.collect();
match InterfaceState::<P::Socket, P::Timer>::new(
addr,
ip_addr,
self.config.clone(),
self.local_peer_id,
self.listen_addresses.clone(),
listen_addresses,
addr_rx,
self.query_response_sender.clone(),
) {
Ok(iface_state) => {
e.insert(P::spawn(iface_state));
P::spawn(iface_state);
e.insert(addr_tx);
}
Err(err) => {
tracing::error!("failed to create `InterfaceState`: {}", err)
Expand All @@ -349,10 +373,8 @@
}
}
Ok(IfEvent::Down(inet)) => {
if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
if self.if_tasks.remove(&inet.addr()).is_some() {
tracing::info!(instance=%inet.addr(), "dropping instance");

handle.abort();
}
}
Err(err) => tracing::error!("if watch returned an error: {}", err),
Expand Down Expand Up @@ -417,11 +439,20 @@
self.closest_expiration = Some(timer);
}

self.waker = cx.waker().clone();
return Poll::Pending;
}
}
}

fn multiaddr_matches_ip(addr: &Multiaddr, ip: &IpAddr) -> bool {
match addr.iter().next() {
Some(Protocol::Ip4(ipv4)) => &IpAddr::V4(ipv4) == ip,
Some(Protocol::Ip6(ipv6)) => &IpAddr::V6(ipv6) == ip,
_ => false,
}
}

/// Event that can be produced by the `Mdns` behaviour.
#[derive(Debug, Clone)]
pub enum Event {
Expand Down
Loading
Loading