@@ -28,7 +28,7 @@ mod userfs;
28
28
use std:: convert:: TryFrom ;
29
29
use std:: io;
30
30
use std:: net:: { SocketAddr , ToSocketAddrs } ;
31
- #[ cfg( feature = "tls" ) ]
31
+ #[ cfg( all ( not ( windows ) , feature = "tls" ) ) ]
32
32
use std:: os:: unix:: io:: { FromRawFd , AsRawFd } ;
33
33
use std:: process:: exit;
34
34
use std:: sync:: Arc ;
@@ -41,9 +41,9 @@ use hyper::{
41
41
server:: conn:: { AddrIncoming , AddrStream } ,
42
42
service:: { make_service_fn, service_fn} ,
43
43
} ;
44
- #[ cfg( feature = "tls" ) ]
44
+ #[ cfg( all ( not ( windows ) , feature = "tls" ) ) ]
45
45
use tls_listener:: TlsListener ;
46
- #[ cfg( feature = "tls" ) ]
46
+ #[ cfg( all ( not ( windows ) , feature = "tls" ) ) ]
47
47
use tokio_rustls:: server:: TlsStream ;
48
48
use webdav_handler:: { davpath:: DavPath , DavConfig , DavHandler , DavMethod , DavMethodSet } ;
49
49
use webdav_handler:: { fakels:: FakeLs , fs:: DavFileSystem , ls:: DavLockSystem } ;
@@ -52,8 +52,8 @@ use crate::config::{AcctType, Auth, CaseInsensitive, Handler, Location, OnNotfou
52
52
use crate :: rootfs:: RootFs ;
53
53
use crate :: router:: MatchedRoute ;
54
54
use crate :: suid:: proc_switch_ugid;
55
- #[ cfg( feature = "tls" ) ]
56
- use crate :: tls:: tls_config ;
55
+ #[ cfg( all ( not ( windows ) , feature = "tls" ) ) ]
56
+ use crate :: tls:: tls_acceptor ;
57
57
use crate :: userfs:: UserFs ;
58
58
59
59
static PROGNAME : & ' static str = "webdav-server" ;
@@ -105,48 +105,41 @@ impl Server {
105
105
None => return Ok ( None ) ,
106
106
} ;
107
107
108
- #[ cfg( target_os = "windows" ) ]
108
+ #[ cfg( target_os = "windows" ) ]
109
109
panic ! ( ) ;
110
- #[ cfg( not( target_os = "windows" ) ) ]
110
+ #[ cfg( not( target_os = "windows" ) ) ]
111
111
{
112
- // check if user exists.
113
- let pwd = match cache:: cached:: unixuser ( user, self . config . unix . aux_groups ) . await {
114
- Ok ( pwd) => pwd,
115
- Err ( _) => {
116
- debug ! ( "acct: unix: user {} not found" , user) ;
117
- return Err ( StatusCode :: UNAUTHORIZED ) ;
118
- } ,
119
- } ;
112
+
113
+ // check if user exists.
114
+ let pwd = match cache:: cached:: unixuser ( user, self . config . unix . aux_groups ) . await {
115
+ Ok ( pwd) => pwd,
116
+ Err ( _) => {
117
+ debug ! ( "acct: unix: user {} not found" , user) ;
118
+ return Err ( StatusCode :: UNAUTHORIZED ) ;
119
+ } ,
120
+ } ;
120
121
121
- // check minimum uid
122
- if let Some ( min_uid) = self . config . unix . min_uid {
123
- if pwd. uid < min_uid {
124
- debug ! ( "acct: {}: uid {} too low (<{})" , pwd. name, pwd. uid, min_uid) ;
125
- return Err ( StatusCode :: FORBIDDEN ) ;
126
- }
122
+ // check minimum uid
123
+ if let Some ( min_uid) = self . config . unix . min_uid {
124
+ if pwd. uid < min_uid {
125
+ debug ! ( "acct: {}: uid {} too low (<{})" , pwd. name, pwd. uid, min_uid) ;
126
+ return Err ( StatusCode :: FORBIDDEN ) ;
127
127
}
128
- Ok ( Some ( pwd) )
128
+ }
129
+ Ok ( Some ( pwd) )
130
+
129
131
}
130
132
}
131
133
132
- // return a new response::Builder with the Server: header set.
134
+ // return a new response::Builder with the Server and CORS header set.
133
135
fn response_builder ( & self ) -> http:: response:: Builder {
134
136
let mut builder = hyper:: Response :: builder ( ) ;
135
- let id = self
136
- . config
137
- . server
138
- . identification
139
- . as_ref ( )
140
- . map ( |s| s. as_str ( ) )
141
- . unwrap_or ( "webdav-server-rs" ) ;
142
- if id != "" {
143
- builder = builder. header ( "Server" , id) ;
144
- }
137
+ self . set_headers ( builder. headers_mut ( ) . unwrap ( ) ) ;
145
138
builder
146
139
}
147
140
148
- // Set Server: webdav-server-rs header.
149
- fn set_server_header ( & self , headers : & mut http:: HeaderMap < http:: header:: HeaderValue > ) {
141
+ // Set Server: webdav-server-rs header, and CORS .
142
+ fn set_headers ( & self , headers : & mut http:: HeaderMap < http:: header:: HeaderValue > ) {
150
143
let id = self
151
144
. config
152
145
. server
@@ -157,6 +150,11 @@ impl Server {
157
150
if id != "" {
158
151
headers. insert ( "server" , id. parse ( ) . unwrap ( ) ) ;
159
152
}
153
+ if self . config . server . cors {
154
+ headers. insert ( "Access-Control-Allow-Origin" , "*" . parse ( ) . unwrap ( ) ) ;
155
+ headers. insert ( "Access-Control-Allow-Methods" , "GET,HEAD,OPTIONS,PROPFIND" . parse ( ) . unwrap ( ) ) ;
156
+ headers. insert ( "Access-Control-Allow-Headers" , "DNT,Depth,Range" . parse ( ) . unwrap ( ) ) ;
157
+ }
160
158
}
161
159
162
160
// handle a request.
@@ -382,12 +380,11 @@ impl Server {
382
380
async fn run_davhandler ( & self , config : DavConfig , req : HttpRequest ) -> HttpResult {
383
381
let resp = self . dh . handle_with ( config, req) . await ;
384
382
let ( mut parts, body) = resp. into_parts ( ) ;
385
- self . set_server_header ( & mut parts. headers ) ;
383
+ self . set_headers ( & mut parts. headers ) ;
386
384
Ok ( http:: Response :: from_parts ( parts, body) )
387
385
}
388
386
}
389
387
390
-
391
388
fn main ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
392
389
// command line option processing.
393
390
let matches = clap_app ! ( webdav_server =>
@@ -466,7 +463,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
466
463
// build servers (one for each listen address).
467
464
let dav_server = Server :: new ( config. clone ( ) , auth) ;
468
465
let mut servers = Vec :: new ( ) ;
469
- #[ cfg( feature= "tls" ) ]
466
+ #[ cfg( all ( not ( windows ) , feature = "tls" ) ) ]
470
467
let mut tls_servers = Vec :: new ( ) ;
471
468
472
469
// Plaintext servers.
@@ -502,68 +499,72 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
502
499
} ) ;
503
500
}
504
501
505
- #[ cfg( feature= "tls" ) ]
502
+ #[ cfg( all ( not ( windows ) , feature = "tls" ) ) ]
506
503
// TLS servers.
507
- for sockaddr in tls_addrs {
508
- let listener = make_listener ( sockaddr) . unwrap_or_else ( |e| {
509
- eprintln ! ( "{}: listener on {:?}: {}" , PROGNAME , & sockaddr, e) ;
510
- exit ( 1 ) ;
511
- } ) ;
512
- let dav_server = dav_server. clone ( ) ;
513
- let tls_config = tls_config ( & config. server ) ?;
514
- let make_service = make_service_fn ( move |stream : & TlsStream < AddrStream > | {
515
- let dav_server = dav_server. clone ( ) ;
516
- let remote_addr = stream. get_ref ( ) . 0 . remote_addr ( ) ;
517
- async move {
518
- let func = move |req| {
519
- let dav_server = dav_server. clone ( ) ;
520
- async move { dav_server. route ( req, remote_addr) . await }
521
- } ;
522
- Ok :: < _ , hyper:: Error > ( service_fn ( func) )
523
- }
524
- } ) ;
525
-
526
- // Since the server can exit when there's an error on the TlsStream,
527
- // we run it in a loop. Every time the loop is entered we dup() the
528
- // listening fd and create a new TcpListener. This way, we should
529
- // not lose any pending connections during a restart.
530
- let master_listen_fd = listener. as_raw_fd ( ) ;
531
- std:: mem:: forget ( listener) ;
504
+ if tls_addrs. len ( ) > 0 {
505
+ let tls_acceptor = tls_acceptor ( & config. server ) ?;
532
506
533
- println ! ( "Listening on http://{:?}" , sockaddr) ;
534
- tls_servers. push ( async move {
535
- loop {
536
- // reuse the incoming socket after the server exits.
537
- let listen_fd = match nix:: unistd:: dup ( master_listen_fd) {
538
- Ok ( fd) => fd,
539
- Err ( e) => {
540
- eprintln ! ( "{}: server error: dup: {}" , PROGNAME , e) ;
541
- break ;
542
- }
543
- } ;
544
- // SAFETY: listen_fd is unique (we just dup'ed it).
545
- let std_listen = unsafe { std:: net:: TcpListener :: from_raw_fd ( listen_fd) } ;
546
- let listener = match tokio:: net:: TcpListener :: from_std ( std_listen) {
547
- Ok ( l) => l,
548
- Err ( e) => {
549
- eprintln ! ( "{}: server error: new TcpListener: {}" , PROGNAME , e) ;
550
- break ;
551
- }
552
- } ;
553
- let a_incoming = match AddrIncoming :: from_listener ( listener) {
554
- Ok ( a) => a,
555
- Err ( e) => {
556
- eprintln ! ( "{}: server error: new AddrIncoming: {}" , PROGNAME , e) ;
557
- break ;
507
+ for sockaddr in tls_addrs {
508
+ let tls_acceptor = tls_acceptor. clone ( ) ;
509
+ let listener = make_listener ( sockaddr) . unwrap_or_else ( |e| {
510
+ eprintln ! ( "{}: listener on {:?}: {}" , PROGNAME , & sockaddr, e) ;
511
+ exit ( 1 ) ;
512
+ } ) ;
513
+ let dav_server = dav_server. clone ( ) ;
514
+ let make_service = make_service_fn ( move |stream : & TlsStream < AddrStream > | {
515
+ let dav_server = dav_server. clone ( ) ;
516
+ let remote_addr = stream. get_ref ( ) . 0 . remote_addr ( ) ;
517
+ async move {
518
+ let func = move |req| {
519
+ let dav_server = dav_server. clone ( ) ;
520
+ async move { dav_server. route ( req, remote_addr) . await }
521
+ } ;
522
+ Ok :: < _ , hyper:: Error > ( service_fn ( func) )
523
+ }
524
+ } ) ;
525
+
526
+ // Since the server can exit when there's an error on the TlsStream,
527
+ // we run it in a loop. Every time the loop is entered we dup() the
528
+ // listening fd and create a new TcpListener. This way, we should
529
+ // not lose any pending connections during a restart.
530
+ let master_listen_fd = listener. as_raw_fd ( ) ;
531
+ std:: mem:: forget ( listener) ;
532
+
533
+ println ! ( "Listening on https://{:?}" , sockaddr) ;
534
+ tls_servers. push ( async move {
535
+ loop {
536
+ // reuse the incoming socket after the server exits.
537
+ let listen_fd = match nix:: unistd:: dup ( master_listen_fd) {
538
+ Ok ( fd) => fd,
539
+ Err ( e) => {
540
+ eprintln ! ( "{}: server error: dup: {}" , PROGNAME , e) ;
541
+ break ;
542
+ }
543
+ } ;
544
+ // SAFETY: listen_fd is unique (we just dup'ed it).
545
+ let std_listen = unsafe { std:: net:: TcpListener :: from_raw_fd ( listen_fd) } ;
546
+ let listener = match tokio:: net:: TcpListener :: from_std ( std_listen) {
547
+ Ok ( l) => l,
548
+ Err ( e) => {
549
+ eprintln ! ( "{}: server error: new TcpListener: {}" , PROGNAME , e) ;
550
+ break ;
551
+ }
552
+ } ;
553
+ let a_incoming = match AddrIncoming :: from_listener ( listener) {
554
+ Ok ( a) => a,
555
+ Err ( e) => {
556
+ eprintln ! ( "{}: server error: new AddrIncoming: {}" , PROGNAME , e) ;
557
+ break ;
558
+ }
559
+ } ;
560
+ let incoming = TlsListener :: new ( tls_acceptor. clone ( ) , a_incoming) ;
561
+ let server = hyper:: Server :: builder ( incoming) ;
562
+ if let Err ( e) = server. serve ( make_service. clone ( ) ) . await {
563
+ eprintln ! ( "{}: server error: {} (retrying)" , PROGNAME , e) ;
558
564
}
559
- } ;
560
- let incoming = TlsListener :: new ( tls_config. clone ( ) , a_incoming) ;
561
- let server = hyper:: Server :: builder ( incoming) ;
562
- if let Err ( e) = server. serve ( make_service. clone ( ) ) . await {
563
- eprintln ! ( "{}: server error: {} (retrying)" , PROGNAME , e) ;
564
565
}
565
- }
566
- } ) ;
566
+ } ) ;
567
+ }
567
568
}
568
569
569
570
// drop privs.
@@ -587,7 +588,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
587
588
for server in servers. drain ( ..) {
588
589
tasks. push ( tokio:: spawn ( server) ) ;
589
590
}
590
- #[ cfg( feature= "tls" ) ]
591
+ #[ cfg( all ( not ( windows ) , feature = "tls" ) ) ]
591
592
for server in tls_servers. drain ( ..) {
592
593
tasks. push ( tokio:: spawn ( server) ) ;
593
594
}
0 commit comments