Skip to content

Commit 2bc7c94

Browse files
committed
fix(#360): data not flushed immediatly on reverse tunnel
1 parent e2d1413 commit 2bc7c94

File tree

4 files changed

+186
-33
lines changed

4 files changed

+186
-33
lines changed

src/tunnel/client/cnx_pool.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::tunnel::client::l4_transport_stream::TransportStream;
44
use crate::tunnel::client::WsClientConfig;
55
use async_trait::async_trait;
66
use bb8::ManageConnection;
7+
use bytes::Bytes;
78
use std::ops::Deref;
89
use std::sync::Arc;
910
use tracing::instrument;
@@ -58,9 +59,9 @@ impl ManageConnection for WsConnection {
5859

5960
if self.remote_addr.tls().is_some() {
6061
let tls_stream = tls::connect(self, tcp_stream).await?;
61-
Ok(Some(TransportStream::Tls(tls_stream)))
62+
Ok(Some(TransportStream::from_client_tls(tls_stream, Bytes::default())))
6263
} else {
63-
Ok(Some(TransportStream::Plain(tcp_stream)))
64+
Ok(Some(TransportStream::from_tcp(tcp_stream, Bytes::default())))
6465
}
6566
}
6667

+126-9
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,160 @@
1+
use bytes::{Buf, Bytes};
2+
use std::cmp;
13
use std::io::{Error, IoSlice};
24
use std::pin::Pin;
35
use std::task::{Context, Poll};
4-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
6+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
7+
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
58
use tokio::net::TcpStream;
6-
use tokio_rustls::client::TlsStream;
79

8-
pub enum TransportStream {
9-
Plain(TcpStream),
10-
Tls(TlsStream<TcpStream>),
10+
pub struct TransportStream {
11+
read: TransportReadHalf,
12+
write: TransportWriteHalf,
13+
}
14+
15+
impl TransportStream {
16+
pub fn from_tcp(tcp: TcpStream, read_buf: Bytes) -> Self {
17+
let (read, write) = tcp.into_split();
18+
Self {
19+
read: TransportReadHalf::Plain(read, read_buf),
20+
write: TransportWriteHalf::Plain(write),
21+
}
22+
}
23+
24+
pub fn from_client_tls(tls: tokio_rustls::client::TlsStream<TcpStream>, read_buf: Bytes) -> Self {
25+
let (read, write) = tokio::io::split(tls);
26+
Self {
27+
read: TransportReadHalf::Tls(read, read_buf),
28+
write: TransportWriteHalf::Tls(write),
29+
}
30+
}
31+
32+
pub fn from_server_tls(tls: tokio_rustls::server::TlsStream<TcpStream>, read_buf: Bytes) -> Self {
33+
let (read, write) = tokio::io::split(tls);
34+
Self {
35+
read: TransportReadHalf::TlsSrv(read, read_buf),
36+
write: TransportWriteHalf::TlsSrv(write),
37+
}
38+
}
39+
40+
pub fn from(self, read_buf: Bytes) -> Self {
41+
let mut read = self.read;
42+
*read.read_buf_mut() = read_buf;
43+
Self {
44+
read,
45+
write: self.write,
46+
}
47+
}
48+
49+
pub fn into_split(self) -> (TransportReadHalf, TransportWriteHalf) {
50+
(self.read, self.write)
51+
}
52+
}
53+
54+
pub enum TransportReadHalf {
55+
Plain(OwnedReadHalf, Bytes),
56+
Tls(ReadHalf<tokio_rustls::client::TlsStream<TcpStream>>, Bytes),
57+
TlsSrv(ReadHalf<tokio_rustls::server::TlsStream<TcpStream>>, Bytes),
58+
}
59+
60+
impl TransportReadHalf {
61+
fn read_buf_mut(&mut self) -> &mut Bytes {
62+
match self {
63+
Self::Plain(_, buf) => buf,
64+
Self::Tls(_, buf) => buf,
65+
Self::TlsSrv(_, buf) => buf,
66+
}
67+
}
68+
}
69+
70+
pub enum TransportWriteHalf {
71+
Plain(OwnedWriteHalf),
72+
Tls(WriteHalf<tokio_rustls::client::TlsStream<TcpStream>>),
73+
TlsSrv(WriteHalf<tokio_rustls::server::TlsStream<TcpStream>>),
1174
}
1275

1376
impl AsyncRead for TransportStream {
1477
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
15-
match self.get_mut() {
16-
Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
17-
Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
18-
}
78+
unsafe { self.map_unchecked_mut(|s| &mut s.read).poll_read(cx, buf) }
1979
}
2080
}
2181

2282
impl AsyncWrite for TransportStream {
83+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
84+
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_write(cx, buf) }
85+
}
86+
87+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
88+
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_flush(cx) }
89+
}
90+
91+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
92+
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_shutdown(cx) }
93+
}
94+
95+
fn poll_write_vectored(
96+
self: Pin<&mut Self>,
97+
cx: &mut Context<'_>,
98+
bufs: &[IoSlice<'_>],
99+
) -> Poll<Result<usize, Error>> {
100+
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_write_vectored(cx, bufs) }
101+
}
102+
103+
fn is_write_vectored(&self) -> bool {
104+
self.write.is_write_vectored()
105+
}
106+
}
107+
108+
impl AsyncRead for TransportReadHalf {
109+
#[inline]
110+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
111+
let this = self.get_mut();
112+
113+
let read_buf = this.read_buf_mut();
114+
if !read_buf.is_empty() {
115+
let copy_len = cmp::min(read_buf.len(), buf.remaining());
116+
buf.put_slice(&read_buf[..copy_len]);
117+
read_buf.advance(copy_len);
118+
return Poll::Ready(Ok(()));
119+
}
120+
121+
match this {
122+
Self::Plain(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
123+
Self::Tls(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
124+
Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
125+
}
126+
}
127+
}
128+
129+
impl AsyncWrite for TransportWriteHalf {
130+
#[inline]
23131
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
24132
match self.get_mut() {
25133
Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
26134
Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
135+
Self::TlsSrv(cnx) => Pin::new(cnx).poll_write(cx, buf),
27136
}
28137
}
29138

139+
#[inline]
30140
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
31141
match self.get_mut() {
32142
Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
33143
Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
144+
Self::TlsSrv(cnx) => Pin::new(cnx).poll_flush(cx),
34145
}
35146
}
36147

