Skip to content

Commit c15e651

Browse files
authored
refactor(tcp): use SelectAll for driving listener streams (#3361)
The PR optimizes polling of the listeners in the TCP transport by using `futures::SelectAll` instead of storing them in a queue and polling manually. Resolves #2781.
1 parent 47c1d5a commit c15e651

File tree

1 file changed

+118
-140
lines changed

1 file changed

+118
-140
lines changed

transports/tcp/src/lib.rs

Lines changed: 118 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pub use provider::tokio;
3939
use futures::{
4040
future::{self, Ready},
4141
prelude::*,
42+
stream::SelectAll,
4243
};
4344
use futures_timer::Delay;
4445
use if_watch::IfEvent;
@@ -55,7 +56,7 @@ use std::{
5556
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener},
5657
pin::Pin,
5758
sync::{Arc, RwLock},
58-
task::{Context, Poll},
59+
task::{Context, Poll, Waker},
5960
time::Duration,
6061
};
6162

@@ -312,7 +313,7 @@ where
312313
/// All the active listeners.
313314
/// The [`ListenStream`] struct contains a stream that we want to be pinned. Since the `VecDeque`
314315
/// can be resized, the only way is to use a `Pin<Box<>>`.
315-
listeners: VecDeque<Pin<Box<ListenStream<T>>>>,
316+
listeners: SelectAll<ListenStream<T>>,
316317
/// Pending transport events to return from [`libp2p_core::Transport::poll`].
317318
pending_events:
318319
VecDeque<TransportEvent<<Self as libp2p_core::Transport>::ListenerUpgrade, io::Error>>,
@@ -419,7 +420,7 @@ where
419420
Transport {
420421
port_reuse,
421422
config,
422-
listeners: VecDeque::new(),
423+
listeners: SelectAll::new(),
423424
pending_events: VecDeque::new(),
424425
}
425426
}
@@ -447,18 +448,13 @@ where
447448
let listener = self
448449
.do_listen(id, socket_addr)
449450
.map_err(TransportError::Other)?;
450-
self.listeners.push_back(Box::pin(listener));
451+
self.listeners.push(listener);
451452
Ok(id)
452453
}
453454

454455
fn remove_listener(&mut self, id: ListenerId) -> bool {
455-
if let Some(index) = self.listeners.iter().position(|l| l.listener_id == id) {
456-
self.listeners.remove(index);
457-
self.pending_events
458-
.push_back(TransportEvent::ListenerClosed {
459-
listener_id: id,
460-
reason: Ok(()),
461-
});
456+
if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) {
457+
listener.close(Ok(()));
462458
true
463459
} else {
464460
false
@@ -548,96 +544,14 @@ where
548544
if let Some(event) = self.pending_events.pop_front() {
549545
return Poll::Ready(event);
550546
}
551-
// We remove each element from `listeners` one by one and add them back.
552-
let mut remaining = self.listeners.len();
553-
while let Some(mut listener) = self.listeners.pop_back() {
554-
match TryStream::try_poll_next(listener.as_mut(), cx) {
555-
Poll::Pending => {
556-
self.listeners.push_front(listener);
557-
remaining -= 1;
558-
if remaining == 0 {
559-
break;
560-
}
561-
}
562-
Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade {
563-
upgrade,
564-
local_addr,
565-
remote_addr,
566-
}))) => {
567-
let id = listener.listener_id;
568-
self.listeners.push_front(listener);
569-
return Poll::Ready(TransportEvent::Incoming {
570-
listener_id: id,
571-
upgrade,
572-
local_addr,
573-
send_back_addr: remote_addr,
574-
});
575-
}
576-
Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(a)))) => {
577-
let id = listener.listener_id;
578-
self.listeners.push_front(listener);
579-
return Poll::Ready(TransportEvent::NewAddress {
580-
listener_id: id,
581-
listen_addr: a,
582-
});
583-
}
584-
Poll::Ready(Some(Ok(TcpListenerEvent::AddressExpired(a)))) => {
585-
let id = listener.listener_id;
586-
self.listeners.push_front(listener);
587-
return Poll::Ready(TransportEvent::AddressExpired {
588-
listener_id: id,
589-
listen_addr: a,
590-
});
591-
}
592-
Poll::Ready(Some(Ok(TcpListenerEvent::Error(error)))) => {
593-
let id = listener.listener_id;
594-
self.listeners.push_front(listener);
595-
return Poll::Ready(TransportEvent::ListenerError {
596-
listener_id: id,
597-
error,
598-
});
599-
}
600-
Poll::Ready(None) => {
601-
return Poll::Ready(TransportEvent::ListenerClosed {
602-
listener_id: listener.listener_id,
603-
reason: Ok(()),
604-
});
605-
}
606-
Poll::Ready(Some(Err(err))) => {
607-
return Poll::Ready(TransportEvent::ListenerClosed {
608-
listener_id: listener.listener_id,
609-
reason: Err(err),
610-
});
611-
}
612-
}
547+
548+
match self.listeners.poll_next_unpin(cx) {
549+
Poll::Ready(Some(transport_event)) => Poll::Ready(transport_event),
550+
_ => Poll::Pending,
613551
}
614-
Poll::Pending
615552
}
616553
}
617554

