Skip to content

Commit 0b48988

Browse files
authored
Merge pull request #563 from elmarco/fixes
Fixes from #562
2 parents 343e1e6 + aca142c commit 0b48988

File tree

6 files changed

+98
-82
lines changed

6 files changed

+98
-82
lines changed

zbus/src/address/transport/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ impl Transport {
131131

132132
#[cfg(not(unix))]
133133
{
134-
let _ = path;
134+
let _ = stream;
135135
Err(Error::Unsupported)
136136
}
137137
}

zbus/src/connection/builder.rs

+30-30
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::{
2828
async_lock::RwLock,
2929
names::{InterfaceName, UniqueName, WellKnownName},
3030
object_server::Interface,
31-
Connection, Error, Executor, Guid, Result,
31+
Connection, Error, Executor, Guid, OwnedGuid, Result,
3232
};
3333

3434
use super::{
@@ -161,7 +161,7 @@ impl<'a> Builder<'a> {
161161

162162
/// Create a builder for connection that will use the given socket.
163163
pub fn socket<S: Socket + 'static>(socket: S) -> Self {
164-
Self::new(Target::Socket(Split::new_boxed(socket)))
164+
Self::new(Target::Socket(socket.into()))
165165
}
166166

167167
/// Specify the mechanisms to use during authentication.
@@ -343,17 +343,11 @@ impl<'a> Builder<'a> {
343343
}
344344

345345
async fn build_(mut self, executor: Executor<'static>) -> Result<Connection> {
346-
let mut stream = self.stream_for_target().await?;
346+
let (mut stream, server_guid) = self.target_connect().await?;
347347
let mut auth = match self.guid {
348348
None => {
349-
let guid = match self.target {
350-
Some(Target::Address(ref addr)) => {
351-
addr.guid().map(|guid| guid.to_owned().into())
352-
}
353-
_ => None,
354-
};
355349
// SASL Handshake
356-
Authenticated::client(stream, guid, self.auth_mechanisms).await?
350+
Authenticated::client(stream, server_guid, self.auth_mechanisms).await?
357351
}
358352
Some(guid) => {
359353
if !self.p2p {
@@ -456,36 +450,42 @@ impl<'a> Builder<'a> {
456450
}
457451
}
458452

459-
async fn stream_for_target(&mut self) -> Result<BoxedSplit> {
460-
// SAFETY: `self.target` is always `Some` from the beginning and this methos is only called
453+
async fn target_connect(&mut self) -> Result<(BoxedSplit, Option<OwnedGuid>)> {
454+
// SAFETY: `self.target` is always `Some` from the beginning and this method is only called
461455
// once.
462-
Ok(match self.target.take().unwrap() {
456+
let split = match self.target.take().unwrap() {
463457
#[cfg(not(feature = "tokio"))]
464-
Target::UnixStream(stream) => Split::new_boxed(Async::new(stream)?),
458+
Target::UnixStream(stream) => Async::new(stream)?.into(),
465459
#[cfg(all(unix, feature = "tokio"))]
466-
Target::UnixStream(stream) => Split::new_boxed(stream),
460+
Target::UnixStream(stream) => stream.into(),
467461
#[cfg(all(not(unix), feature = "tokio"))]
468462
Target::UnixStream(_) => return Err(Error::Unsupported),
469463
#[cfg(not(feature = "tokio"))]
470-
Target::TcpStream(stream) => Split::new_boxed(Async::new(stream)?),
464+
Target::TcpStream(stream) => Async::new(stream)?.into(),
471465
#[cfg(feature = "tokio")]
472-
Target::TcpStream(stream) => Split::new_boxed(stream),
466+
Target::TcpStream(stream) => stream.into(),
473467
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
474-
Target::VsockStream(stream) => Split::new_boxed(Async::new(stream)?),
468+
Target::VsockStream(stream) => Async::new(stream)?.into(),
475469
#[cfg(feature = "tokio-vsock")]
476-
Target::VsockStream(stream) => Split::new_boxed(stream),
477-
Target::Address(address) => match address.connect().await? {
478-
#[cfg(any(unix, not(feature = "tokio")))]
479-
address::transport::Stream::Unix(stream) => Split::new_boxed(stream),
480-
address::transport::Stream::Tcp(stream) => Split::new_boxed(stream),
481-
#[cfg(any(
482-
all(feature = "vsock", not(feature = "tokio")),
483-
feature = "tokio-vsock"
484-
))]
485-
address::transport::Stream::Vsock(stream) => Split::new_boxed(stream),
486-
},
470+
Target::VsockStream(stream) => stream.into(),
471+
Target::Address(address) => {
472+
let guid = address.guid().map(|g| g.to_owned().into());
473+
let split = match address.connect().await? {
474+
#[cfg(any(unix, not(feature = "tokio")))]
475+
address::transport::Stream::Unix(stream) => stream.into(),
476+
address::transport::Stream::Tcp(stream) => stream.into(),
477+
#[cfg(any(
478+
all(feature = "vsock", not(feature = "tokio")),
479+
feature = "tokio-vsock"
480+
))]
481+
address::transport::Stream::Vsock(stream) => stream.into(),
482+
};
483+
return Ok((split, guid));
484+
}
487485
Target::Socket(stream) => stream,
488-
})
486+
};
487+
488+
Ok((split, None))
489489
}
490490
}
491491

