Skip to content

Commit 638da90

Browse files
committed
refactor: merge build_tls() function into wrap_tls()
1 parent fe0c995 commit 638da90

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

src/net/tls.rs

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,42 @@ static LETSENCRYPT_ROOT: Lazy<Certificate> = Lazy::new(|| {
1414
.unwrap()
1515
});
1616

17-
pub fn build_tls(strict_tls: bool, alpns: &[&str]) -> TlsConnector {
17+
pub async fn wrap_tls<T: AsyncRead + AsyncWrite + Unpin>(
18+
strict_tls: bool,
19+
hostname: &str,
20+
alpn: &[&str],
21+
stream: T,
22+
) -> Result<TlsStream<T>> {
1823
let tls_builder = TlsConnector::new()
1924
.min_protocol_version(Some(Protocol::Tlsv12))
20-
.request_alpns(alpns)
25+
.request_alpns(alpn)
2126
.add_root_certificate(LETSENCRYPT_ROOT.clone());
22-
23-
if strict_tls {
27+
let tls = if strict_tls {
2428
tls_builder
2529
} else {
2630
tls_builder
2731
.danger_accept_invalid_hostnames(true)
2832
.danger_accept_invalid_certs(true)
29-
}
33+
};
34+
let tls_stream = tls.connect(hostname, stream).await?;
35+
Ok(tls_stream)
3036
}
3137

32-
pub async fn wrap_tls<T: AsyncRead + AsyncWrite + Unpin>(
33-
strict_tls: bool,
38+
pub async fn wrap_rustls<T: AsyncRead + AsyncWrite + Unpin>(
3439
hostname: &str,
3540
alpn: &[&str],
3641
stream: T,
37-
) -> Result<TlsStream<T>> {
38-
let tls = build_tls(strict_tls, alpn);
39-
let tls_stream = tls.connect(hostname, stream).await?;
40-
Ok(tls_stream)
41-
}
42+
) -> Result<tokio_rustls::client::TlsStream<T>> {
43+
let mut root_cert_store = rustls::RootCertStore::empty();
44+
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
4245

43-
#[cfg(test)]
44-
mod tests {
45-
use super::*;
46+
let mut config = rustls::ClientConfig::builder()
47+
.with_root_certificates(root_cert_store)
48+
.with_no_client_auth();
49+
config.alpn_protocols = alpn.into_iter().map(|s| s.as_bytes().to_vec()).collect();
4650

47-
#[test]
48-
fn test_build_tls() {
49-
// we are using some additional root certificates.
50-
// make sure, they do not break construction of TlsConnector
51-
let _ = build_tls(true, &[]);
52-
let _ = build_tls(false, &[]);
53-
}
51+
let tls = tokio_rustls::TlsConnector::from(Arc::new(config));
52+
let name = rustls_pki_types::ServerName::try_from(hostname)?.to_owned();
53+
let tls_stream = tls.connect(name, stream).await?;
54+
Ok(tls_stream)
5455
}

0 commit comments

Comments
 (0)