618-
/// Event produced by a [`ListenStream`].
619-
#[derive(Debug)]
620-
enum TcpListenerEvent<S> {
621-
/// The listener is listening on a new additional [`Multiaddr`].
622-
NewAddress(Multiaddr),
623-
/// An upgrade, consisting of the upgrade future, the listener address and the remote address.
624-
Upgrade {
625-
/// The upgrade.
626-
upgrade: Ready<Result<S, io::Error>>,
627-
/// The local address which produced this upgrade.
628-
local_addr: Multiaddr,
629-
/// The remote address which produced this upgrade.
630-
remote_addr: Multiaddr,
631-
},
632-
/// A [`Multiaddr`] is no longer used for listening.
633-
AddressExpired(Multiaddr),
634-
/// A non-fatal error has happened on the listener.
635-
///
636-
/// This event should be generated in order to notify the user that something wrong has
637-
/// happened. The listener, however, continues to run.
638-
Error(io::Error),
639-
}
640-
641555
/// A stream of incoming connections on one or more interfaces.
642556
struct ListenStream<T>
643557
where
@@ -669,6 +583,12 @@ where
669583
sleep_on_error: Duration,
670584
/// The current pause, if any.
671585
pause: Option<Delay>,
586+
/// Pending event to reported.
587+
pending_event: Option<<Self as Stream>::Item>,
588+
/// The listener can be manually closed with [`Transport::remove_listener`](libp2p_core::Transport::remove_listener).
589+
is_closed: bool,
590+
/// The stream must be awaken after it has been closed to deliver the last event.
591+
close_listener_waker: Option<Waker>,
672592
}
673593

674594
impl<T> ListenStream<T>
@@ -694,6 +614,9 @@ where
694614
if_watcher,
695615
pause: None,
696616
sleep_on_error: Duration::from_millis(100),
617+
pending_event: None,
618+
is_closed: false,
619+
close_listener_waker: None,
697620
})
698621
}
699622

@@ -716,6 +639,74 @@ where
716639
.unregister(self.listen_addr.ip(), self.listen_addr.port()),
717640
}
718641
}
642+
643+
/// Close the listener.
644+
///
645+
/// This will create a [`TransportEvent::ListenerClosed`] and
646+
/// terminate the stream once the event has been reported.
647+
fn close(&mut self, reason: Result<(), io::Error>) {
648+
if self.is_closed {
649+
return;
650+
}
651+
self.pending_event = Some(TransportEvent::ListenerClosed {
652+
listener_id: self.listener_id,
653+
reason,
654+
});
655+
self.is_closed = true;
656+
657+
// Wake the stream to deliver the last event.
658+
if let Some(waker) = self.close_listener_waker.take() {
659+
waker.wake();
660+
}
661+
}
662+
663+
/// Poll for a next If Event.
664+
fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
665+
let if_watcher = match self.if_watcher.as_mut() {
666+
Some(if_watcher) => if_watcher,
667+
None => return Poll::Pending,
668+
};
669+
670+
let my_listen_addr_port = self.listen_addr.port();
671+
672+
while let Poll::Ready(Some(event)) = if_watcher.poll_next_unpin(cx) {
673+
match event {
674+
Ok(IfEvent::Up(inet)) => {
675+
let ip = inet.addr();
676+
if self.listen_addr.is_ipv4() == ip.is_ipv4() {
677+
let ma = ip_to_multiaddr(ip, my_listen_addr_port);
678+
log::debug!("New listen address: {}", ma);
679+
self.port_reuse.register(ip, my_listen_addr_port);
680+
return Poll::Ready(TransportEvent::NewAddress {
681+
listener_id: self.listener_id,
682+
listen_addr: ma,
683+
});
684+
}
685+
}
686+
Ok(IfEvent::Down(inet)) => {
687+
let ip = inet.addr();
688+
if self.listen_addr.is_ipv4() == ip.is_ipv4() {
689+
let ma = ip_to_multiaddr(ip, my_listen_addr_port);
690+
log::debug!("Expired listen address: {}", ma);
691+
self.port_reuse.unregister(ip, my_listen_addr_port);
692+
return Poll::Ready(TransportEvent::AddressExpired {
693+
listener_id: self.listener_id,
694+
listen_addr: ma,
695+
});
696+
}
697+
}
698+
Err(error) => {
699+
self.pause = Some(Delay::new(self.sleep_on_error));
700+
return Poll::Ready(TransportEvent::ListenerError {
701+
listener_id: self.listener_id,
702+
error,
703+
});
704+
}
705+
}
706+
}
707+
708+
Poll::Pending
709+
}
719710
}
720711

