Skip to content

Commit 0b5966c

Browse files
actually shut down sockets
1 parent 283f785 commit 0b5966c

File tree

3 files changed

+102
-38
lines changed

3 files changed

+102
-38
lines changed

iroh-net/src/magicsock.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,11 @@ impl Handle {
15131513
}
15141514
self.msock.closing.store(true, Ordering::Relaxed);
15151515
self.msock.actor_sender.send(ActorMessage::Shutdown).await?;
1516+
self.msock.pconn4.close().await;
1517+
if let Some(ref conn) = self.msock.pconn6 {
1518+
conn.close().await;
1519+
}
1520+
15161521
self.msock.closed.store(true, Ordering::SeqCst);
15171522
self.msock.direct_addrs.addrs.shutdown();
15181523

iroh-net/src/magicsock/udp_conn.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl UdpConn {
3030
let sock = bind(addr)?;
3131
let state = sock.with_socket(|socket| {
3232
quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket))
33-
})?;
33+
})??;
3434

3535
Ok(Self {
3636
io: Arc::new(sock),
@@ -45,7 +45,7 @@ impl UdpConn {
4545
// update socket state
4646
let new_state = self.io.with_socket(|socket| {
4747
quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket))
48-
})?;
48+
})??;
4949
*self.inner.write().unwrap() = new_state;
5050
Ok(())
5151
}
@@ -59,6 +59,11 @@ impl UdpConn {
5959
io: self.io.clone(),
6060
})
6161
}
62+
63+
/// Closes the socket for good
64+
pub async fn close(&self) {
65+
self.io.close().await;
66+
}
6267
}
6368