zbus/src/connection/handshake.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ mod tests {
10451045

10461046
use super::*;
10471047

1048-
use crate::{connection::socket::Split, Guid, Socket};
1048+
use crate::{Guid, Socket};
10491049

10501050
fn create_async_socket_pair() -> (impl AsyncWrite + Socket, impl AsyncWrite + Socket) {
10511051
// Tokio needs us to call the sync function from async context. :shrug:
@@ -1071,9 +1071,9 @@ mod tests {
10711071
let (p0, p1) = create_async_socket_pair();
10721072

10731073
let guid = OwnedGuid::from(Guid::generate());
1074-
let client = ClientHandshake::new(Split::new_boxed(p0), None, Some(guid.clone()));
1074+
let client = ClientHandshake::new(p0.into(), None, Some(guid.clone()));
10751075
let server = ServerHandshake::new(
1076-
Split::new_boxed(p1),
1076+
p1.into(),
10771077
guid,
10781078
Some(Uid::effective().into()),
10791079
None,
@@ -1097,7 +1097,7 @@ mod tests {
10971097
fn pipelined_handshake() {
10981098
let (mut p0, p1) = create_async_socket_pair();
10991099
let server = ServerHandshake::new(
1100-
Split::new_boxed(p1),
1100+
p1.into(),
11011101
Guid::generate().into(),
11021102
Some(Uid::effective().into()),
11031103
None,
@@ -1126,7 +1126,7 @@ mod tests {
11261126
fn separate_external_data() {
11271127
let (mut p0, p1) = create_async_socket_pair();
11281128
let server = ServerHandshake::new(
1129-
Split::new_boxed(p1),
1129+
p1.into(),
11301130
Guid::generate().into(),
11311131
Some(Uid::effective().into()),
11321132
None,
@@ -1153,7 +1153,7 @@ mod tests {
11531153
fn missing_external_data() {
11541154
let (mut p0, p1) = create_async_socket_pair();
11551155
let server = ServerHandshake::new(
1156-
Split::new_boxed(p1),
1156+
p1.into(),
11571157
Guid::generate().into(),
11581158
Some(Uid::effective().into()),
11591159
None,
@@ -1171,7 +1171,7 @@ mod tests {
11711171
fn anonymous_handshake() {
11721172
let (mut p0, p1) = create_async_socket_pair();
11731173
let server = ServerHandshake::new(
1174-
Split::new_boxed(p1),
1174+
p1.into(),
11751175
Guid::generate().into(),
11761176
Some(Uid::effective().into()),
11771177
Some(vec![AuthMechanism::Anonymous].into()),
@@ -1189,7 +1189,7 @@ mod tests {
11891189
fn separate_anonymous_data() {
11901190
let (mut p0, p1) = create_async_socket_pair();
11911191
let server = ServerHandshake::new(
1192-
Split::new_boxed(p1),
1192+
p1.into(),
11931193
Guid::generate().into(),
11941194
Some(Uid::effective().into()),
11951195
Some(vec![AuthMechanism::Anonymous].into()),

zbus/src/connection/socket/split.rs

+11-10
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@ pub struct Split<R: ReadHalf, W: WriteHalf> {
88
}
99

1010
impl<R: ReadHalf, W: WriteHalf> Split<R, W> {
11-
/// Create a new boxed `Split` from `socket`.
12-
pub fn new_boxed<S: Socket<ReadHalf = R, WriteHalf = W>>(socket: S) -> BoxedSplit {
13-
let split = socket.split();
14-
15-
Split {
16-
read: Box::new(split.read),
17-
write: Box::new(split.write),
18-
}
19-
}
20-
2111
/// Reference to the read half.
2212
pub fn read(&self) -> &R {
2313
&self.read
@@ -46,3 +36,14 @@ impl<R: ReadHalf, W: WriteHalf> Split<R, W> {
4636

4737
/// A boxed `Split`.
4838
pub type BoxedSplit = Split<Box<dyn ReadHalf>, Box<dyn WriteHalf>>;
39+
40+
impl<S: Socket> From<S> for BoxedSplit {
41+
fn from(socket: S) -> Self {
42+
let split = socket.split();
43+
44+
Split {
45+
read: Box::new(split.read),
46+
write: Box::new(split.write),
47+
}
48+
}
49+
}

zbus/src/connection/socket/tcp.rs

+36-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
#[cfg(not(feature = "tokio"))]
2-
use crate::fdo::ConnectionCredentials;
3-
#[cfg(not(feature = "tokio"))]
42
use async_io::Async;
53
use std::io;
64
#[cfg(unix)]
@@ -28,7 +26,7 @@ impl ReadHalf for Arc<Async<TcpStream>> {
2826
}
2927
}
3028

31-
async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
29+
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
3230
#[cfg(windows)]
3331
let creds = {
3432
let stream = self.clone();
@@ -40,7 +38,7 @@ impl ReadHalf for Arc<Async<TcpStream>> {
4038
let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
4139
.and_then(|process_token| process_token.sid())?;
4240
io::Result::Ok(
43-
ConnectionCredentials::default()
41+
crate::fdo::ConnectionCredentials::default()
4442
.set_process_id(pid)
4543
.set_windows_sid(sid),
4644
)
@@ -51,7 +49,7 @@ impl ReadHalf for Arc<Async<TcpStream>> {
5149
}?;
5250

5351
#[cfg(not(windows))]
54-
let creds = ConnectionCredentials::default();
52+
let creds = crate::fdo::ConnectionCredentials::default();
5553

5654
Ok(creds)
5755
}
@@ -85,7 +83,7 @@ impl WriteHalf for Arc<Async<TcpStream>> {
8583
.await
8684
}
8785

88-
async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
86+
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
8987
ReadHalf::peer_credentials(self).await
9088
}
9189
}
@@ -119,21 +117,13 @@ impl ReadHalf for tokio::net::tcp::OwnedReadHalf {
119117
}
120118

121119
#[cfg(windows)]
122-
fn peer_sid(&self) -> Option<String> {
123-
use crate::win32::{socket_addr_get_pid, ProcessToken};
124-
125-
let peer_addr = match self.peer_addr() {
126-
Ok(addr) => addr,
127-
Err(_) => return None,
128-
};
129-
130-
if let Ok(pid) = socket_addr_get_pid(&peer_addr) {
131-
if let Ok(process_token) = ProcessToken::open(if pid != 0 { Some(pid) } else { None }) {
132-
return process_token.sid().ok();
133-
}
134-
}
135-
136-
None
120+
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
121+
let peer_addr = self.peer_addr()?.clone();
122+
crate::Task::spawn_blocking(
123+
move || win32_credentials_from_addr(&peer_addr),
124+
"peer credentials",
125+
)
126+
.await
137127
}
138128
}
139129

@@ -161,4 +151,29 @@ impl WriteHalf for tokio::net::tcp::OwnedWriteHalf {
161151
async fn close(&mut self) -> io::Result<()> {
162152
tokio::io::AsyncWriteExt::shutdown(self).await
163153
}
154+
155+
#[cfg(windows)]
156+
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
157+
let peer_addr = self.peer_addr()?.clone();
158+
crate::Task::spawn_blocking(
159+
move || win32_credentials_from_addr(&peer_addr),
160+
"peer credentials",
161+
)
162+
.await
163+
}
164+
}
165+
166+
#[cfg(feature = "tokio")]
167+
#[cfg(windows)]
168+
fn win32_credentials_from_addr(
169+
addr: &std::net::SocketAddr,
170+
) -> io::Result<crate::fdo::ConnectionCredentials> {
171+
use crate::win32::{socket_addr_get_pid, ProcessToken};
172+
173+
let pid = socket_addr_get_pid(addr)? as _;
174+
let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
175+
.and_then(|process_token| process_token.sid())?;
176+
Ok(crate::fdo::ConnectionCredentials::default()
177+
.set_process_id(pid)
178+
.set_windows_sid(sid))
164179
}

0 commit comments

Comments
 (0)