Skip to content

Commit 4b43dfc

Browse files
committed
fix(#368): Update the restrictions before each request
instead of only once per conections, to avoid using a stale restriction config when multiple request arrive on the same tcp stream.
1 parent 6ae1eae commit 4b43dfc

File tree

4 files changed

+82
-84
lines changed

4 files changed

+82
-84
lines changed

Cargo.lock

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ futures-util = { version = "0.3.30" }
2121
hickory-resolver = { version = "0.24.1", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls", "native-certs"] }
2222
ppp = { version = "2.2.0", features = [] }
2323
async-channel = { version = "2.3.1", features = [] }
24+
arc-swap = { version = "1.7.1", features = [] }
2425

2526
# For config file parsing
2627
regex = { version = "1.11.0", default-features = false, features = ["std", "perf"] }

src/restrictions/config_reloader.rs

+38-45
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
11
use super::types::RestrictionsRules;
22
use crate::restrictions::config_reloader::RestrictionsRulesReloaderState::{Config, Static};
33
use anyhow::Context;
4+
use arc_swap::ArcSwap;
45
use log::trace;
56
use notify::{EventKind, RecommendedWatcher, Watcher};
67
use parking_lot::Mutex;
78
use std::path::PathBuf;
89
use std::sync::Arc;
910
use std::thread;
1011
use std::time::Duration;
11-
use tokio::sync::futures::Notified;
12-
use tokio::sync::Notify;
1312
use tracing::{error, info, warn};
1413

1514
struct ConfigReloaderState {
1615
fs_watcher: Mutex<RecommendedWatcher>,
1716
config_path: PathBuf,
18-
should_reload_config: Notify,
1917
}
2018

19+
#[derive(Clone)]
2120
enum RestrictionsRulesReloaderState {
22-
Static(Notify),
21+
Static,
2322
Config(Arc<ConfigReloaderState>),
2423
}
2524

2625
impl RestrictionsRulesReloaderState {
2726
fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> {
2827
match self {
29-
Static(_) => unreachable!(),
28+
Static => unreachable!(),
3029
Config(this) => &this.fs_watcher,
3130
}
3231
}
3332
}
3433

34+
#[derive(Clone)]
3535
pub struct RestrictionsRulesReloader {
3636
state: RestrictionsRulesReloaderState,
37-
restrictions: Arc<RestrictionsRules>,
37+
restrictions: Arc<ArcSwap<RestrictionsRules>>,
3838
}
3939

4040
impl RestrictionsRulesReloader {
@@ -44,37 +44,40 @@ impl RestrictionsRulesReloader {
4444
config_path
4545
} else {
4646
return Ok(Self {
47-
state: Static(Notify::new()),
48-
restrictions: Arc::new(restrictions_rules),
47+
state: Static,
48+
restrictions: Arc::new(ArcSwap::from_pointee(restrictions_rules)),
4949
});
5050
};
51-
52-
let this = Arc::new(ConfigReloaderState {
53-
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
54-
should_reload_config: Notify::new(),
55-
config_path,
56-
});
51+
let reloader = Self {
52+
state: Config(Arc::new(ConfigReloaderState {
53+
fs_watcher: Mutex::new(notify::recommended_watcher(|_| {})?),
54+
config_path,
55+
})),
56+
restrictions: Arc::new(ArcSwap::from_pointee(restrictions_rules)),
57+
};
5758

5859
info!("Starting to watch restriction config file for changes to reload them");
5960
let mut watcher = notify::recommended_watcher({
60-
let this = Config(this.clone());
61+
let reloader = reloader.clone();
6162

62-
move |event: notify::Result<notify::Event>| Self::handle_config_fs_event(&this, event)
63+
move |event: notify::Result<notify::Event>| Self::handle_config_fs_event(&reloader, event)
6364
})
6465
.with_context(|| "Cannot create restriction config watcher")?;
6566

66-
watcher.watch(&this.config_path, notify::RecursiveMode::NonRecursive)?;
67-
*this.fs_watcher.lock() = watcher;
67+
match &reloader.state {
68+
Static => {}
69+
Config(cfg) => {
70+
watcher.watch(&cfg.config_path, notify::RecursiveMode::NonRecursive)?;
71+
*cfg.fs_watcher.lock() = watcher
72+
}
73+
}
6874

69-
Ok(Self {
70-
state: Config(this),
71-
restrictions: Arc::new(restrictions_rules),
72-
})
75+
Ok(reloader)
7376
}
7477

75-
pub fn reload_restrictions_config(&mut self) {
78+
pub fn reload_restrictions_config(&self) {
7679
let restrictions = match &self.state {
77-
Static(_) => return,
80+
Static => return,
7881
Config(st) => match RestrictionsRules::from_config_file(&st.config_path) {
7982
Ok(restrictions) => {
8083
info!("Restrictions config file has been reloaded");
@@ -87,21 +90,14 @@ impl RestrictionsRulesReloader {
8790
},
8891
};
8992

90-
self.restrictions = Arc::new(restrictions);
93+
self.restrictions.store(Arc::new(restrictions));
9194
}
9295

93-
pub const fn restrictions_rules(&self) -> &Arc<RestrictionsRules> {
96+
pub const fn restrictions_rules(&self) -> &Arc<ArcSwap<RestrictionsRules>> {
9497
&self.restrictions
9598
}
9699

97-
pub fn reload_notifier(&self) -> Notified {
98-
match &self.state {
99-
Static(st) => st.notified(),
100-
Config(st) => st.should_reload_config.notified(),
101-
}
102-
}
103-
104-
fn try_rewatch_config(this: RestrictionsRulesReloaderState, path: PathBuf) {
100+
fn try_rewatch_config(this: RestrictionsRulesReloader, path: PathBuf) {
105101
thread::spawn(move || {
106102
while !path.exists() {
107103
warn!(
@@ -110,7 +106,7 @@ impl RestrictionsRulesReloader {
110106
);
111107
thread::sleep(Duration::from_secs(10));
112108
}
113-
let mut watcher = this.fs_watcher().lock();
109+
let mut watcher = this.state.fs_watcher().lock();
114110
let _ = watcher.unwatch(&path);
115111
let Ok(_) = watcher
116112
.watch(&path, notify::RecursiveMode::NonRecursive)
@@ -123,23 +119,20 @@ impl RestrictionsRulesReloader {
123119
};
124120
drop(watcher);
125121

126-
// Generate a fake event to force-reload the certificate
122+
// Generate a fake event to force-reload the config
127123
let event = notify::Event {
128124
kind: EventKind::Create(notify::event::CreateKind::Any),
129125
paths: vec![path],
130126
attrs: Default::default(),
131127
};
132128

133-
match &this {
134-
Static(_) => Self::handle_config_fs_event(&this, Ok(event)),
135-
Config(_) => Self::handle_config_fs_event(&this, Ok(event)),
136-
}
129+
Self::handle_config_fs_event(&this, Ok(event))
137130
});
138131
}
139132

140-
fn handle_config_fs_event(this: &RestrictionsRulesReloaderState, event: notify::Result<notify::Event>) {
141-
let this = match this {
142-
Static(_) => return,
133+
fn handle_config_fs_event(reloader: &RestrictionsRulesReloader, event: notify::Result<notify::Event>) {
134+
let this = match &reloader.state {
135+
Static => return,
143136
Config(st) => st,
144137
};
145138

@@ -159,11 +152,11 @@ impl RestrictionsRulesReloader {
159152
if let Some(path) = event.paths.iter().find(|p| p.ends_with(&this.config_path)) {
160153
match event.kind {
161154
EventKind::Create(_) | EventKind::Modify(_) => {
162-
this.should_reload_config.notify_one();
155+
reloader.reload_restrictions_config();
163156
}
164157
EventKind::Remove(_) => {
165158
warn!("Restriction config file has been removed, trying to re-set a watch for it");
166-
Self::try_rewatch_config(Config(this.clone()), path.to_path_buf());
159+
Self::try_rewatch_config(reloader.clone(), path.to_path_buf());
167160
}
168161
EventKind::Access(_) | EventKind::Other | EventKind::Any => {
169162
trace!("Ignoring event {:?}", event);

src/tunnel/server/server.rs

+36-39
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,23 @@ use http_body_util::Either;
44
use std::fmt;
55
use std::fmt::{Debug, Formatter};
66

7-
use bytes::Bytes;
8-
use http_body_util::combinators::BoxBody;
9-
use std::net::SocketAddr;
10-
use std::path::PathBuf;
11-
use std::pin::Pin;
12-
use std::sync::{Arc, LazyLock};
13-
use std::time::Duration;
14-
157
use crate::protocols;
168
use crate::tunnel::{try_to_sock_addr, LocalProtocol, RemoteAddr};
9+
use arc_swap::ArcSwap;
10+
use bytes::Bytes;
11+
use http_body_util::combinators::BoxBody;
1712
use hyper::body::Incoming;
1813
use hyper::server::conn::{http1, http2};
1914
use hyper::service::service_fn;
2015
use hyper::{http, Request, Response, StatusCode, Version};
2116
use hyper_util::rt::{TokioExecutor, TokioTimer};
2217
use parking_lot::Mutex;
2318
use socket2::SockRef;
19+
use std::net::SocketAddr;
20+
use std::path::PathBuf;
21+
use std::pin::Pin;
22+
use std::sync::{Arc, LazyLock};
23+
use std::time::Duration;
2424

2525
use crate::protocols::dns::DnsResolver;
2626
use crate::protocols::tls;
@@ -37,7 +37,6 @@ use crate::tunnel::server::utils::{
3737
use crate::tunnel::tls_reloader::TlsReloader;
3838
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
3939
use tokio::net::TcpListener;
40-
use tokio::select;
4140
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
4241
use tokio_rustls::TlsAcceptor;
4342
use tracing::{error, info, span, warn, Instrument, Level, Span};
@@ -285,29 +284,41 @@ impl WsServer {
285284

286285
// setup upgrade request handler
287286
let mk_websocket_upgrade_fn = |server: WsServer,
288-
restrictions: Arc<RestrictionsRules>,
287+
restrictions: Arc<ArcSwap<RestrictionsRules>>,
289288
restrict_path: Option<String>,
290289
client_addr: SocketAddr| {
291290
move |req: Request<Incoming>| {
292-
ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req)
293-
.map::<anyhow::Result<_>, _>(Ok)
294-
.instrument(mk_span())
291+
ws_server_upgrade(
292+
server.clone(),
293+
restrictions.load().clone(),
294+
restrict_path.clone(),
295+
client_addr,
296+
req,
297+
)
298+
.map::<anyhow::Result<_>, _>(Ok)
299+
.instrument(mk_span())
295300
}
296301
};
297302

298303
let mk_http_upgrade_fn = |server: WsServer,
299-
restrictions: Arc<RestrictionsRules>,
304+
restrictions: Arc<ArcSwap<RestrictionsRules>>,
300305
restrict_path: Option<String>,
301306
client_addr: SocketAddr| {
302307
move |req: Request<Incoming>| {
303-
http_server_upgrade(server.clone(), restrictions.clone(), restrict_path.clone(), client_addr, req)
304-
.map::<anyhow::Result<_>, _>(Ok)
305-
.instrument(mk_span())
308+
http_server_upgrade(
309+
server.clone(),
310+
restrictions.load().clone(),
311+
restrict_path.clone(),
312+
client_addr,
313+
req,
314+
)
315+
.map::<anyhow::Result<_>, _>(Ok)
316+
.instrument(mk_span())
306317
}
307318
};
308319

309320
let mk_auto_upgrade_fn = |server: WsServer,
310-
restrictions: Arc<RestrictionsRules>,
321+
restrictions: Arc<ArcSwap<RestrictionsRules>>,
311322
restrict_path: Option<String>,
312323
client_addr: SocketAddr| {
313324
move |req: Request<Incoming>| {
@@ -316,13 +327,13 @@ impl WsServer {
316327
let restrict_path = restrict_path.clone();
317328
async move {
318329
if fastwebsockets::upgrade::is_upgrade_request(&req) {
319-
ws_server_upgrade(server.clone(), restrictions.clone(), restrict_path, client_addr, req)
330+
ws_server_upgrade(server.clone(), restrictions.load().clone(), restrict_path, client_addr, req)
320331
.map::<anyhow::Result<_>, _>(Ok)
321332
.await
322333
} else if req.version() == Version::HTTP_2 {
323334
http_server_upgrade(
324335
server.clone(),
325-
restrictions.clone(),
336+
restrictions.load().clone(),
326337
restrict_path.clone(),
327338
client_addr,
328339
req,
@@ -357,25 +368,11 @@ impl WsServer {
357368
};
358369

359370
// Bind server and run forever to serve incoming connections.
360-
let mut restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
361-
let mut await_config_reload = Box::pin(restrictions.reload_notifier());
371+
let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
362372
let listener = TcpListener::bind(&self.config.bind).await?;
363373

364374
loop {
365-
let cnx = select! {
366-
biased;
367-
368-
_ = &mut await_config_reload => {
369-
drop(await_config_reload);
370-
restrictions.reload_restrictions_config();
371-
await_config_reload = Box::pin(restrictions.reload_notifier());
372-
continue;
373-
},
374-
375-
cnx = listener.accept() => { cnx }
376-
};
377-
378-
let (stream, peer_addr) = match cnx {
375+
let (stream, peer_addr) = match listener.accept().await {
379376
Ok(ret) => ret,
380377
Err(err) => {
381378
warn!("Error while accepting connection {:?}", err);
@@ -423,7 +420,7 @@ impl WsServer {
423420
}
424421

425422
let http_upgrade_fn =
426-
mk_http_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr);
423+
mk_http_upgrade_fn(server, restrictions, restrict_path, peer_addr);
427424
let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn));
428425
if let Err(e) = con_fut.await {
429426
error!("Error while upgrading cnx to http: {:?}", e);
@@ -432,7 +429,7 @@ impl WsServer {
432429
// websocket
433430
_ => {
434431
let websocket_upgrade_fn =
435-
mk_websocket_upgrade_fn(server, restrictions.clone(), restrict_path, peer_addr);
432+
mk_websocket_upgrade_fn(server, restrictions, restrict_path, peer_addr);
436433
let conn_fut = http1::Builder::new()
437434
.timer(TokioTimer::new())
438435
// https://github.com/erebe/wstunnel/issues/358
@@ -460,7 +457,7 @@ impl WsServer {
460457
conn_fut.http2().keep_alive_interval(ping);
461458
}
462459

463-
let websocket_upgrade_fn = mk_auto_upgrade_fn(server, restrictions.clone(), None, peer_addr);
460+
let websocket_upgrade_fn = mk_auto_upgrade_fn(server, restrictions, None, peer_addr);
464461
let upgradable =
465462
conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn));
466463

0 commit comments

Comments
 (0)