Skip to content

Commit 460baf2

Browse files
authored
Merge pull request from GHSA-3999-5ffv-wp2r
feat: switch pending_frames VecDequeue for an Option to bound it
2 parents cf6456f + af8f693 commit 460baf2

File tree

3 files changed

+86
-81
lines changed

3 files changed

+86
-81
lines changed

test-harness/tests/poll_api.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ fn concurrent_streams() {
5757
const PAYLOAD_SIZE: usize = 128 * 1024;
5858

5959
let data = Msg(vec![0x42; PAYLOAD_SIZE]);
60-
let n_streams = 1000;
60+
let n_streams = 512;
6161

6262
let mut cfg = Config::default();
6363
cfg.set_split_send_size(PAYLOAD_SIZE); // Use a large frame size to speed up the test.

yamux/src/connection.rs

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ struct Active<T> {
286286
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
287287
no_streams_waker: Option<Waker>,
288288

289-
pending_frames: VecDeque<Frame<()>>,
289+
pending_read_frame: Option<Frame<()>>,
290+
pending_write_frame: Option<Frame<()>>,
290291
new_outbound_stream_waker: Option<Waker>,
291292

292293
rtt: rtt::Rtt,
@@ -360,7 +361,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
360361
Mode::Client => 1,
361362
Mode::Server => 2,
362363
},
363-
pending_frames: VecDeque::default(),
364+
pending_read_frame: None,
365+
pending_write_frame: None,
364366
new_outbound_stream_waker: None,
365367
rtt: rtt::Rtt::new(),
366368
accumulated_max_stream_windows: Default::default(),
@@ -369,7 +371,12 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
369371

370372
/// Gracefully close the connection to the remote.
371373
fn close(self) -> Closing<T> {
372-
Closing::new(self.stream_receivers, self.pending_frames, self.socket)
374+
let pending_frames = self
375+
.pending_read_frame
376+
.into_iter()
377+
.chain(self.pending_write_frame)
378+
.collect::<VecDeque<Frame<()>>>();
379+
Closing::new(self.stream_receivers, pending_frames, self.socket)
373380
}
374381

375382
/// Cleanup all our resources.
@@ -392,7 +399,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
392399
continue;
393400
}
394401

395-
if let Some(frame) = self.pending_frames.pop_front() {
402+
// Privilege pending `Pong` and `GoAway` `Frame`s
403+
// over `Frame`s from the receivers.
404+
if let Some(frame) = self
405+
.pending_read_frame
406+
.take()
407+
.or_else(|| self.pending_write_frame.take())
408+
{
396409
self.socket.start_send_unpin(frame)?;
397410
continue;
398411
}
@@ -403,36 +416,63 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
403416
Poll::Pending => {}
404417
}
405418

406-
match self.stream_receivers.poll_next_unpin(cx) {
407-
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
408-
self.on_send_frame(frame);
409-
continue;
410-
}
411-
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
412-
self.on_close_stream(id, ack);
413-
continue;
414-
}
415-
Poll::Ready(Some((id, None))) => {
416-
self.on_drop_stream(id);
417-
continue;
418-
}
419-
Poll::Ready(None) => {
420-
self.no_streams_waker = Some(cx.waker().clone());
419+
if self.pending_write_frame.is_none() {
420+
match self.stream_receivers.poll_next_unpin(cx) {
421+
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
422+
log::trace!(
423+
"{}/{}: sending: {}",
424+
self.id,
425+
frame.header().stream_id(),
426+
frame.header()
427+
);
428+
self.pending_write_frame.replace(frame.into());
429+
continue;
430+
}
431+
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
432+
log::trace!("{}/{}: sending close", self.id, id);
433+
self.pending_write_frame
434+
.replace(Frame::close_stream(id, ack).into());
435+
continue;
436+
}
437+
Poll::Ready(Some((id, None))) => {
438+
if let Some(frame) = self.on_drop_stream(id) {
439+
log::trace!("{}/{}: sending: {}", self.id, id, frame.header());
440+
self.pending_write_frame.replace(frame);
441+
};
442+
continue;
443+
}
444+
Poll::Ready(None) => {
445+
self.no_streams_waker = Some(cx.waker().clone());
446+
}
447+
Poll::Pending => {}
421448
}
422-
Poll::Pending => {}
423449
}
424450

