Skip to content

Commit dcfa7ec

Browse files
authored
feat(quic): Wake transport when adding a new dialer or listener (#3342)
Wake `quic::GenTransport` if a new dialer or listener is added.
1 parent 778f7a2 commit dcfa7ec

File tree

3 files changed

+134
-3
lines changed

3 files changed

+134
-3
lines changed

transports/quic/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
- Add opt-in support for the `/quic` codepoint, interpreted as QUIC version draft-29.
88
See [PR 3151].
99

10+
- Wake the transport's task when a new dialer or listener is added. See [3342].
11+
1012
[PR 3151]: https://github.com/libp2p/rust-libp2p/pull/3151
13+
[PR 3342]: https://github.com/libp2p/rust-libp2p/pull/3342
1114

1215
# 0.7.0-alpha
1316

transports/quic/src/transport.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ pub struct GenTransport<P: Provider> {
7171
listeners: SelectAll<Listener<P>>,
7272
/// Dialer for each socket family if no matching listener exists.
7373
dialer: HashMap<SocketFamily, Dialer>,
74+
/// Waker to poll the transport again when a new dialer or listener is added.
75+
waker: Option<Waker>,
7476
}
7577

7678
impl<P: Provider> GenTransport<P> {
@@ -84,6 +86,7 @@ impl<P: Provider> GenTransport<P> {
8486
quinn_config,
8587
handshake_timeout,
8688
dialer: HashMap::new(),
89+
waker: None,
8790
support_draft_29,
8891
}
8992
}
@@ -108,6 +111,10 @@ impl<P: Provider> Transport for GenTransport<P> {
108111
)?;
109112
self.listeners.push(listener);
110113

114+
if let Some(waker) = self.waker.take() {
115+
waker.wake();
116+
}
117+
111118
// Remove dialer endpoint so that the endpoint is dropped once the last
112119
// connection that uses it is closed.
113120
// New outbound connections will use the bidirectional (listener) endpoint.
@@ -163,6 +170,9 @@ impl<P: Provider> Transport for GenTransport<P> {
163170
let dialer = match self.dialer.entry(socket_family) {
164171
Entry::Occupied(occupied) => occupied.into_mut(),
165172
Entry::Vacant(vacant) => {
173+
if let Some(waker) = self.waker.take() {
174+
waker.wake();
175+
}
166176
vacant.insert(Dialer::new::<P>(self.quinn_config.clone(), socket_family)?)
167177
}
168178
};
@@ -202,15 +212,19 @@ impl<P: Provider> Transport for GenTransport<P> {
202212
errored.push(*key);
203213
}
204214
}
215+
205216
for key in errored {
206217
// Endpoint driver of dialer crashed.
207218
// Drop dialer and all pending dials so that the connection receiver is notified.
208219
self.dialer.remove(&key);
209220
}
210-
match self.listeners.poll_next_unpin(cx) {
211-
Poll::Ready(Some(ev)) => Poll::Ready(ev),
212-
_ => Poll::Pending,
221+
222+
if let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
223+
return Poll::Ready(ev);
213224
}
225+
226+
self.waker = Some(cx.waker().clone());
227+
Poll::Pending
214228
}
215229
}
216230

transports/quic/tests/smoke.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#![cfg(any(feature = "async-std", feature = "tokio"))]
22

33
use futures::channel::{mpsc, oneshot};
4+
use futures::future::BoxFuture;
45
use futures::future::{poll_fn, Either};
56
use futures::stream::StreamExt;
67
use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
8+
use futures_timer::Delay;
79
use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox};
810
use libp2p_core::transport::{Boxed, OrTransport, TransportEvent};
11+
use libp2p_core::transport::{ListenerId, TransportError};
912
use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport};
1013
use libp2p_noise as noise;
1114
use libp2p_quic as quic;
@@ -18,6 +21,10 @@ use std::io;
1821
use std::num::NonZeroU8;
1922
use std::task::Poll;
2023
use std::time::Duration;
24+
use std::{
25+
pin::Pin,
26+
sync::{Arc, Mutex},
27+
};
2128

2229
#[cfg(feature = "tokio")]
2330
#[tokio::test]
@@ -89,6 +96,113 @@ async fn ipv4_dial_ipv6() {
8996
assert_eq!(b_connected, a_peer_id);
9097
}
9198