148+
#[inline]
37149
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
38150
match self.get_mut() {
39151
Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
40152
Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
153+
Self::TlsSrv(cnx) => Pin::new(cnx).poll_shutdown(cx),
41154
}
42155
}
43156

157+
#[inline]
44158
fn poll_write_vectored(
45159
self: Pin<&mut Self>,
46160
cx: &mut Context<'_>,
@@ -49,13 +163,16 @@ impl AsyncWrite for TransportStream {
49163
match self.get_mut() {
50164
Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
51165
Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
166+
Self::TlsSrv(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
52167
}
53168
}
54169

170+
#[inline]
55171
fn is_write_vectored(&self) -> bool {
56172
match &self {
57173
Self::Plain(cnx) => cnx.is_write_vectored(),
58174
Self::Tls(cnx) => cnx.is_write_vectored(),
175+
Self::TlsSrv(cnx) => cnx.is_write_vectored(),
59176
}
60177
}
61178
}

src/tunnel/server/handler_websocket.rs

+6-10
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use crate::restrictions::types::RestrictionsRules;
22
use crate::tunnel::server::utils::{bad_request, inject_cookie};
33
use crate::tunnel::server::WsServer;
44
use crate::tunnel::transport;
5-
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
5+
use crate::tunnel::transport::websocket::mk_websocket_tunnel;
66
use bytes::Bytes;
7+
use fastwebsockets::Role;
78
use http_body_util::combinators::BoxBody;
89
use http_body_util::Either;
910
use hyper::body::Incoming;
@@ -46,31 +47,26 @@ pub(super) async fn ws_server_upgrade(
4647
tokio::spawn(
4748
async move {
4849
let (ws_rx, ws_tx) = match fut.await {
49-
Ok(mut ws) => {
50-
ws.set_auto_pong(false);
51-
ws.set_auto_close(false);
52-
ws.set_auto_apply_mask(mask_frame);
53-
ws.split(tokio::io::split)
54-
}
50+
Ok(ws) => mk_websocket_tunnel(ws, Role::Server, mask_frame)?,
5551
Err(err) => {
5652
error!("Error during http upgrade request: {:?}", err);
57-
return;
53+
return Err(anyhow::Error::from(err));
5854
}
5955
};
6056
let (close_tx, close_rx) = oneshot::channel::<()>();
6157

62-
let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
6358
tokio::task::spawn(
6459
transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
6560
);
6661

6762
let _ = transport::io::propagate_local_to_remote(
6863
local_rx,
69-
WebsocketTunnelWrite::new(ws_tx, pending_ops),
64+
ws_tx,
7065
close_tx,
7166
server.config.websocket_ping_frequency,
7267
)
7368
.await;
69+
Ok(())
7470
}
7571
.instrument(Span::current()),
7672
);

src/tunnel/transport/websocket.rs

+51-12
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
2+
use crate::tunnel::client::l4_transport_stream::{TransportReadHalf, TransportStream, TransportWriteHalf};
23
use crate::tunnel::client::WsClient;
34
use crate::tunnel::transport::headers_from_file;
45
use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX};
56
use crate::tunnel::RemoteAddr;
67
use anyhow::{anyhow, Context};
78
use bytes::{Bytes, BytesMut};
8-
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
9+
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocket, WebSocketRead, WebSocketWrite};
910
use http_body_util::Empty;
1011
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
1112
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
@@ -21,14 +22,16 @@ use std::ops::DerefMut;
2122
use std::sync::atomic::AtomicUsize;
2223
use std::sync::atomic::Ordering::Relaxed;
2324
use std::sync::Arc;
24-
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
25+
use tokio::io::{AsyncWrite, AsyncWriteExt};
26+
use tokio::net::TcpStream;
2527
use tokio::sync::mpsc::{Receiver, Sender};
2628
use tokio::sync::Notify;
29+
use tokio_rustls::server::TlsStream;
2730
use tracing::trace;
2831
use uuid::Uuid;
2932

