Skip to content

Commit 7a0767e

Browse files
committed
Add support for ALPN
1 parent 26f6e19 commit 7a0767e

File tree

6 files changed

+191
-0
lines changed

6 files changed

+191
-0
lines changed

mbedtls/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ name = "server"
8282
path = "examples/server.rs"
8383
required-features = ["std"]
8484

85+
[[test]]
86+
name = "alpn"
87+
path = "tests/alpn.rs"
88+
required-features = ["std"]
89+
8590
[[test]]
8691
name = "async_session"
8792
path = "tests/async_session.rs"

mbedtls/src/ssl/async_session.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,4 +372,8 @@ impl<'a> AsyncSession<'a> {
372372
pub fn verify_result(&self) -> io::Result<Result<(), VerifyError>> {
373373
self.session().map(|session| session.verify_result())
374374
}
375+
376+
pub fn get_alpn_protocol(&self) -> io::Result<Result<Option<&'a str>, Error>> {
377+
self.session().map(|session| session.get_alpn_protocol())
378+
}
375379
}

mbedtls/src/ssl/config.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* option. This file may not be copied, modified, or distributed except
77
* according to those terms. */
88

9+
use core::marker::PhantomData;
10+
use core::ptr;
911
use core::slice::from_raw_parts;
1012

1113
use mbedtls_sys::types::raw_types::{c_char, c_int, c_uchar, c_uint, c_void};
@@ -85,6 +87,32 @@ define!(
8587

8688
callback!(DbgCallback:Sync(level: c_int, file: *const c_char, line: c_int, message: *const c_char) -> ());
8789

90+
pub struct NullTerminatedStrList<'s> {
91+
c: Box<[*const u8]>,
92+
r: PhantomData<&'s ()>,
93+
}
94+
95+
unsafe impl<'s> Send for NullTerminatedStrList<'s> {}
96+
unsafe impl<'s> Sync for NullTerminatedStrList<'s> {}
97+
98+
impl<'s> NullTerminatedStrList<'s> {
99+
pub fn new(list: &[&'s str]) -> Self {
100+
let mut c = Vec::with_capacity(list.len() + 1);
101+
for s in list {
102+
c.push(s.as_ptr());
103+
}
104+
c.push(ptr::null());
105+
NullTerminatedStrList {
106+
c: c.into_boxed_slice(),
107+
r: PhantomData,
108+
}
109+
}
110+
111+
pub fn as_ptr(&self) -> *const *const u8 {
112+
self.c.as_ptr()
113+
}
114+
}
115+
88116
define!(
89117
#[c_ty(ssl_config)]
90118
struct Config<'c>;
@@ -329,6 +357,17 @@ impl<'c> Config<'c> {
329357
)
330358
}
331359
}
360+
361+
/// Set the supported Application Layer Protocols.
362+
///
363+
/// Each protocol name in the list must also be terminated with a null character (`\0`).
364+
pub fn set_alpn_protocols(&mut self, protocols: &'c NullTerminatedStrList<'c>) -> Result<()> {
365+
unsafe {
366+
ssl_conf_alpn_protocols(&mut self.inner, protocols.as_ptr() as *mut _)
367+
.into_result()
368+
.map(|_| ())
369+
}
370+
}
332371
}
333372

334373
/// Builds a linked list of x509_crt instances, all of which are owned by mbedtls. That is, the

mbedtls/src/ssl/context.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,19 @@ impl<'a> Session<'a> {
319319
flags => Err(VerifyError::from_bits_truncate(flags)),
320320
}
321321
}
322+
323+
#[cfg(feature = "std")]
324+
pub fn get_alpn_protocol(&self) -> Result<Option<&'a str>> {
325+
unsafe {
326+
let ptr = ssl_get_alpn_protocol(self.inner);
327+
if ptr.is_null() {
328+
Ok(None)
329+
} else {
330+
let s = std::ffi::CStr::from_ptr(ptr).to_str()?;
331+
Ok(Some(s))
332+
}
333+
}
334+
}
322335
}
323336

