1
- use std:: { future, io, path:: PathBuf , sync:: atomic:: Ordering , time:: Duration } ;
1
+ use std:: {
2
+ collections:: HashMap , future, io, path:: PathBuf , sync:: atomic:: Ordering , time:: Duration ,
3
+ } ;
2
4
3
5
use futures:: { stream:: FuturesOrdered , StreamExt } ;
4
6
use hickory_resolver:: {
@@ -24,16 +26,12 @@ use tokio::{
24
26
mpsc:: { Receiver , Sender } ,
25
27
oneshot,
26
28
} ,
27
- task:: JoinSet ,
29
+ task:: { Id , JoinSet } ,
28
30
} ;
29
31
use tokio_util:: sync:: CancellationToken ;
30
32
use tracing:: { warn, Level } ;
31
33
32
- use crate :: {
33
- error:: { AgentError , AgentResult } ,
34
- metrics:: DNS_REQUEST_COUNT ,
35
- watched_task:: TaskStatus ,
36
- } ;
34
+ use crate :: { error:: AgentResult , metrics:: DNS_REQUEST_COUNT , util:: remote_runtime:: BgTaskStatus } ;
37
35
38
36
#[ derive( Debug ) ]
39
37
pub ( crate ) enum ClientGetAddrInfoRequest {
@@ -54,7 +52,7 @@ impl ClientGetAddrInfoRequest {
54
52
#[ derive( Debug ) ]
55
53
pub ( crate ) struct DnsCommand {
56
54
request : ClientGetAddrInfoRequest ,
57
- response_tx : oneshot:: Sender < Result < DnsLookup , InternalLookupError > > ,
55
+ response_tx : oneshot:: Sender < Result < DnsLookup , ResolveErrorKindInternal > > ,
58
56
}
59
57
60
58
/// Background task for resolving hostnames to IP addresses.
@@ -80,12 +78,11 @@ pub(crate) struct DnsWorker {
80
78
/// Background tasks that handle the DNS requests.
81
79
///
82
80
/// Each of these builds a new [`TokioAsyncResolver`] and performs one lookup.
83
- tasks : JoinSet < ( ) > ,
81
+ tasks : JoinSet < Result < DnsLookup , InternalLookupError > > ,
82
+ response_txs : HashMap < Id , oneshot:: Sender < Result < DnsLookup , ResolveErrorKindInternal > > > ,
84
83
}
85
84
86
85
impl DnsWorker {
87
- pub const TASK_NAME : & ' static str = "DNS worker" ;
88
-
89
86
/// Creates a new instance of this worker.
90
87
/// To run this worker, call [`Self::run`].
91
88
///
@@ -124,6 +121,7 @@ impl DnsWorker {
124
121
attempts,
125
122
support_ipv6,
126
123
tasks : Default :: default ( ) ,
124
+ response_txs : Default :: default ( ) ,
127
125
}
128
126
}
129
127
@@ -203,34 +201,51 @@ impl DnsWorker {
203
201
let attempts = self . attempts ;
204
202
let support_ipv6 = self . support_ipv6 ;
205
203
206
- let lookup_future = async move {
207
- let result = Self :: do_lookup (
208
- etc_path,
209
- message. request . into_v2 ( ) ,
210
- attempts,
211
- timeout,
212
- support_ipv6,
213
- )
214
- . await ;
215
-
216
- let _ = message. response_tx . send ( result) ;
217
- } ;
204
+ let handle = self . tasks . spawn ( Self :: do_lookup (
205
+ etc_path,
206
+ message. request . into_v2 ( ) ,
207
+ attempts,
208
+ timeout,
209
+ support_ipv6,
210
+ ) ) ;
211
+ self . response_txs . insert ( handle. id ( ) , message. response_tx ) ;
218
212
219
213
DNS_REQUEST_COUNT . fetch_add ( 1 , Ordering :: Relaxed ) ;
220
- self . tasks . spawn ( lookup_future) ;
221
214
}
222
215
223
- pub ( crate ) async fn run ( mut self , cancellation_token : CancellationToken ) -> AgentResult < ( ) > {
216
+ pub ( crate ) async fn run ( mut self , cancellation_token : CancellationToken ) {
224
217
loop {
225
218
tokio:: select! {
226
- _ = cancellation_token. cancelled( ) => break Ok ( ( ) ) ,
219
+ _ = cancellation_token. cancelled( ) => break ,
227
220
228
- Some ( .. ) = self . tasks. join_next ( ) => {
221
+ Some ( result ) = self . tasks. join_next_with_id ( ) => {
229
222
DNS_REQUEST_COUNT . fetch_sub( 1 , Ordering :: Relaxed ) ;
223
+ let ( id, result) = match result {
224
+ Ok ( ( id, result) ) => (
225
+ id,
226
+ result. map_err( Into :: into) ,
227
+ ) ,
228
+ Err ( error) => {
229
+ (
230
+ error. id( ) ,
231
+ Err ( ResolveErrorKindInternal :: Message ( "DNS task panicked" . into( ) ) )
232
+ )
233
+ }
234
+ } ;
235
+
236
+ let response_tx = self . response_txs. remove( & id) ;
237
+ match response_tx {
238
+ Some ( response_tx) => {
239
+ let _ = response_tx. send( result) ;
240
+ }
241
+ None => {
242
+ warn!( ?id, "Received a DNS result with no matching response channel" ) ;
243
+ }
244
+ }
230
245
}
231
246
232
247
message = self . request_rx. recv( ) => match message {
233
- None => break Ok ( ( ) ) ,
248
+ None => break ,
234
249
Some ( message) => self . handle_message( message) ,
235
250
} ,
236
251
}
@@ -246,15 +261,15 @@ impl Drop for DnsWorker {
246
261
}
247
262
248
263
pub ( crate ) struct DnsApi {
249
- task_status : TaskStatus ,
264
+ task_status : BgTaskStatus ,
250
265
request_tx : Sender < DnsCommand > ,
251
266
/// [`DnsWorker`] processes all requests concurrently, so we use a combination of [`oneshot`]
252
267
/// channels and [`FuturesOrdered`] to preserve order of responses.
253
- responses : FuturesOrdered < oneshot:: Receiver < Result < DnsLookup , InternalLookupError > > > ,
268
+ responses : FuturesOrdered < oneshot:: Receiver < Result < DnsLookup , ResolveErrorKindInternal > > > ,
254
269
}
255
270
256
271
impl DnsApi {
257
- pub ( crate ) fn new ( task_status : TaskStatus , task_sender : Sender < DnsCommand > ) -> Self {
272
+ pub ( crate ) fn new ( task_status : BgTaskStatus , task_sender : Sender < DnsCommand > ) -> Self {
258
273
Self {
259
274
task_status,
260
275
request_tx : task_sender,
@@ -276,7 +291,7 @@ impl DnsApi {
276
291
response_tx,
277
292
} ;
278
293
if self . request_tx . send ( command) . await . is_err ( ) {
279
- return Err ( self . task_status . unwrap_err ( ) . await ) ;
294
+ return Err ( self . task_status . wait_assert_running ( ) . await ) ;
280
295
}
281
296
282
297
self . responses . push_back ( response_rx) ;
@@ -294,11 +309,14 @@ impl DnsApi {
294
309
return future:: pending ( ) . await ;
295
310
} ;
296
311
297
- let response = response
298
- . map_err ( |_| AgentError :: DnsTaskPanic ) ?
299
- . map_err ( |error| ResponseError :: DnsLookup ( DnsLookupError { kind : error. into ( ) } ) ) ;
300
-
301
- Ok ( GetAddrInfoResponse ( response) )
312
+ match response {
313
+ Ok ( response) => {
314
+ Ok ( GetAddrInfoResponse ( response. map_err ( |kind| {
315
+ ResponseError :: DnsLookup ( DnsLookupError { kind } )
316
+ } ) ) )
317
+ }
318
+ Err ( ..) => Err ( self . task_status . wait_assert_running ( ) . await ) ,
319
+ }
302
320
}
303
321
}
304
322
0 commit comments