Skip to content

Commit 4a2cb34

Browse files
joshtriplettsdroege
authored andcommitted
Add tokio support for ByteReader
Factor out a helper for the common portions. This also fixes a bug in the async-std version when doing a short read.
1 parent 2ac7c70 commit 4a2cb34

File tree

3 files changed

+105
-24
lines changed

3 files changed

+105
-24
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ required-features = ["gio-runtime"]
180180
name = "gio-echo-server"
181181
required-features = ["gio-runtime"]
182182

183+
[[example]]
184+
name = "tokio-client-bytes"
185+
required-features = ["tokio-runtime"]
186+
183187
[[example]]
184188
name = "tokio-echo"
185189
required-features = ["tokio-runtime"]

examples/tokio-client-bytes.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//! A simple example of hooking up stdin/stdout to a WebSocket stream using ByteStream.
2+
//!
3+
//! This example will connect to a server specified in the argument list and
4+
//! then forward all data read on stdin to the server, printing out all data
5+
//! received on stdout.
6+
//!
7+
//! Note that this is not currently optimized for performance, especially around
8+
//! buffer management. Rather it's intended to show an example of working with a
9+
//! client.
10+
//!
11+
//! You can use this example together with the `server` example.
12+
13+
use std::env;
14+
15+
use futures::StreamExt;
16+
17+
use async_tungstenite::tokio::connect_async;
18+
use async_tungstenite::{ByteReader, ByteWriter};
19+
use tokio::io;
20+
use tokio::task;
21+
22+
async fn run() {
23+
let connect_addr = env::args()
24+
.nth(1)
25+
.unwrap_or_else(|| panic!("this program requires at least one argument"));
26+
27+
let (ws_stream, _) = connect_async(&connect_addr)
28+
.await
29+
.expect("Failed to connect");
30+
println!("WebSocket handshake has been successfully completed");
31+
32+
let (write, read) = ws_stream.split();
33+
let mut byte_writer = ByteWriter::new(write);
34+
let mut byte_reader = ByteReader::new(read);
35+
let stdin_to_ws =
36+
task::spawn(async move { io::copy(&mut io::stdin(), &mut byte_writer).await });
37+
let ws_to_stdout =
38+
task::spawn(async move { io::copy(&mut byte_reader, &mut io::stdout()).await });
39+
stdin_to_ws.await.unwrap().unwrap();
40+
ws_to_stdout.await.unwrap().unwrap();
41+
}
42+
43+
fn main() {
44+
let rt = tokio::runtime::Runtime::new().expect("runtime");
45+
rt.block_on(run())
46+
}

src/bytes.rs

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -125,39 +125,70 @@ impl<S> ByteReader<S> {
125125
}
126126
}
127127

128+
fn poll_read_helper<S>(
129+
mut s: Pin<&mut ByteReader<S>>,
130+
cx: &mut Context<'_>,
131+
buf_len: usize,
132+
) -> Poll<io::Result<Option<Bytes>>>
133+
where
134+
S: Stream<Item = Result<Message, WsError>> + Unpin,
135+
{
136+
Poll::Ready(Ok(Some(match s.bytes {
137+
None => match Pin::new(&mut s.stream).poll_next(cx) {
138+
Poll::Pending => return Poll::Pending,
139+
Poll::Ready(None) => return Poll::Ready(Ok(None)),
140+
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
141+
Poll::Ready(Some(Ok(msg))) => {
142+
let bytes = msg.into_data();
143+
if bytes.len() > buf_len {
144+
s.bytes.insert(bytes).split_to(buf_len)
145+
} else {
146+
bytes
147+
}
148+
}
149+
},
150+
Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
151+
Some(ref mut bytes) => {
152+
let bytes = bytes.clone();
153+
s.bytes = None;
154+
bytes
155+
}
156+
})))
157+
}
158+
128159
impl<S> futures_io::AsyncRead for ByteReader<S>
129160
where
130161
S: Stream<Item = Result<Message, WsError>> + Unpin,
131162
{
132163
fn poll_read(
133-
mut self: Pin<&mut Self>,
164+
self: Pin<&mut Self>,
134165
cx: &mut Context<'_>,
135166
buf: &mut [u8],
136167
) -> Poll<io::Result<usize>> {
137-
let buf_len = buf.len();
138-
let bytes_to_read = match self.bytes {
139-
None => match Pin::new(&mut self.stream).poll_next(cx) {
140-
Poll::Pending => return Poll::Pending,
141-
Poll::Ready(None) => return Poll::Ready(Ok(0)),
142-
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
143-
Poll::Ready(Some(Ok(msg))) => {
144-
let bytes = msg.into_data();
145-
if bytes.len() > buf_len {
146-
self.bytes.insert(bytes).split_to(buf_len)
147-
} else {
148-
bytes
149-
}
150-
}
151-
},
152-
Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
153-
Some(ref mut bytes) => {
154-
let bytes = bytes.clone();
155-
self.bytes = None;
156-
bytes
168+
poll_read_helper(self, cx, buf.len()).map_ok(|bytes| {
169+
bytes.map_or(0, |bytes| {
170+
buf[..bytes.len()].copy_from_slice(&bytes);
171+
bytes.len()
172+
})
173+
})
174+
}
175+
}
176+
177+
#[cfg(feature = "tokio-runtime")]
178+
impl<S> tokio::io::AsyncRead for ByteReader<S>
179+
where
180+
S: Stream<Item = Result<Message, WsError>> + Unpin,
181+
{
182+
fn poll_read(
183+
self: Pin<&mut Self>,
184+
cx: &mut Context<'_>,
185+
buf: &mut tokio::io::ReadBuf,
186+
) -> Poll<io::Result<()>> {
187+
poll_read_helper(self, cx, buf.remaining()).map_ok(|bytes| {
188+
if let Some(ref bytes) = bytes {
189+
buf.put_slice(bytes);
157190
}
158-
};
159-
buf.copy_from_slice(&bytes_to_read);
160-
Poll::Ready(Ok(bytes_to_read.len()))
191+
})
161192
}
162193
}
163194

0 commit comments

Comments
 (0)