3033
pub struct WebsocketTunnelWrite {
31-
inner: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
34+
inner: WebSocketWrite<TransportWriteHalf>,
3235
buf: BytesMut,
3336
pending_operations: Receiver<Frame<'static>>,
3437
pending_ops_notify: Arc<Notify>,
@@ -37,7 +40,7 @@ pub struct WebsocketTunnelWrite {
3740

3841
impl WebsocketTunnelWrite {
3942
pub fn new(
40-
ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
43+
ws: WebSocketWrite<TransportWriteHalf>,
4144
(pending_operations, notify): (Receiver<Frame<'static>>, Arc<Notify>),
4245
) -> Self {
4346
Self {
@@ -146,13 +149,13 @@ impl TunnelWrite for WebsocketTunnelWrite {
146149
}
147150

148151
pub struct WebsocketTunnelRead {
149-
inner: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
152+
inner: WebSocketRead<TransportReadHalf>,
150153
pending_operations: Sender<Frame<'static>>,
151154
notify_pending_ops: Arc<Notify>,
152155
}
153156

154157
impl WebsocketTunnelRead {
155-
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
158+
pub fn new(ws: WebSocketRead<TransportReadHalf>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
156159
let (tx, rx) = tokio::sync::mpsc::channel(10);
157160
let notify = Arc::new(Notify::new());
158161
(
@@ -278,16 +281,52 @@ pub async fn connect(
278281
})?;
279282
debug!("with HTTP upgrade request {:?}", req);
280283
let transport = pooled_cnx.deref_mut().take().unwrap();
281-
let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
284+
let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
282285
.await
283286
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
284287

285-
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
286-
ws.set_auto_close(false);
287-
ws.set_auto_pong(false);
288+
let (ws_rx, ws_tx) = mk_websocket_tunnel(ws, Role::Client, client_cfg.websocket_mask_frame)?;
289+
Ok((ws_rx, ws_tx, response.into_parts().0))
290+
}
288291

289-
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
292+
pub fn mk_websocket_tunnel(
293+
ws: WebSocket<TokioIo<Upgraded>>,
294+
role: Role,
295+
mask_frame: bool,
296+
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite)> {
297+
let mut ws = match role {
298+
Role::Client => {
299+
let stream = ws
300+
.into_inner()
301+
.into_inner()
302+
.downcast::<TokioIo<TransportStream>>()
303+
.map_err(|_| anyhow!("cannot downcast websocket client stream"))?;
304+
let transport = TransportStream::from(stream.io.into_inner(), stream.read_buf);
305+
WebSocket::after_handshake(transport, role)
306+
}
307+
Role::Server => {
308+
let upgraded = ws.into_inner().into_inner();
309+
match upgraded.downcast::<TokioIo<TlsStream<TcpStream>>>() {
310+
Ok(stream) => {
311+
let transport = TransportStream::from_server_tls(stream.io.into_inner(), stream.read_buf);
312+
WebSocket::after_handshake(transport, role)
313+
}
314+
Err(upgraded) => {
315+
let stream = upgraded
316+
.downcast::<TokioIo<TcpStream>>()
317+
.map_err(|_| anyhow!("cannot downcast websocket server stream"))?;
318+
let transport = TransportStream::from_tcp(stream.io.into_inner(), stream.read_buf);
319+
WebSocket::after_handshake(transport, role)
320+
}
321+
}
322+
}
323+
};
324+
325+
ws.set_auto_pong(false);
326+
ws.set_auto_close(false);
327+
ws.set_auto_apply_mask(mask_frame);
328+
let (ws_rx, ws_tx) = ws.split(|x| x.into_split());
290329

291330
let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
292-
Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), response.into_parts().0))
331+
Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops)))
293332
}

0 commit comments

Comments
 (0)