Skip to content

Commit 2ac7c70

Browse files
joshtriplettsdroege
authored andcommitted
Add a ByteReader wrapper that implements AsyncRead for a Stream of messages
This is useful for programs that want to treat a WebSocket as a stream of bytes.
1 parent a6f8f9a commit 2ac7c70

File tree

3 files changed

+77
-18
lines changed

3 files changed

+77
-18
lines changed

examples/client-bytes.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ use std::env;
1515
use futures::StreamExt;
1616

1717
use async_std::io;
18-
use async_std::prelude::*;
1918
use async_std::task;
2019
use async_tungstenite::async_std::connect_async;
21-
use async_tungstenite::ByteWriter;
20+
use async_tungstenite::{ByteReader, ByteWriter};
2221

2322
async fn run() {
2423
let connect_addr = env::args()
@@ -32,15 +31,11 @@ async fn run() {
3231

3332
let (write, read) = ws_stream.split();
3433
let byte_writer = ByteWriter::new(write);
35-
let stdin_to_ws = task::spawn(async {
36-
io::copy(io::stdin(), byte_writer).await.unwrap();
37-
});
38-
let ws_to_stdout = task::spawn(read.for_each(|message| async {
39-
let data = message.unwrap().into_data();
40-
async_std::io::stdout().write_all(&data).await.unwrap();
41-
}));
42-
stdin_to_ws.await;
43-
ws_to_stdout.await;
34+
let byte_reader = ByteReader::new(read);
35+
let stdin_to_ws = task::spawn(io::copy(io::stdin(), byte_writer));
36+
let ws_to_stdout = task::spawn(io::copy(byte_reader, io::stdout()));
37+
stdin_to_ws.await.unwrap();
38+
ws_to_stdout.await.unwrap();
4439
}
4540

4641
fn main() {

src/bytes.rs

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
//! Provides an abstraction to use `AsyncWrite` to write bytes to a `WebSocketStream`.
1+
//! Provides abstractions to use `AsyncRead` and `AsyncWrite` with a `WebSocketStream`.
22
33
use std::{
44
io,
55
pin::Pin,
66
task::{Context, Poll},
77
};
88

9-
use futures_util::Sink;
9+
use futures_core::stream::Stream;
1010

11-
use crate::{Message, WsError};
11+
use crate::{tungstenite::Bytes, Message, WsError};
1212

1313
/// Treat a `WebSocketStream` as an `AsyncWrite` implementation.
1414
///
1515
/// Every write sends a binary message. If you want to group writes together, consider wrapping
1616
/// this with a `BufWriter`.
17+
#[cfg(feature = "futures-03-sink")]
1718
#[derive(Debug)]
1819
pub struct ByteWriter<S>(S);
1920

21+
#[cfg(feature = "futures-03-sink")]
2022
impl<S> ByteWriter<S> {
2123
/// Create a new `ByteWriter` from a `Sink` that accepts a WebSocket `Message`
2224
#[inline(always)]
@@ -31,13 +33,14 @@ impl<S> ByteWriter<S> {
3133
}
3234
}
3335

36+
#[cfg(feature = "futures-03-sink")]
3437
fn poll_write_helper<S>(
3538
mut s: Pin<&mut ByteWriter<S>>,
3639
cx: &mut Context<'_>,
3740
buf: &[u8],
3841
) -> Poll<io::Result<usize>>
3942
where
40-
S: Sink<Message, Error = WsError> + Unpin,
43+
S: futures_util::Sink<Message, Error = WsError> + Unpin,
4144
{
4245
match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) {
4346
Poll::Ready(Ok(())) => {}
@@ -54,9 +57,10 @@ where
5457
)
5558
}
5659

60+
#[cfg(feature = "futures-03-sink")]
5761
impl<S> futures_io::AsyncWrite for ByteWriter<S>
5862
where
59-
S: Sink<Message, Error = WsError> + Unpin,
63+
S: futures_util::Sink<Message, Error = WsError> + Unpin,
6064
{
6165
fn poll_write(
6266
self: Pin<&mut Self>,
@@ -75,10 +79,11 @@ where
7579
}
7680
}
7781

82+
#[cfg(feature = "futures-03-sink")]
7883
#[cfg(feature = "tokio-runtime")]
7984
impl<S> tokio::io::AsyncWrite for ByteWriter<S>
8085
where
81-
S: Sink<Message, Error = WsError> + Unpin,
86+
S: futures_util::Sink<Message, Error = WsError> + Unpin,
8287
{
8388
fn poll_write(
8489
self: Pin<&mut Self>,
@@ -97,6 +102,65 @@ where
97102
}
98103
}
99104

105+
/// Treat a `WebSocketStream` as an `AsyncRead` implementation.
106+
///
107+
/// This also works with any other `Stream` of `Message`, such as a `SplitStream`.
108+
///
109+
/// Each read will only return data from one message. If you want to combine data from multiple
110+
/// messages into one read, consider wrapping this in a `BufReader`.
111+
#[derive(Debug)]
112+
pub struct ByteReader<S> {
113+
stream: S,
114+
bytes: Option<Bytes>,
115+
}
116+
117+
impl<S> ByteReader<S> {
118+
/// Create a new `ByteReader` from a `Stream` that returns a WebSocket `Message`
119+
#[inline(always)]
120+
pub fn new(stream: S) -> Self {
121+
Self {
122+
stream,
123+
bytes: None,
124+
}
125+
}
126+
}
127+
128+
impl<S> futures_io::AsyncRead for ByteReader<S>
129+
where
130+
S: Stream<Item = Result<Message, WsError>> + Unpin,
131+
{
132+
fn poll_read(
133+
mut self: Pin<&mut Self>,
134+
cx: &mut Context<'_>,
135+
buf: &mut [u8],
136+
) -> Poll<io::Result<usize>> {
137+
let buf_len = buf.len();
138+
let bytes_to_read = match self.bytes {
139+
None => match Pin::new(&mut self.stream).poll_next(cx) {
140+
Poll::Pending => return Poll::Pending,
141+
Poll::Ready(None) => return Poll::Ready(Ok(0)),
142+
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
143+
Poll::Ready(Some(Ok(msg))) => {
144+
let bytes = msg.into_data();
145+
if bytes.len() > buf_len {
146+
self.bytes.insert(bytes).split_to(buf_len)
147+
} else {
148+
bytes
149+
}
150+
}
151+
},
152+
Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
153+
Some(ref mut bytes) => {
154+
let bytes = bytes.clone();
155+
self.bytes = None;
156+
bytes
157+
}
158+
};
159+
buf.copy_from_slice(&bytes_to_read);
160+
Poll::Ready(Ok(bytes_to_read.len()))
161+
}
162+
}
163+
100164
fn convert_err(e: WsError) -> io::Error {
101165
match e {
102166
WsError::Io(io) => io,

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ pub mod gio;
9292
#[cfg(feature = "tokio-runtime")]
9393
pub mod tokio;
9494

95-
#[cfg(feature = "futures-03-sink")]
9695
pub mod bytes;
96+
pub use bytes::ByteReader;
9797
#[cfg(feature = "futures-03-sink")]
9898
pub use bytes::ByteWriter;
9999

0 commit comments

Comments
 (0)