425-
match self.socket.poll_next_unpin(cx) {
426-
Poll::Ready(Some(frame)) => {
427-
if let Some(stream) = self.on_frame(frame?)? {
428-
return Poll::Ready(Ok(stream));
451+
if self.pending_read_frame.is_none() {
452+
match self.socket.poll_next_unpin(cx) {
453+
Poll::Ready(Some(frame)) => {
454+
match self.on_frame(frame?)? {
455+
Action::None => {}
456+
Action::New(stream) => {
457+
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
458+
return Poll::Ready(Ok(stream));
459+
}
460+
Action::Ping(f) => {
461+
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
462+
self.pending_read_frame.replace(f.into());
463+
}
464+
Action::Terminate(f) => {
465+
log::trace!("{}: sending term", self.id);
466+
self.pending_read_frame.replace(f.into());
467+
}
468+
}
469+
continue;
429470
}
430-
continue;
431-
}
432-
Poll::Ready(None) => {
433-
return Poll::Ready(Err(ConnectionError::Closed));
471+
Poll::Ready(None) => {
472+
return Poll::Ready(Err(ConnectionError::Closed));
473+
}
474+
Poll::Pending => {}
434475
}
435-
Poll::Pending => {}
436476
}
437477

438478
// If we make it this far, at least one of the above must have registered a waker.
@@ -463,23 +503,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
463503
Poll::Ready(Ok(stream))
464504
}
465505

466-
fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
467-
log::trace!(
468-
"{}/{}: sending: {}",
469-
self.id,
470-
frame.header().stream_id(),
471-
frame.header()
472-
);
473-
self.pending_frames.push_back(frame.into());
474-
}
475-
476-
fn on_close_stream(&mut self, id: StreamId, ack: bool) {
477-
log::trace!("{}/{}: sending close", self.id, id);
478-
self.pending_frames
479-
.push_back(Frame::close_stream(id, ack).into());
480-
}
481-
482-
fn on_drop_stream(&mut self, stream_id: StreamId) {
506+
fn on_drop_stream(&mut self, stream_id: StreamId) -> Option<Frame<()>> {
483507
let s = self.streams.remove(&stream_id).expect("stream not found");
484508

485509
log::trace!("{}: removing dropped stream {}", self.id, stream_id);
@@ -525,10 +549,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
525549
}
526550
frame
527551
};
528-
if let Some(f) = frame {
529-
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
530-
self.pending_frames.push_back(f.into());
531-
}
552+
frame.map(Into::into)
532553
}
533554

534555
/// Process the result of reading from the socket.
@@ -537,7 +558,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
537558
/// and return a corresponding error, which terminates the connection.
538559
/// Otherwise we process the frame and potentially return a new `Stream`
539560
/// if one was opened by the remote.
540-
fn on_frame(&mut self, frame: Frame<()>) -> Result<Option<Stream>> {
561+
fn on_frame(&mut self, frame: Frame<()>) -> Result<Action> {
541562
log::trace!("{}: received: {}", self.id, frame.header());
542563

543564
if frame.header().flags().contains(header::ACK)
@@ -560,23 +581,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
560581
Tag::Ping => self.on_ping(&frame.into_ping()),
561582
Tag::GoAway => return Err(ConnectionError::Closed),
562583
};
563-
match action {
564-
Action::None => {}
565-
Action::New(stream) => {
566-
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
567-
return Ok(Some(stream));
568-
}
569-
Action::Ping(f) => {
570-
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
571-
self.pending_frames.push_back(f.into());
572-
}
573-
Action::Terminate(f) => {
574-
log::trace!("{}: sending term", self.id);
575-
self.pending_frames.push_back(f.into());
576-
}
577-
}
578-
579-
Ok(None)
584+
Ok(action)
580585
}
581586

582587
fn on_data(&mut self, frame: Frame<Data>) -> Action {

yamux/src/connection/closing.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ where
3030
socket: Fuse<frame::Io<T>>,
3131
) -> Self {
3232
Self {
33-
state: State::ClosingStreamReceiver,
33+
state: State::FlushingPendingFrames,
3434
stream_receivers,
3535
pending_frames,
3636
socket,
@@ -49,6 +49,14 @@ where
4949

5050
loop {
5151
match this.state {
52+
State::FlushingPendingFrames => {
53+
ready!(this.socket.poll_ready_unpin(cx))?;
54+
55+
match this.pending_frames.pop_front() {
56+
Some(frame) => this.socket.start_send_unpin(frame)?,
57+
None => this.state = State::ClosingStreamReceiver,
58+
}
59+
}
5260
State::ClosingStreamReceiver => {
5361
for stream in this.stream_receivers.iter_mut() {
5462
stream.inner_mut().close();
@@ -59,7 +67,7 @@ where
5967
State::DrainingStreamReceiver => {
6068
match this.stream_receivers.poll_next_unpin(cx) {
6169
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
62-
this.pending_frames.push_back(frame.into())
70+
this.pending_frames.push_back(frame.into());
6371
}
6472
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
6573
this.pending_frames
@@ -69,19 +77,11 @@ where
6977
Poll::Pending | Poll::Ready(None) => {
7078
// No more frames from streams, append `Term` frame and flush them all.
7179
this.pending_frames.push_back(Frame::term().into());
72-
this.state = State::FlushingPendingFrames;
80+
this.state = State::ClosingSocket;
7381
continue;
7482
}
7583
}
7684
}
77-
State::FlushingPendingFrames => {
78-
ready!(this.socket.poll_ready_unpin(cx))?;
79-
80-
match this.pending_frames.pop_front() {
81-
Some(frame) => this.socket.start_send_unpin(frame)?,
82-
None => this.state = State::ClosingSocket,
83-
}
84-
}
8585
State::ClosingSocket => {
8686
ready!(this.socket.poll_close_unpin(cx))?;
8787

@@ -93,8 +93,8 @@ where
9393
}
9494

9595
enum State {
96+
FlushingPendingFrames,
9697
ClosingStreamReceiver,
9798
DrainingStreamReceiver,
98-
FlushingPendingFrames,
9999
ClosingSocket,
100100
}

0 commit comments

Comments
 (0)