1
+ use bytes:: { Buf , Bytes } ;
2
+ use std:: cmp;
1
3
use std:: io:: { Error , IoSlice } ;
2
4
use std:: pin:: Pin ;
3
5
use std:: task:: { Context , Poll } ;
4
- use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
6
+ use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf , ReadHalf , WriteHalf } ;
7
+ use tokio:: net:: tcp:: { OwnedReadHalf , OwnedWriteHalf } ;
5
8
use tokio:: net:: TcpStream ;
6
- use tokio_rustls:: client:: TlsStream ;
7
9
8
- pub enum TransportStream {
9
- Plain ( TcpStream ) ,
10
- Tls ( TlsStream < TcpStream > ) ,
10
+ pub struct TransportStream {
11
+ read : TransportReadHalf ,
12
+ write : TransportWriteHalf ,
13
+ }
14
+
15
+ impl TransportStream {
16
+ pub fn from_tcp ( tcp : TcpStream , read_buf : Bytes ) -> Self {
17
+ let ( read, write) = tcp. into_split ( ) ;
18
+ Self {
19
+ read : TransportReadHalf :: Plain ( read, read_buf) ,
20
+ write : TransportWriteHalf :: Plain ( write) ,
21
+ }
22
+ }
23
+
24
+ pub fn from_client_tls ( tls : tokio_rustls:: client:: TlsStream < TcpStream > , read_buf : Bytes ) -> Self {
25
+ let ( read, write) = tokio:: io:: split ( tls) ;
26
+ Self {
27
+ read : TransportReadHalf :: Tls ( read, read_buf) ,
28
+ write : TransportWriteHalf :: Tls ( write) ,
29
+ }
30
+ }
31
+
32
+ pub fn from_server_tls ( tls : tokio_rustls:: server:: TlsStream < TcpStream > , read_buf : Bytes ) -> Self {
33
+ let ( read, write) = tokio:: io:: split ( tls) ;
34
+ Self {
35
+ read : TransportReadHalf :: TlsSrv ( read, read_buf) ,
36
+ write : TransportWriteHalf :: TlsSrv ( write) ,
37
+ }
38
+ }
39
+
40
+ pub fn from ( self , read_buf : Bytes ) -> Self {
41
+ let mut read = self . read ;
42
+ * read. read_buf_mut ( ) = read_buf;
43
+ Self {
44
+ read,
45
+ write : self . write ,
46
+ }
47
+ }
48
+
49
+ pub fn into_split ( self ) -> ( TransportReadHalf , TransportWriteHalf ) {
50
+ ( self . read , self . write )
51
+ }
52
+ }
53
+
54
+ pub enum TransportReadHalf {
55
+ Plain ( OwnedReadHalf , Bytes ) ,
56
+ Tls ( ReadHalf < tokio_rustls:: client:: TlsStream < TcpStream > > , Bytes ) ,
57
+ TlsSrv ( ReadHalf < tokio_rustls:: server:: TlsStream < TcpStream > > , Bytes ) ,
58
+ }
59
+
60
+ impl TransportReadHalf {
61
+ fn read_buf_mut ( & mut self ) -> & mut Bytes {
62
+ match self {
63
+ Self :: Plain ( _, buf) => buf,
64
+ Self :: Tls ( _, buf) => buf,
65
+ Self :: TlsSrv ( _, buf) => buf,
66
+ }
67
+ }
68
+ }
69
+
70
+ pub enum TransportWriteHalf {
71
+ Plain ( OwnedWriteHalf ) ,
72
+ Tls ( WriteHalf < tokio_rustls:: client:: TlsStream < TcpStream > > ) ,
73
+ TlsSrv ( WriteHalf < tokio_rustls:: server:: TlsStream < TcpStream > > ) ,
11
74
}
12
75
13
76
impl AsyncRead for TransportStream {
14
77
fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
15
- match self . get_mut ( ) {
16
- Self :: Plain ( cnx) => Pin :: new ( cnx) . poll_read ( cx, buf) ,
17
- Self :: Tls ( cnx) => Pin :: new ( cnx) . poll_read ( cx, buf) ,
18
- }
78
+ unsafe { self . map_unchecked_mut ( |s| & mut s. read ) . poll_read ( cx, buf) }
19
79
}
20
80
}
21
81
22
82
impl AsyncWrite for TransportStream {
83
+ fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , Error > > {
84
+ unsafe { self . map_unchecked_mut ( |s| & mut s. write ) . poll_write ( cx, buf) }
85
+ }
86
+
87
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Error > > {
88
+ unsafe { self . map_unchecked_mut ( |s| & mut s. write ) . poll_flush ( cx) }
89
+ }
90
+
91
+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Error > > {
92
+ unsafe { self . map_unchecked_mut ( |s| & mut s. write ) . poll_shutdown ( cx) }
93
+ }
94
+
95
+ fn poll_write_vectored (
96
+ self : Pin < & mut Self > ,
97
+ cx : & mut Context < ' _ > ,
98
+ bufs : & [ IoSlice < ' _ > ] ,
99
+ ) -> Poll < Result < usize , Error > > {
100
+ unsafe { self . map_unchecked_mut ( |s| & mut s. write ) . poll_write_vectored ( cx, bufs) }
101
+ }
102
+
103
+ fn is_write_vectored ( & self ) -> bool {
104
+ self . write . is_write_vectored ( )
105
+ }
106
+ }
107
+
108
+ impl AsyncRead for TransportReadHalf {
109
+ #[ inline]
110
+ fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
111
+ let this = self . get_mut ( ) ;
112
+
113
+ let read_buf = this. read_buf_mut ( ) ;
114
+ if !read_buf. is_empty ( ) {
115
+ let copy_len = cmp:: min ( read_buf. len ( ) , buf. remaining ( ) ) ;
116
+ buf. put_slice ( & read_buf[ ..copy_len] ) ;
117
+ read_buf. advance ( copy_len) ;
118
+ return Poll :: Ready ( Ok ( ( ) ) ) ;
119
+ }
120
+
121
+ match this {
122
+ Self :: Plain ( cnx, _) => Pin :: new ( cnx) . poll_read ( cx, buf) ,
123
+ Self :: Tls ( cnx, _) => Pin :: new ( cnx) . poll_read ( cx, buf) ,
124
+ Self :: TlsSrv ( cnx, _) => Pin :: new ( cnx) . poll_read ( cx, buf) ,
125
+ }
126
+ }
127
+ }
128
+
129
+ impl AsyncWrite for TransportWriteHalf {
130
+ #[ inline]
23
131
fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , Error > > {
24
132
match self . get_mut ( ) {
25
133
Self :: Plain ( cnx) => Pin :: new ( cnx) . poll_write ( cx, buf) ,
26
134
Self :: Tls ( cnx) => Pin :: new ( cnx) . poll_write ( cx, buf) ,
135
+ Self :: TlsSrv ( cnx) => Pin :: new ( cnx) . poll_write ( cx, buf) ,
27
136
}
28
137
}
29
138
139
+ #[ inline]
30
140
fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Error > > {
31
141
match self . get_mut ( ) {
32
142
Self :: Plain ( cnx) => Pin :: new ( cnx) . poll_flush ( cx) ,
33
143
Self :: Tls ( cnx) => Pin :: new ( cnx) . poll_flush ( cx) ,
144
+ Self :: TlsSrv ( cnx) => Pin :: new ( cnx) . poll_flush ( cx) ,
34
145
}
35
146
}
36
147
148
+ #[ inline]
37
149
fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Error > > {
38
150
match self . get_mut ( ) {
39
151
Self :: Plain ( cnx) => Pin :: new ( cnx) . poll_shutdown ( cx) ,
40
152
Self :: Tls ( cnx) => Pin :: new ( cnx) . poll_shutdown ( cx) ,
153
+ Self :: TlsSrv ( cnx) => Pin :: new ( cnx) . poll_shutdown ( cx) ,
41
154
}
42
155
}
43
156
157
+ #[ inline]
44
158
fn poll_write_vectored (
45
159
self : Pin < & mut Self > ,
46
160
cx : & mut Context < ' _ > ,
@@ -49,13 +163,16 @@ impl AsyncWrite for TransportStream {
49
163
match self . get_mut ( ) {
50
164
Self :: Plain ( cnx) => Pin :: new ( cnx) . poll_write_vectored ( cx, bufs) ,
51
165
Self :: Tls ( cnx) => Pin :: new ( cnx) . poll_write_vectored ( cx, bufs) ,
166
+ Self :: TlsSrv ( cnx) => Pin :: new ( cnx) . poll_write_vectored ( cx, bufs) ,
52
167
}
53
168
}
54
169
170
+ #[ inline]
55
171
fn is_write_vectored ( & self ) -> bool {
56
172
match & self {
57
173
Self :: Plain ( cnx) => cnx. is_write_vectored ( ) ,
58
174
Self :: Tls ( cnx) => cnx. is_write_vectored ( ) ,
175
+ Self :: TlsSrv ( cnx) => cnx. is_write_vectored ( ) ,
59
176
}
60
177
}
61
178
}
0 commit comments