Skip to content

Commit a6f8f9a

Browse files
joshtriplettsdroege
authored andcommitted
ByteWriter: Support tokio::io::AsyncWrite if tokio-runtime is enabled
1 parent 11d4d33 commit a6f8f9a

File tree

1 file changed

+47
-14
lines changed

1 file changed

+47
-14
lines changed

src/bytes.rs

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,39 @@ impl<S> ByteWriter<S> {
3131
}
3232
}
3333

34+
fn poll_write_helper<S>(
35+
mut s: Pin<&mut ByteWriter<S>>,
36+
cx: &mut Context<'_>,
37+
buf: &[u8],
38+
) -> Poll<io::Result<usize>>
39+
where
40+
S: Sink<Message, Error = WsError> + Unpin,
41+
{
42+
match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) {
43+
Poll::Ready(Ok(())) => {}
44+
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
45+
Poll::Pending => return Poll::Pending,
46+
}
47+
let len = buf.len();
48+
let msg = Message::binary(buf.to_owned());
49+
Poll::Ready(
50+
Pin::new(&mut s.0)
51+
.start_send(msg)
52+
.map_err(convert_err)
53+
.map(|()| len),
54+
)
55+
}
56+
3457
impl<S> futures_io::AsyncWrite for ByteWriter<S>
3558
where
3659
S: Sink<Message, Error = WsError> + Unpin,
3760
{
3861
fn poll_write(
39-
mut self: Pin<&mut Self>,
62+
self: Pin<&mut Self>,
4063
cx: &mut Context<'_>,
4164
buf: &[u8],
4265
) -> Poll<io::Result<usize>> {
43-
match Pin::new(&mut self.0).poll_ready(cx).map_err(convert_err) {
44-
Poll::Ready(Ok(())) => {}
45-
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
46-
Poll::Pending => return Poll::Pending,
47-
}
48-
let len = buf.len();
49-
let msg = Message::binary(buf.to_owned());
50-
Poll::Ready(
51-
Pin::new(&mut self.0)
52-
.start_send(msg)
53-
.map_err(convert_err)
54-
.map(|()| len),
55-
)
66+
poll_write_helper(self, cx, buf)
5667
}
5768

5869
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@@ -64,6 +75,28 @@ where
6475
}
6576
}
6677

78+
#[cfg(feature = "tokio-runtime")]
79+
impl<S> tokio::io::AsyncWrite for ByteWriter<S>
80+
where
81+
S: Sink<Message, Error = WsError> + Unpin,
82+
{
83+
fn poll_write(
84+
self: Pin<&mut Self>,
85+
cx: &mut Context<'_>,
86+
buf: &[u8],
87+
) -> Poll<io::Result<usize>> {
88+
poll_write_helper(self, cx, buf)
89+
}
90+
91+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
92+
Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
93+
}
94+
95+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96+
Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
97+
}
98+
}
99+
67100
fn convert_err(e: WsError) -> io::Error {
68101
match e {
69102
WsError::Io(io) => io,

0 commit comments

Comments
 (0)