99+
/// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`].
100+
///
101+
/// See https://github.com/libp2p/rust-libp2p/pull/3306 for context.
102+
#[cfg(feature = "async-std")]
103+
#[async_std::test]
104+
async fn wrapped_with_delay() {
105+
let _ = env_logger::try_init();
106+
107+
struct DialDelay(Arc<Mutex<Boxed<(PeerId, StreamMuxerBox)>>>);
108+
109+
impl Transport for DialDelay {
110+
type Output = (PeerId, StreamMuxerBox);
111+
type Error = std::io::Error;
112+
type ListenerUpgrade = Pin<Box<dyn Future<Output = io::Result<Self::Output>> + Send>>;
113+
type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
114+
115+
fn listen_on(
116+
&mut self,
117+
addr: Multiaddr,
118+
) -> Result<ListenerId, TransportError<Self::Error>> {
119+
self.0.lock().unwrap().listen_on(addr)
120+
}
121+
122+
fn remove_listener(&mut self, id: ListenerId) -> bool {
123+
self.0.lock().unwrap().remove_listener(id)
124+
}
125+
126+
fn address_translation(
127+
&self,
128+
listen: &Multiaddr,
129+
observed: &Multiaddr,
130+
) -> Option<Multiaddr> {
131+
self.0.lock().unwrap().address_translation(listen, observed)
132+
}
133+
134+
/// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the
135+
/// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer
136+
/// [`Transport::dial`].
137+
fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
138+
let t = self.0.clone();
139+
Ok(async move {
140+
// Simulate DNS lookup. Giving the `Transport::poll` the chance to return
141+
// `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial
142+
// on the inner transport below.
143+
Delay::new(Duration::from_millis(100)).await;
144+
145+
let dial = t.lock().unwrap().dial(addr).map_err(|e| match e {
146+
TransportError::MultiaddrNotSupported(_) => {
147+
panic!()
148+
}
149+
TransportError::Other(e) => e,
150+
})?;
151+
dial.await
152+
}
153+
.boxed())
154+
}
155+
156+
fn dial_as_listener(
157+
&mut self,
158+
addr: Multiaddr,
159+
) -> Result<Self::Dial, TransportError<Self::Error>> {
160+
self.0.lock().unwrap().dial_as_listener(addr)
161+
}
162+
163+
fn poll(
164+
self: Pin<&mut Self>,
165+
cx: &mut std::task::Context<'_>,
166+
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
167+
Pin::new(&mut *self.0.lock().unwrap()).poll(cx)
168+
}
169+
}
170+
171+
let (a_peer_id, mut a_transport) = create_default_transport::<quic::async_std::Provider>();
172+
let (b_peer_id, mut b_transport) = {
173+
let (id, transport) = create_default_transport::<quic::async_std::Provider>();
174+
(id, DialDelay(Arc::new(Mutex::new(transport))).boxed())
175+
};
176+
177+
// Spawn A
178+
let a_addr = start_listening(&mut a_transport, "/ip6/::1/udp/0/quic-v1").await;
179+
let listener = async_std::task::spawn(async move {
180+
let (upgrade, _) = a_transport
181+
.select_next_some()
182+
.await
183+
.into_incoming()
184+
.unwrap();
185+
let (peer_id, _) = upgrade.await.unwrap();
186+
187+
peer_id
188+
});
189+
190+
// Spawn B
191+
//
192+
// Note that the dial is spawned on a different task than the transport allowing the transport
193+
// task to poll the transport once and then suspend, waiting for the wakeup from the dial.
194+
let dial = async_std::task::spawn({
195+
let dial = b_transport.dial(a_addr).unwrap();
196+
async { dial.await.unwrap().0 }
197+
});
198+
async_std::task::spawn(async move { b_transport.next().await });
199+
200+
let (a_connected, b_connected) = future::join(listener, dial).await;
201+
202+
assert_eq!(a_connected, b_peer_id);
203+
assert_eq!(b_connected, a_peer_id);
204+
}
205+
92206
#[cfg(feature = "async-std")]
93207
#[async_std::test]
94208
#[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.

0 commit comments

Comments
 (0)