324337
impl<'a> Read for Session<'a> {

mbedtls/tests/alpn.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/* Copyright (c) Fortanix, Inc.
2+
*
3+
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
4+
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
5+
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
6+
* option. This file may not be copied, modified, or distributed except
7+
* according to those terms. */
8+
9+
#![allow(dead_code)]
10+
extern crate mbedtls;
11+
12+
use std::net::TcpStream;
13+
use std::thread;
14+
15+
use mbedtls::pk::Pk;
16+
use mbedtls::rng::CtrDrbg;
17+
use mbedtls::ssl::config::{Endpoint, NullTerminatedStrList, Preset, Transport};
18+
use mbedtls::ssl::{Config, Context, Session};
19+
use mbedtls::x509::{Certificate, LinkedCertificate, VerifyError};
20+
use mbedtls::{Error, Result};
21+
22+
mod support;
23+
use support::entropy::entropy_new;
24+
use support::keys;
25+
26+
27+
#[derive(Debug)]
28+
enum Expected<'a> {
29+
FailedHandshake(Error),
30+
SessionEstablished {
31+
alpn: Option<&'a str>,
32+
}
33+
}
34+
35+
impl Expected<'_> {
36+
fn check(self, res: Result<Session<'_>>) {
37+
match (res, self) {
38+
(Ok(session), Expected::SessionEstablished { alpn }) => assert_eq!(session.get_alpn_protocol().unwrap(), alpn),
39+
(Err(e), Expected::FailedHandshake(err)) => assert_eq!(e, err),
40+
(res, expected) => panic!("Unexpected result, expected {:?}, session is_ok: {}", expected, res.is_ok()),
41+
}
42+
}
43+
}
44+
45+
fn client(mut conn: TcpStream, alpn_list: Option<&[&str]>, expected: Expected<'_>) -> Result<()> {
46+
let mut entropy = entropy_new();
47+
let mut rng = CtrDrbg::new(&mut entropy, None)?;
48+
let mut cacert = Certificate::from_pem(keys::ROOT_CA_CERT)?;
49+
let verify_callback = &mut |_crt: &mut LinkedCertificate, _depth, verify_flags: &mut VerifyError| {
50+
verify_flags.remove(VerifyError::CERT_EXPIRED);
51+
Ok(())
52+
};
53+
let alpn_list = alpn_list.map(|list| NullTerminatedStrList::new(list));
54+
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
55+
config.set_rng(Some(&mut rng));
56+
config.set_verify_callback(verify_callback);
57+
config.set_ca_list(Some(&mut *cacert), None);
58+
if let Some(ref alpn_list) = alpn_list {
59+
config.set_alpn_protocols(alpn_list)?;
60+
}
61+
let mut ctx = Context::new(&config)?;
62+
let res = ctx.establish(&mut conn, None);
63+
expected.check(res);
64+
Ok(())
65+
}
66+
67+
fn server(mut conn: TcpStream, alpn_list: Option<&[&str]>, expected: Expected<'_>) -> Result<()> {
68+
let mut entropy = entropy_new();
69+
let mut rng = CtrDrbg::new(&mut entropy, None)?;
70+
let mut cert = Certificate::from_pem(keys::EXPIRED_CERT)?;
71+
let mut key = Pk::from_private_key(keys::EXPIRED_KEY, None)?;
72+
let alpn_list = alpn_list.map(|list| NullTerminatedStrList::new(list));
73+
let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default);
74+
config.set_rng(Some(&mut rng));
75+
config.push_cert(&mut *cert, &mut key)?;
76+
if let Some(ref alpn_list) = alpn_list {
77+
config.set_alpn_protocols(alpn_list)?;
78+
}
79+
let mut ctx = Context::new(&config)?;
80+
81+
let res = ctx.establish(&mut conn, None);
82+
expected.check(res);
83+
Ok(())
84+
}
85+
86+
#[test]
87+
fn alpn() {
88+
#[derive(Clone)]
89+
struct TestConfig {
90+
client_list: Option<&'static [&'static str]>,
91+
server_list: Option<&'static [&'static str]>,
92+
expected: Option<&'static str>,
93+
}
94+
95+
impl TestConfig {
96+
fn new(client_list: Option<&'static [&'static str]>, server_list: Option<&'static [&'static str]>, expected: Option<&'static str>) -> Self {
97+
Self { client_list, server_list, expected }
98+
}
99+
}
100+
101+
let test_configs = vec![
102+
TestConfig::new(Some(&["h2\0", "http/1.1\0"]), Some(&["h2\0", "http/1.1\0"]), Some("h2")),
103+
TestConfig::new(Some(&["http/1.1\0", "h2\0"]), Some(&["h2\0", "http/1.1\0"]), Some("h2")),
104+
TestConfig::new(Some(&["h2\0", "http/1.1\0"]), Some(&["http/1.1\0", "h2\0"]), Some("http/1.1")),
105+
TestConfig::new(None, None, None),
106+
TestConfig::new(None, Some(&["h2\0", "http/1.1\0"]), None),
107+
TestConfig::new(Some(&["h2\0", "http/1.1\0"]), None, None),
108+
];
109+
110+
for config in test_configs {
111+
let client_list = config.client_list;
112+
let server_list = config.server_list;
113+
let alpn = config.expected;
114+
let (c, s) = support::net::create_tcp_pair().unwrap();
115+
let c = thread::spawn(move || client(c, client_list, Expected::SessionEstablished { alpn }).unwrap());
116+
let s = thread::spawn(move || server(s, server_list, Expected::SessionEstablished { alpn }).unwrap());
117+
c.join().unwrap();
118+
s.join().unwrap();
119+
}
120+
}
121+
122+
#[test]
123+
fn nothing_in_common() {
124+
let (c, s) = support::net::create_tcp_pair().unwrap();
125+
let c = thread::spawn(move || client(c, Some(&["a1\0", "a2\0"]), Expected::FailedHandshake(Error::SslFatalAlertMessage)).unwrap());
126+
let s = thread::spawn(move || server(s, Some(&["b1\0", "b2\0"]), Expected::FailedHandshake(Error::SslBadHsClientHello)).unwrap());
127+
c.join().unwrap();
128+
s.join().unwrap();
129+
}

mbedtls/tests/async_session.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ async fn server(
123123
}
124124
};
125125

126+
assert_eq!(session.get_alpn_protocol().unwrap().unwrap(), None);
126127
let ciphersuite = session.ciphersuite().unwrap();
127128
session
128129
.write_all(format!("Server2Client {:4x}", ciphersuite).as_bytes())

0 commit comments

Comments
 (0)