1
1
#![ cfg( any( feature = "async-std" , feature = "tokio" ) ) ]
2
2
3
3
use futures:: channel:: { mpsc, oneshot} ;
4
+ use futures:: future:: BoxFuture ;
4
5
use futures:: future:: { poll_fn, Either } ;
5
6
use futures:: stream:: StreamExt ;
6
7
use futures:: { future, AsyncReadExt , AsyncWriteExt , FutureExt , SinkExt } ;
8
+ use futures_timer:: Delay ;
7
9
use libp2p_core:: muxing:: { StreamMuxerBox , StreamMuxerExt , SubstreamBox } ;
8
10
use libp2p_core:: transport:: { Boxed , OrTransport , TransportEvent } ;
11
+ use libp2p_core:: transport:: { ListenerId , TransportError } ;
9
12
use libp2p_core:: { multiaddr:: Protocol , upgrade, Multiaddr , PeerId , Transport } ;
10
13
use libp2p_noise as noise;
11
14
use libp2p_quic as quic;
@@ -18,6 +21,10 @@ use std::io;
18
21
use std:: num:: NonZeroU8 ;
19
22
use std:: task:: Poll ;
20
23
use std:: time:: Duration ;
24
+ use std:: {
25
+ pin:: Pin ,
26
+ sync:: { Arc , Mutex } ,
27
+ } ;
21
28
22
29
#[ cfg( feature = "tokio" ) ]
23
30
#[ tokio:: test]
@@ -89,6 +96,113 @@ async fn ipv4_dial_ipv6() {
89
96
assert_eq ! ( b_connected, a_peer_id) ;
90
97
}
91
98
99
+ /// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`].
100
+ ///
101
+ /// See https://github.com/libp2p/rust-libp2p/pull/3306 for context.
102
+ #[ cfg( feature = "async-std" ) ]
103
+ #[ async_std:: test]
104
+ async fn wrapped_with_delay ( ) {
105
+ let _ = env_logger:: try_init ( ) ;
106
+
107
+ struct DialDelay ( Arc < Mutex < Boxed < ( PeerId , StreamMuxerBox ) > > > ) ;
108
+
109
+ impl Transport for DialDelay {
110
+ type Output = ( PeerId , StreamMuxerBox ) ;
111
+ type Error = std:: io:: Error ;
112
+ type ListenerUpgrade = Pin < Box < dyn Future < Output = io:: Result < Self :: Output > > + Send > > ;
113
+ type Dial = BoxFuture < ' static , Result < Self :: Output , Self :: Error > > ;
114
+
115
+ fn listen_on (
116
+ & mut self ,
117
+ addr : Multiaddr ,
118
+ ) -> Result < ListenerId , TransportError < Self :: Error > > {
119
+ self . 0 . lock ( ) . unwrap ( ) . listen_on ( addr)
120
+ }
121
+
122
+ fn remove_listener ( & mut self , id : ListenerId ) -> bool {
123
+ self . 0 . lock ( ) . unwrap ( ) . remove_listener ( id)
124
+ }
125
+
126
+ fn address_translation (
127
+ & self ,
128
+ listen : & Multiaddr ,
129
+ observed : & Multiaddr ,
130
+ ) -> Option < Multiaddr > {
131
+ self . 0 . lock ( ) . unwrap ( ) . address_translation ( listen, observed)
132
+ }
133
+
134
+ /// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the
135
+ /// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer
136
+ /// [`Transport::dial`].
137
+ fn dial ( & mut self , addr : Multiaddr ) -> Result < Self :: Dial , TransportError < Self :: Error > > {
138
+ let t = self . 0 . clone ( ) ;
139
+ Ok ( async move {
140
+ // Simulate DNS lookup. Giving the `Transport::poll` the chance to return
141
+ // `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial
142
+ // on the inner transport below.
143
+ Delay :: new ( Duration :: from_millis ( 100 ) ) . await ;
144
+
145
+ let dial = t. lock ( ) . unwrap ( ) . dial ( addr) . map_err ( |e| match e {
146
+ TransportError :: MultiaddrNotSupported ( _) => {
147
+ panic ! ( )
148
+ }
149
+ TransportError :: Other ( e) => e,
150
+ } ) ?;
151
+ dial. await
152
+ }
153
+ . boxed ( ) )
154
+ }
155
+
156
+ fn dial_as_listener (
157
+ & mut self ,
158
+ addr : Multiaddr ,
159
+ ) -> Result < Self :: Dial , TransportError < Self :: Error > > {
160
+ self . 0 . lock ( ) . unwrap ( ) . dial_as_listener ( addr)
161
+ }
162
+
163
+ fn poll (
164
+ self : Pin < & mut Self > ,
165
+ cx : & mut std:: task:: Context < ' _ > ,
166
+ ) -> Poll < TransportEvent < Self :: ListenerUpgrade , Self :: Error > > {
167
+ Pin :: new ( & mut * self . 0 . lock ( ) . unwrap ( ) ) . poll ( cx)
168
+ }
169
+ }
170
+
171
+ let ( a_peer_id, mut a_transport) = create_default_transport :: < quic:: async_std:: Provider > ( ) ;
172
+ let ( b_peer_id, mut b_transport) = {
173
+ let ( id, transport) = create_default_transport :: < quic:: async_std:: Provider > ( ) ;
174
+ ( id, DialDelay ( Arc :: new ( Mutex :: new ( transport) ) ) . boxed ( ) )
175
+ } ;
176
+
177
+ // Spawn A
178
+ let a_addr = start_listening ( & mut a_transport, "/ip6/::1/udp/0/quic-v1" ) . await ;
179
+ let listener = async_std:: task:: spawn ( async move {
180
+ let ( upgrade, _) = a_transport
181
+ . select_next_some ( )
182
+ . await
183
+ . into_incoming ( )
184
+ . unwrap ( ) ;
185
+ let ( peer_id, _) = upgrade. await . unwrap ( ) ;
186
+
187
+ peer_id
188
+ } ) ;
189
+
190
+ // Spawn B
191
+ //
192
+ // Note that the dial is spawned on a different task than the transport allowing the transport
193
+ // task to poll the transport once and then suspend, waiting for the wakeup from the dial.
194
+ let dial = async_std:: task:: spawn ( {
195
+ let dial = b_transport. dial ( a_addr) . unwrap ( ) ;
196
+ async { dial. await . unwrap ( ) . 0 }
197
+ } ) ;
198
+ async_std:: task:: spawn ( async move { b_transport. next ( ) . await } ) ;
199
+
200
+ let ( a_connected, b_connected) = future:: join ( listener, dial) . await ;
201
+
202
+ assert_eq ! ( a_connected, b_peer_id) ;
203
+ assert_eq ! ( b_connected, a_peer_id) ;
204
+ }
205
+
92
206
#[ cfg( feature = "async-std" ) ]
93
207
#[ async_std:: test]
94
208
#[ ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.
0 commit comments