721712
impl<T> Drop for ListenStream<T>
@@ -733,52 +724,34 @@ where
733724
T::Listener: Unpin,
734725
T::Stream: Unpin,
735726
{
736-
type Item = Result<TcpListenerEvent<T::Stream>, io::Error>;
737-
738-
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
739-
let me = Pin::into_inner(self);
727+
type Item = TransportEvent<Ready<Result<T::Stream, io::Error>>, io::Error>;
740728

741-
if let Some(mut pause) = me.pause.take() {
729+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
730+
if let Some(mut pause) = self.pause.take() {
742731
match pause.poll_unpin(cx) {
743732
Poll::Ready(_) => {}
744733
Poll::Pending => {
745-
me.pause = Some(pause);
734+
self.pause = Some(pause);
746735
return Poll::Pending;
747736
}
748737
}
749738
}
750739

751-
if let Some(if_watcher) = me.if_watcher.as_mut() {
752-
while let Poll::Ready(Some(event)) = if_watcher.poll_next_unpin(cx) {
753-
match event {
754-
Ok(IfEvent::Up(inet)) => {
755-
let ip = inet.addr();
756-
if me.listen_addr.is_ipv4() == ip.is_ipv4() {
757-
let ma = ip_to_multiaddr(ip, me.listen_addr.port());
758-
log::debug!("New listen address: {}", ma);
759-
me.port_reuse.register(ip, me.listen_addr.port());
760-
return Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(ma))));
761-
}
762-
}
763-
Ok(IfEvent::Down(inet)) => {
764-
let ip = inet.addr();
765-
if me.listen_addr.is_ipv4() == ip.is_ipv4() {
766-
let ma = ip_to_multiaddr(ip, me.listen_addr.port());
767-
log::debug!("Expired listen address: {}", ma);
768-
me.port_reuse.unregister(ip, me.listen_addr.port());
769-
return Poll::Ready(Some(Ok(TcpListenerEvent::AddressExpired(ma))));
770-
}
771-
}
772-
Err(err) => {
773-
me.pause = Some(Delay::new(me.sleep_on_error));
774-
return Poll::Ready(Some(Ok(TcpListenerEvent::Error(err))));
775-
}
776-
}
777-
}
740+
if let Some(event) = self.pending_event.take() {
741+
return Poll::Ready(Some(event));
742+
}
743+
744+
if self.is_closed {
745+
// Terminate the stream if the listener closed and all remaining events have been reported.
746+
return Poll::Ready(None);
747+
}
748+
749+
if let Poll::Ready(event) = self.poll_if_addr(cx) {
750+
return Poll::Ready(Some(event));
778751
}
779752

780753
// Take the pending connection from the backlog.
781-
match T::poll_accept(&mut me.listener, cx) {
754+
match T::poll_accept(&mut self.listener, cx) {
782755
Poll::Ready(Ok(Incoming {
783756
local_addr,
784757
remote_addr,
@@ -789,20 +762,25 @@ where
789762

790763
log::debug!("Incoming connection from {} at {}", remote_addr, local_addr);
791764

792-
return Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade {
765+
return Poll::Ready(Some(TransportEvent::Incoming {
766+
listener_id: self.listener_id,
793767
upgrade: future::ok(stream),
794768
local_addr,
795-
remote_addr,
796-
})));
769+
send_back_addr: remote_addr,
770+
}));
797771
}
798-
Poll::Ready(Err(e)) => {
772+
Poll::Ready(Err(error)) => {
799773
// These errors are non-fatal for the listener stream.
800-
me.pause = Some(Delay::new(me.sleep_on_error));
801-
return Poll::Ready(Some(Ok(TcpListenerEvent::Error(e))));
774+
self.pause = Some(Delay::new(self.sleep_on_error));
775+
return Poll::Ready(Some(TransportEvent::ListenerError {
776+
listener_id: self.listener_id,
777+
error,
778+
}));
802779
}
803780
Poll::Pending => {}
804-
};
781+
}
805782

783+
self.close_listener_waker = Some(cx.waker().clone());
806784
Poll::Pending
807785
}
808786
}
@@ -1119,7 +1097,7 @@ mod tests {
11191097
match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await {
11201098
TransportEvent::NewAddress { .. } => {
11211099
// Check that tcp and listener share the same port reuse SocketAddr
1122-
let listener = tcp.listeners.front().unwrap();
1100+
let listener = tcp.listeners.iter().next().unwrap();
11231101
let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip());
11241102
let port_reuse_listener = listener
11251103
.port_reuse
@@ -1188,7 +1166,7 @@ mod tests {
11881166
TransportEvent::NewAddress {
11891167
listen_addr: addr1, ..
11901168
} => {
1191-
let listener1 = tcp.listeners.front().unwrap();
1169+
let listener1 = tcp.listeners.iter().next().unwrap();
11921170
let port_reuse_tcp =
11931171
tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip());
11941172
let port_reuse_listener1 = listener1

0 commit comments

Comments
 (0)