Skip to content

Commit 55a6a75

Browse files
authored
Merge branch 'master' into gijsk/windows
2 parents 745c2a4 + 74dbd1b commit 55a6a75

File tree

6 files changed

+292
-192
lines changed

6 files changed

+292
-192
lines changed

mbedtls/examples/client_dtls.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use std::sync::Arc;
1515

1616
use mbedtls::rng::CtrDrbg;
1717
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
18-
use mbedtls::ssl::{Config, Context};
18+
use mbedtls::ssl::{Config, Context, Io};
1919
use mbedtls::x509::Certificate;
2020
use mbedtls::Result as TlsResult;
2121

@@ -35,7 +35,7 @@ fn result_main(addr: &str) -> TlsResult<()> {
3535
ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new()));
3636

3737
let sock = UdpSocket::bind("localhost:12345").unwrap();
38-
let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap();
38+
let sock = mbedtls::ssl::io::ConnectedUdpSocket::connect(sock, addr).unwrap();
3939
ctx.establish(sock, None).unwrap();
4040

4141
let mut line = String::new();

mbedtls/src/ssl/config.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ define!(
5252

5353
define!(
5454
#[c_ty(c_int)]
55+
#[derive(PartialEq, Eq)]
5556
enum Transport {
5657
/// TLS
5758
Stream = SSL_TRANSPORT_STREAM,

mbedtls/src/ssl/context.rs

Lines changed: 13 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,9 @@
99
use core::result::Result as StdResult;
1010

1111
#[cfg(feature = "std")]
12-
use {
13-
std::io::{Read, Write, Result as IoResult, Error as IoError},
14-
std::sync::Arc,
15-
};
12+
use std::sync::Arc;
1613

17-
use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void};
18-
use mbedtls_sys::types::size_t;
14+
use mbedtls_sys::types::raw_types::{c_int, c_void};
1915
use mbedtls_sys::*;
2016

2117
#[cfg(not(feature = "std"))]
@@ -25,94 +21,9 @@ use crate::error::{Error, Result, IntoResult};
2521
use crate::pk::Pk;
2622
use crate::private::UnsafeFrom;
2723
use crate::ssl::config::{Config, Version, AuthMode};
24+
use crate::ssl::io::IoCallbackUnsafe;
2825
use crate::x509::{Certificate, Crl, VerifyError};
2926

30-
pub trait IoCallback {
31-
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized;
32-
unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized;
33-
fn data_ptr(&mut self) -> *mut c_void;
34-
}
35-
36-
#[cfg(feature = "std")]
37-
impl<IO: Read + Write> IoCallback for IO {
38-
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
39-
let len = if len > (c_int::max_value() as size_t) {
40-
c_int::max_value() as size_t
41-
} else {
42-
len
43-
};
44-
match (&mut *(user_data as *mut IO)).read(::core::slice::from_raw_parts_mut(data, len)) {
45-
Ok(i) => i as c_int,
46-
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
47-
}
48-
}
49-
50-
unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
51-
let len = if len > (c_int::max_value() as size_t) {
52-
c_int::max_value() as size_t
53-
} else {
54-
len
55-
};
56-
match (&mut *(user_data as *mut IO)).write(::core::slice::from_raw_parts(data, len)) {
57-
Ok(i) => i as c_int,
58-
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
59-
}
60-
}
61-
62-
fn data_ptr(&mut self) -> *mut c_void {
63-
self as *mut IO as *mut _
64-
}
65-
}
66-
67-
#[cfg(feature = "std")]
68-
pub struct ConnectedUdpSocket {
69-
socket: std::net::UdpSocket,
70-
}
71-
72-
#[cfg(feature = "std")]
73-
impl ConnectedUdpSocket {
74-
pub fn connect<A: std::net::ToSocketAddrs>(socket: std::net::UdpSocket, addr: A) -> StdResult<Self, (IoError, std::net::UdpSocket)> {
75-
match socket.connect(addr) {
76-
Ok(_) => Ok(ConnectedUdpSocket {
77-
socket,
78-
}),
79-
Err(e) => Err((e, socket)),
80-
}
81-
}
82-
}
83-
84-
#[cfg(feature = "std")]
85-
impl IoCallback for ConnectedUdpSocket {
86-
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
87-
let len = if len > (c_int::max_value() as size_t) {
88-
c_int::max_value() as size_t
89-
} else {
90-
len
91-
};
92-
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) {
93-
Ok(i) => i as c_int,
94-
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
95-
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
96-
}
97-
}
98-
99-
unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
100-
let len = if len > (c_int::max_value() as size_t) {
101-
c_int::max_value() as size_t
102-
} else {
103-
len
104-
};
105-
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) {
106-
Ok(i) => i as c_int,
107-
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
108-
}
109-
}
110-
111-
fn data_ptr(&mut self) -> *mut c_void {
112-
self as *mut ConnectedUdpSocket as *mut c_void
113-
}
114-
}
115-
11627
pub trait TimerCallback: Send + Sync {
11728
unsafe extern "C" fn set_timer(
11829
p_timer: *mut c_void,
@@ -261,8 +172,13 @@ impl<T> Context<T> {
261172
}
262173
}
263174

264-
impl<T: IoCallback> Context<T> {
265-
pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> {
175+
impl<T> Context<T> {
176+
/// Establish a TLS session on the given `io`.
177+
///
178+
/// Upon succesful return, the context can be communicated with using the
179+
/// `std::io::Read` and `std::io::Write` traits if `io` implements those as
180+
/// well, and using the `mbedtls::ssl::io::Io` trait otherwise.
181+
pub fn establish<IoType>(&mut self, io: T, hostname: Option<&str>) -> Result<()> where T: IoCallbackUnsafe<IoType> {
266182
unsafe {
267183
let mut io = Box::new(io);
268184
ssl_session_reset(self.into()).into_result()?;
@@ -292,7 +208,7 @@ impl<T> Context<T> {
292208
/// Try to complete the handshake procedure to set up a (D)TLS connection
293209
///
294210
/// In general, this should not be called directly. Instead, [`establish`](Context::establish)
295-
/// should be used which properly sets up the [`IoCallback`] and resets any previous sessions.
211+
/// should be used which properly sets up the [`IoCallbackUnsafe`] and resets any previous sessions.
296212
///
297213
/// This should only be used directly if the handshake could not be completed successfully in
298214
/// `establish`, i.e.:
@@ -483,16 +399,14 @@ impl<T> Context<T> {
483399
pub fn set_client_transport_id_once(&mut self, info: &[u8]) {
484400
self.client_transport_id = Some(info.into());
485401
}
486-
}
487402

488-
impl<T: IoCallback> Context<T> {
489-
pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
403+
pub(super) fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
490404
unsafe {
491405
ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result().map(|r| r as usize)
492406
}
493407
}
494408

495-
pub fn send(&mut self, buf: &[u8]) -> Result<usize> {
409+
pub(super) fn send(&mut self, buf: &[u8]) -> Result<usize> {
496410
unsafe {
497411
ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result().map(|w| w as usize)
498412
}
@@ -508,40 +422,6 @@ impl<T> Drop for Context<T> {
508422
}
509423
}
510424

511-
#[cfg(feature = "std")]
512-
/// Implements [`std::io::Read`] whenever T implements `Read`, too. This ensures that
513-
/// `Read`, which is designated for byte-oriented sources, is only implemented when the
514-
/// underlying [`IoCallback`] is byte-oriented, too. Specifically, this means that it is implemented
515-
/// for `Context<TcpStream>`, i.e. TLS connections but not for DTLS connections.
516-
impl<T: IoCallback + Read> Read for Context<T> {
517-
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
518-
match self.recv(buf) {
519-
Err(Error::SslPeerCloseNotify) => Ok(0),
520-
Err(e) => Err(crate::private::error_to_io_error(e)),
521-
Ok(i) => Ok(i),
522-
}
523-
}
524-
}
525-
526-
#[cfg(feature = "std")]
527-
/// Implements [`std::io::Write`] whenever T implements `Write`, too. This ensures that
528-
/// `Write`, which is designated for byte-oriented sinks, is only implemented when the
529-
/// underlying [`IoCallback`] is byte-oriented, too. Specifically, this means that it is implemented
530-
/// for `Context<TcpStream>`, i.e. TLS connections but not for DTLS connections.
531-
impl<T: IoCallback + Write> Write for Context<T> {
532-
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
533-
match self.send(buf) {
534-
Err(Error::SslPeerCloseNotify) => Ok(0),
535-
Err(e) => Err(crate::private::error_to_io_error(e)),
536-
Ok(i) => Ok(i),
537-
}
538-
}
539-
540-
fn flush(&mut self) -> IoResult<()> {
541-
Ok(())
542-
}
543-
}
544-
545425
//
546426
// Class exists only during SNI callback that is configured from Config.
547427
// SNI Callback must provide input whose lifetime exceeds the SNI closure to avoid memory corruptions.

0 commit comments

Comments
 (0)