@@ -7,7 +7,7 @@ use std::{
7
7
task:: Poll ,
8
8
} ;
9
9
10
- use anyhow:: { ensure, Context , Result } ;
10
+ use anyhow:: { bail , ensure, Context , Result } ;
11
11
use tracing:: { debug, warn} ;
12
12
13
13
use super :: IpFamily ;
@@ -82,7 +82,9 @@ impl UdpSocket {
82
82
// Remove old socket
83
83
let mut guard = self . socket . write ( ) . unwrap ( ) ;
84
84
{
85
- let socket = guard. take ( ) . expect ( "not yet dropped" ) ;
85
+ let Some ( socket) = guard. take ( ) else {
86
+ bail ! ( "cannot rebind closed socket" ) ;
87
+ } ;
86
88
drop ( socket) ;
87
89
}
88
90
@@ -113,13 +115,18 @@ impl UdpSocket {
113
115
}
114
116
115
117
/// Use the socket
116
- pub fn with_socket < F , T > ( & self , f : F ) -> T
118
+ pub fn with_socket < F , T > ( & self , f : F ) -> std :: io :: Result < T >
117
119
where
118
120
F : FnOnce ( & tokio:: net:: UdpSocket ) -> T ,
119
121
{
120
122
let guard = self . socket . read ( ) . unwrap ( ) ;
121
- let socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
122
- f ( socket)
123
+ let Some ( socket) = guard. as_ref ( ) else {
124
+ return Err ( std:: io:: Error :: new (
125
+ std:: io:: ErrorKind :: BrokenPipe ,
126
+ "socket closed" ,
127
+ ) ) ;
128
+ } ;
129
+ Ok ( f ( socket) )
123
130
}
124
131
125
132
pub fn try_io < R > (
@@ -128,7 +135,12 @@ impl UdpSocket {
128
135
f : impl FnOnce ( ) -> std:: io:: Result < R > ,
129
136
) -> std:: io:: Result < R > {
130
137
let guard = self . socket . read ( ) . unwrap ( ) ;
131
- let socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
138
+ let Some ( socket) = guard. as_ref ( ) else {
139
+ return Err ( std:: io:: Error :: new (
140
+ std:: io:: ErrorKind :: BrokenPipe ,
141
+ "socket closed" ,
142
+ ) ) ;
143
+ } ;
132
144
socket. try_io ( interest, f)
133
145
}
134
146
@@ -173,7 +185,13 @@ impl UdpSocket {
173
185
pub fn connect ( & self , addr : SocketAddr ) -> std:: io:: Result < ( ) > {
174
186
let mut guard = self . socket . write ( ) . unwrap ( ) ;
175
187
// dance around to make non async connect work
176
- let socket_tokio = guard. take ( ) . expect ( "missing socket" ) ;
188
+ let Some ( socket_tokio) = guard. take ( ) else {
189
+ return Err ( std:: io:: Error :: new (
190
+ std:: io:: ErrorKind :: BrokenPipe ,
191
+ "socket closed" ,
192
+ ) ) ;
193
+ } ;
194
+
177
195
let socket_std = socket_tokio. into_std ( ) ?;
178
196
socket_std. connect ( addr) ?;
179
197
let socket_tokio = tokio:: net:: UdpSocket :: from_std ( socket_std) ?;
@@ -184,30 +202,38 @@ impl UdpSocket {
184
202
/// Returns the local address of this socket.
185
203
pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
186
204
let guard = self . socket . read ( ) . unwrap ( ) ;
187
- let socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
205
+ let Some ( socket) = guard. as_ref ( ) else {
206
+ return Err ( std:: io:: Error :: new (
207
+ std:: io:: ErrorKind :: BrokenPipe ,
208
+ "socket closed" ,
209
+ ) ) ;
210
+ } ;
211
+
188
212
socket. local_addr ( )
189
213
}
190
214
191
215
/// Closes the socket, and waits for the underlying `libc::close` call to be finished.
192
- pub async fn close ( self ) {
193
- let std_sock = self
194
- . socket
195
- . write ( )
196
- . unwrap ( )
197
- . take ( )
198
- . expect ( "not yet dropped" )
199
- . into_std ( ) ;
200
- let res = tokio:: runtime:: Handle :: current ( )
201
- . spawn_blocking ( move || {
202
- // Calls libc::close, which can block
203
- drop ( std_sock) ;
204
- } )
205
- . await ;
206
- if let Err ( err) = res {
207
- warn ! ( "failed to close socket: {:?}" , err) ;
216
+ pub async fn close ( & self ) {
217
+ let socket = self . socket . write ( ) . unwrap ( ) . take ( ) ;
218
+ if let Some ( sock) = socket {
219
+ let std_sock = sock. into_std ( ) ;
220
+ let res = tokio:: runtime:: Handle :: current ( )
221
+ . spawn_blocking ( move || {
222
+ // Calls libc::close, which can block
223
+ drop ( std_sock) ;
224
+ } )
225
+ . await ;
226
+ if let Err ( err) = res {
227
+ warn ! ( "failed to close socket: {:?}" , err) ;
228
+ }
208
229
}
209
230
}
210
231
232
+ /// Check if this socket is closed.
233
+ pub fn is_closed ( & self ) -> bool {
234
+ self . socket . read ( ) . unwrap ( ) . is_none ( )
235
+ }
236
+
211
237
/// Handle potential read errors, updating internal state.
212
238
///
213
239
/// Returns `Some(error)` if the error is fatal otherwise `None.
@@ -255,7 +281,12 @@ impl UdpSocket {
255
281
}
256
282
}
257
283
let guard = self . socket . read ( ) . unwrap ( ) ;
258
- let socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
284
+ let Some ( socket) = guard. as_ref ( ) else {
285
+ return Poll :: Ready ( Err ( std:: io:: Error :: new (
286
+ std:: io:: ErrorKind :: BrokenPipe ,
287
+ "socket closed" ,
288
+ ) ) ) ;
289
+ } ;
259
290
260
291
match socket. poll_send_ready ( cx) {
261
292
Poll :: Pending => return Poll :: Pending ,
@@ -302,7 +333,12 @@ impl Future for RecvFut<'_, '_> {
302
333
}
303
334
304
335
let guard = socket. socket . read ( ) . unwrap ( ) ;
305
- let inner_socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
336
+ let Some ( inner_socket) = guard. as_ref ( ) else {
337
+ return Poll :: Ready ( Err ( std:: io:: Error :: new (
338
+ std:: io:: ErrorKind :: BrokenPipe ,
339
+ "socket closed" ,
340
+ ) ) ) ;
341
+ } ;
306
342
307
343
match inner_socket. poll_recv_ready ( cx) {
308
344
Poll :: Pending => return Poll :: Pending ,
@@ -360,7 +396,12 @@ impl Future for RecvFromFut<'_, '_> {
360
396
}
361
397
}
362
398
let guard = socket. socket . read ( ) . unwrap ( ) ;
363
- let inner_socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
399
+ let Some ( inner_socket) = guard. as_ref ( ) else {
400
+ return Poll :: Ready ( Err ( std:: io:: Error :: new (
401
+ std:: io:: ErrorKind :: BrokenPipe ,
402
+ "socket closed" ,
403
+ ) ) ) ;
404
+ } ;
364
405
365
406
match inner_socket. poll_recv_ready ( cx) {
366
407
Poll :: Pending => return Poll :: Pending ,
@@ -430,7 +471,12 @@ impl Future for SendFut<'_, '_> {
430
471
}
431
472
}
432
473
let guard = self . socket . socket . read ( ) . unwrap ( ) ;
433
- let socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
474
+ let Some ( socket) = guard. as_ref ( ) else {
475
+ return Poll :: Ready ( Err ( std:: io:: Error :: new (
476
+ std:: io:: ErrorKind :: BrokenPipe ,
477
+ "socket closed" ,
478
+ ) ) ) ;
479
+ } ;
434
480
435
481
match socket. poll_send_ready ( c) {
436
482
Poll :: Pending => return Poll :: Pending ,
@@ -488,7 +534,12 @@ impl Future for SendToFut<'_, '_> {
488
534
}
489
535
490
536
let guard = self . socket . socket . read ( ) . unwrap ( ) ;
491
- let socket = guard. as_ref ( ) . expect ( "missing socket" ) ;
537
+ let Some ( socket) = guard. as_ref ( ) else {
538
+ return Poll :: Ready ( Err ( std:: io:: Error :: new (
539
+ std:: io:: ErrorKind :: BrokenPipe ,
540
+ "socket closed" ,
541
+ ) ) ) ;
542
+ } ;
492
543
493
544
match socket. poll_send_ready ( cx) {
494
545
Poll :: Pending => return Poll :: Pending ,
0 commit comments