1
- //! Provides an abstraction to use `AsyncWrite` to write bytes to a `WebSocketStream`.
1
+ //! Provides abstractions to use `AsyncRead` and `AsyncWrite` with a `WebSocketStream`.
2
2
3
3
use std:: {
4
4
io,
5
5
pin:: Pin ,
6
6
task:: { Context , Poll } ,
7
7
} ;
8
8
9
- use futures_util :: Sink ;
9
+ use futures_core :: stream :: Stream ;
10
10
11
- use crate :: { Message , WsError } ;
11
+ use crate :: { tungstenite :: Bytes , Message , WsError } ;
12
12
13
13
/// Treat a `WebSocketStream` as an `AsyncWrite` implementation.
14
14
///
15
15
/// Every write sends a binary message. If you want to group writes together, consider wrapping
16
16
/// this with a `BufWriter`.
17
+ #[ cfg( feature = "futures-03-sink" ) ]
17
18
#[ derive( Debug ) ]
18
19
pub struct ByteWriter < S > ( S ) ;
19
20
21
+ #[ cfg( feature = "futures-03-sink" ) ]
20
22
impl < S > ByteWriter < S > {
21
23
/// Create a new `ByteWriter` from a `Sink` that accepts a WebSocket `Message`
22
24
#[ inline( always) ]
@@ -31,13 +33,14 @@ impl<S> ByteWriter<S> {
31
33
}
32
34
}
33
35
36
+ #[ cfg( feature = "futures-03-sink" ) ]
34
37
fn poll_write_helper < S > (
35
38
mut s : Pin < & mut ByteWriter < S > > ,
36
39
cx : & mut Context < ' _ > ,
37
40
buf : & [ u8 ] ,
38
41
) -> Poll < io:: Result < usize > >
39
42
where
40
- S : Sink < Message , Error = WsError > + Unpin ,
43
+ S : futures_util :: Sink < Message , Error = WsError > + Unpin ,
41
44
{
42
45
match Pin :: new ( & mut s. 0 ) . poll_ready ( cx) . map_err ( convert_err) {
43
46
Poll :: Ready ( Ok ( ( ) ) ) => { }
54
57
)
55
58
}
56
59
60
+ #[ cfg( feature = "futures-03-sink" ) ]
57
61
impl < S > futures_io:: AsyncWrite for ByteWriter < S >
58
62
where
59
- S : Sink < Message , Error = WsError > + Unpin ,
63
+ S : futures_util :: Sink < Message , Error = WsError > + Unpin ,
60
64
{
61
65
fn poll_write (
62
66
self : Pin < & mut Self > ,
@@ -75,10 +79,11 @@ where
75
79
}
76
80
}
77
81
82
+ #[ cfg( feature = "futures-03-sink" ) ]
78
83
#[ cfg( feature = "tokio-runtime" ) ]
79
84
impl < S > tokio:: io:: AsyncWrite for ByteWriter < S >
80
85
where
81
- S : Sink < Message , Error = WsError > + Unpin ,
86
+ S : futures_util :: Sink < Message , Error = WsError > + Unpin ,
82
87
{
83
88
fn poll_write (
84
89
self : Pin < & mut Self > ,
@@ -97,6 +102,65 @@ where
97
102
}
98
103
}
99
104
105
+ /// Treat a `WebSocketStream` as an `AsyncRead` implementation.
106
+ ///
107
+ /// This also works with any other `Stream` of `Message`, such as a `SplitStream`.
108
+ ///
109
+ /// Each read will only return data from one message. If you want to combine data from multiple
110
+ /// messages into one read, consider wrapping this in a `BufReader`.
111
+ #[ derive( Debug ) ]
112
+ pub struct ByteReader < S > {
113
+ stream : S ,
114
+ bytes : Option < Bytes > ,
115
+ }
116
+
117
+ impl < S > ByteReader < S > {
118
+ /// Create a new `ByteReader` from a `Stream` that returns a WebSocket `Message`
119
+ #[ inline( always) ]
120
+ pub fn new ( stream : S ) -> Self {
121
+ Self {
122
+ stream,
123
+ bytes : None ,
124
+ }
125
+ }
126
+ }
127
+
128
+ impl < S > futures_io:: AsyncRead for ByteReader < S >
129
+ where
130
+ S : Stream < Item = Result < Message , WsError > > + Unpin ,
131
+ {
132
+ fn poll_read (
133
+ mut self : Pin < & mut Self > ,
134
+ cx : & mut Context < ' _ > ,
135
+ buf : & mut [ u8 ] ,
136
+ ) -> Poll < io:: Result < usize > > {
137
+ let buf_len = buf. len ( ) ;
138
+ let bytes_to_read = match self . bytes {
139
+ None => match Pin :: new ( & mut self . stream ) . poll_next ( cx) {
140
+ Poll :: Pending => return Poll :: Pending ,
141
+ Poll :: Ready ( None ) => return Poll :: Ready ( Ok ( 0 ) ) ,
142
+ Poll :: Ready ( Some ( Err ( e) ) ) => return Poll :: Ready ( Err ( convert_err ( e) ) ) ,
143
+ Poll :: Ready ( Some ( Ok ( msg) ) ) => {
144
+ let bytes = msg. into_data ( ) ;
145
+ if bytes. len ( ) > buf_len {
146
+ self . bytes . insert ( bytes) . split_to ( buf_len)
147
+ } else {
148
+ bytes
149
+ }
150
+ }
151
+ } ,
152
+ Some ( ref mut bytes) if bytes. len ( ) > buf_len => bytes. split_to ( buf_len) ,
153
+ Some ( ref mut bytes) => {
154
+ let bytes = bytes. clone ( ) ;
155
+ self . bytes = None ;
156
+ bytes
157
+ }
158
+ } ;
159
+ buf. copy_from_slice ( & bytes_to_read) ;
160
+ Poll :: Ready ( Ok ( bytes_to_read. len ( ) ) )
161
+ }
162
+ }
163
+
100
164
fn convert_err ( e : WsError ) -> io:: Error {
101
165
match e {
102
166
WsError :: Io ( io) => io,
0 commit comments