6469
impl AsyncUdpSocket for UdpConn {
@@ -90,8 +95,9 @@ impl AsyncUdpSocket for UdpConn {
9095
})
9196
});
9297
match res {
93-
Ok(()) => return Ok(()),
94-
Err(err) => {
98+
Ok(Ok(())) => return Ok(()),
99+
Err(err) => return Err(err), // closed error
100+
Ok(Err(err)) => {
95101
if err.kind() == std::io::ErrorKind::WouldBlock {
96102
continue;
97103
}
@@ -129,22 +135,23 @@ impl AsyncUdpSocket for UdpConn {
129135
}
130136

131137
match self.io.with_socket(|io| io.poll_recv_ready(cx)) {
132-
Poll::Pending => return Poll::Pending,
133-
Poll::Ready(Ok(())) => {}
134-
Poll::Ready(Err(err)) => match self.io.handle_read_error(err) {
138+
Ok(Poll::Pending) => return Poll::Pending,
139+
Ok(Poll::Ready(Ok(()))) => {}
140+
Ok(Poll::Ready(Err(err))) => match self.io.handle_read_error(err) {
135141
Some(err) => return Poll::Ready(Err(err)),
136142
None => {
137143
continue;
138144
}
139145
},
146+
Err(err) => return Poll::Ready(Err(err)),
140147
}
141148

142149
let res = self.io.try_io(Interest::READABLE, || {
143150
self.io
144151
.with_socket(|io| self.inner.read().unwrap().recv(io.into(), bufs, meta))
145152
});
146153
match res {
147-
Ok(count) => {
154+
Ok(Ok(count)) => {
148155
for meta in meta.iter().take(count) {
149156
trace!(
150157
src = %meta.addr,
@@ -156,7 +163,7 @@ impl AsyncUdpSocket for UdpConn {
156163
}
157164
return Poll::Ready(Ok(count));
158165
}
159-
Err(err) => {
166+
Ok(Err(err)) => {
160167
// ignore spurious wakeups
161168
if err.kind() == std::io::ErrorKind::WouldBlock {
162169
continue;
@@ -168,6 +175,7 @@ impl AsyncUdpSocket for UdpConn {
168175
}
169176
}
170177
}
178+
Err(err) => return Poll::Ready(Err(err)),
171179
}
172180
}
173181
}

net-tools/netwatch/src/udp.rs

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
task::Poll,
88
};
99

10-
use anyhow::{ensure, Context, Result};
10+
use anyhow::{bail, ensure, Context, Result};
1111
use tracing::{debug, warn};
1212

1313
use super::IpFamily;
@@ -82,7 +82,9 @@ impl UdpSocket {
8282
// Remove old socket
8383
let mut guard = self.socket.write().unwrap();
8484
{
85-
let socket = guard.take().expect("not yet dropped");
85+
let Some(socket) = guard.take() else {
86+
bail!("cannot rebind closed socket");
87+
};
8688
drop(socket);
8789
}
8890

@@ -113,13 +115,18 @@ impl UdpSocket {
113115
}
114116

115117
/// Use the socket
116-
pub fn with_socket<F, T>(&self, f: F) -> T
118+
pub fn with_socket<F, T>(&self, f: F) -> std::io::Result<T>
117119
where
118120
F: FnOnce(&tokio::net::UdpSocket) -> T,
119121
{
120122
let guard = self.socket.read().unwrap();
121-
let socket = guard.as_ref().expect("missing socket");
122-
f(socket)
123+
let Some(socket) = guard.as_ref() else {
124+
return Err(std::io::Error::new(
125+
std::io::ErrorKind::BrokenPipe,
126+
"socket closed",
127+
));
128+
};
129+
Ok(f(socket))
123130
}
124131

125132
pub fn try_io<R>(
@@ -128,7 +135,12 @@ impl UdpSocket {
128135
f: impl FnOnce() -> std::io::Result<R>,
129136
) -> std::io::Result<R> {
130137
let guard = self.socket.read().unwrap();
131-
let socket = guard.as_ref().expect("missing socket");
138+
let Some(socket) = guard.as_ref() else {
139+
return Err(std::io::Error::new(
140+
std::io::ErrorKind::BrokenPipe,
141+
"socket closed",
142+
));
143+
};
132144
socket.try_io(interest, f)
133145
}
134146

@@ -173,7 +185,13 @@ impl UdpSocket {
173185
pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> {
174186
let mut guard = self.socket.write().unwrap();
175187
// dance around to make non async connect work
176-
let socket_tokio = guard.take().expect("missing socket");
188+
let Some(socket_tokio) = guard.take() else {
189+
return Err(std::io::Error::new(
190+
std::io::ErrorKind::BrokenPipe,
191+
"socket closed",
192+
));
193+
};
194+
177195
let socket_std = socket_tokio.into_std()?;
178196
socket_std.connect(addr)?;
179197
let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?;
@@ -184,30 +202,38 @@ impl UdpSocket {
184202
/// Returns the local address of this socket.
185203
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
186204
let guard = self.socket.read().unwrap();
187-
let socket = guard.as_ref().expect("missing socket");
205+
let Some(socket) = guard.as_ref() else {
206+
return Err(std::io::Error::new(
207+
std::io::ErrorKind::BrokenPipe,
208+
"socket closed",
209+
));
210+
};
211+
188212
socket.local_addr()
189213
}
190214

191215
/// Closes the socket, and waits for the underlying `libc::close` call to be finished.
192-
pub async fn close(self) {
193-
let std_sock = self
194-
.socket
195-
.write()
196-
.unwrap()
197-
.take()
198-
.expect("not yet dropped")
199-
.into_std();
200-
let res = tokio::runtime::Handle::current()
201-
.spawn_blocking(move || {
202-
// Calls libc::close, which can block
203-
drop(std_sock);
204-
})
205-
.await;
206-
if let Err(err) = res {
207-
warn!("failed to close socket: {:?}", err);
216+
pub async fn close(&self) {
217+
let socket = self.socket.write().unwrap().take();
218+
if let Some(sock) = socket {
219+
let std_sock = sock.into_std();
220+
let res = tokio::runtime::Handle::current()
221+
.spawn_blocking(move || {
222+
// Calls libc::close, which can block
223+
drop(std_sock);
224+
})
225+
.await;
226+
if let Err(err) = res {
227+
warn!("failed to close socket: {:?}", err);
228+
}
208229
}
209230
}
210231

232+
/// Check if this socket is closed.
233+
pub fn is_closed(&self) -> bool {
234+
self.socket.read().unwrap().is_none()
235+
}
236+
211237
/// Handle potential read errors, updating internal state.
212238
///
213239
/// Returns `Some(error)` if the error is fatal otherwise `None.
@@ -255,7 +281,12 @@ impl UdpSocket {
255281
}
256282
}
257283
let guard = self.socket.read().unwrap();
258-
let socket = guard.as_ref().expect("missing socket");
284+
let Some(socket) = guard.as_ref() else {
285+
return Poll::Ready(Err(std::io::Error::new(
286+
std::io::ErrorKind::BrokenPipe,
287+
"socket closed",
288+
)));
289+
};
259290

260291
match socket.poll_send_ready(cx) {
261292
Poll::Pending => return Poll::Pending,
@@ -302,7 +333,12 @@ impl Future for RecvFut<'_, '_> {
302333
}
303334

304335
let guard = socket.socket.read().unwrap();
305-
let inner_socket = guard.as_ref().expect("missing socket");
336+
let Some(inner_socket) = guard.as_ref() else {
337+
return Poll::Ready(Err(std::io::Error::new(
338+
std::io::ErrorKind::BrokenPipe,
339+
"socket closed",
340+
)));
341+
};
306342

307343
match inner_socket.poll_recv_ready(cx) {
308344
Poll::Pending => return Poll::Pending,
@@ -360,7 +396,12 @@ impl Future for RecvFromFut<'_, '_> {
360396
}
361397
}
362398
let guard = socket.socket.read().unwrap();
363-
let inner_socket = guard.as_ref().expect("missing socket");
399+
let Some(inner_socket) = guard.as_ref() else {
400+
return Poll::Ready(Err(std::io::Error::new(
401+
std::io::ErrorKind::BrokenPipe,
402+
"socket closed",
403+
)));
404+
};
364405

365406
match inner_socket.poll_recv_ready(cx) {
366407
Poll::Pending => return Poll::Pending,
@@ -430,7 +471,12 @@ impl Future for SendFut<'_, '_> {
430471
}
431472
}
432473
let guard = self.socket.socket.read().unwrap();
433-
let socket = guard.as_ref().expect("missing socket");
474+
let Some(socket) = guard.as_ref() else {
475+
return Poll::Ready(Err(std::io::Error::new(
476+
std::io::ErrorKind::BrokenPipe,
477+
"socket closed",
478+
)));
479+
};
434480

435481
match socket.poll_send_ready(c) {
436482
Poll::Pending => return Poll::Pending,
@@ -488,7 +534,12 @@ impl Future for SendToFut<'_, '_> {
488534
}
489535

490536
let guard = self.socket.socket.read().unwrap();
491-
let socket = guard.as_ref().expect("missing socket");
537+
let Some(socket) = guard.as_ref() else {
538+
return Poll::Ready(Err(std::io::Error::new(
539+
std::io::ErrorKind::BrokenPipe,
540+
"socket closed",
541+
)));
542+
};
492543

493544
match socket.poll_send_ready(cx) {
494545
Poll::Pending => return Poll::Pending,

0 commit comments

Comments
 (0)