diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 845b4d2..894fbc5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,6 +13,7 @@ jobs: strategy: matrix: rust: [stable, beta, nightly] + tls: [rustls, native-tls] features: - hyper-h1,hyper-h2 continue-on-error: ${{ matrix.rust == 'nightly' }} @@ -25,10 +26,10 @@ jobs: profile: minimal components: rustfmt, clippy - name: Build - run: cargo build --verbose --examples --features ${{ matrix.features }} + run: cargo build --verbose --examples --features ${{ matrix.tls }},${{ matrix.features }} - name: Test - run: cargo test --verbose + run: cargo test --verbose --features ${{ matrix.tls }},${{ matrix.features }} - name: Lint - run: cargo clippy --examples --features ${{ matrix.features }} + run: cargo clippy --examples --features ${{ matrix.tls }},${{ matrix.features }} - name: Format check run: cargo fmt -- --check diff --git a/.gitignore b/.gitignore index 6936990..22676ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target **/*.rs.bk Cargo.lock +.idea/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..cb587d7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,19 @@ +# Changelog + +## 0.4.0 - 2022-02-22 + +### Added + +- Support for [`native-tls`](https://github.com/sfackler/rust-native-tls). + +### Changed + +- Either one of the `rustls` or `native-tls` features must now be enabled. +- The `TlsListener` stream now returns a `tls_listener::Error` instead of `std::io::Error` type. +- Signatures of `TcpListener::new()` and `builder()` have changed to now take an argument `T: Into`. + When passing a `rustls::ServerConfig` it should therefore be wrapped in an `Arc` first. + +### Fixed + +- Crate will now compile when linked against a target that doesn't explicitly enable the `tokio/time` and `hyper/tcp` + features. diff --git a/Cargo.toml b/Cargo.toml index 01bddd6..8cf5705 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tls-listener" description = "wrap incoming Stream of connections in TLS" -version = "0.3.0" +version = "0.4.0" authors = ["Thayne McCombs "] repository = "https://github.com/tmccombs/tls-listener" edition = "2018" @@ -9,6 +9,8 @@ license = "Apache-2.0" [features] default = ["tokio-net"] +rustls = ["tokio-rustls"] +native-tls = ["tokio-native-tls"] tokio-net = ["tokio/net"] hyper-h1 = ["hyper", "hyper/http1"] @@ -16,18 +18,15 @@ hyper-h2 = ["hyper", "hyper/http2"] [dependencies] futures-util = "0.3.8" -tokio = "1.0" -tokio-rustls = "0.23.0" +hyper = { version = "0.14.1", features = ["server", "tcp"], optional = true } pin-project-lite = "0.2.8" -#tokio-native-tls = "0.3.0" - -[dependencies.hyper] -version = "0.14.1" -features = ["server"] -optional = true +thiserror = "1.0.30" +tokio = { version = "1.0", features = ["time"] } +tokio-native-tls = { version = "0.3.0", optional = true } +tokio-rustls = { version = "0.23.0", optional = true } [dev-dependencies] -hyper = { version = "0.14.1", features = ["server", "http1", "tcp", "stream"] } +hyper = { version = "0.14.1", features = ["http1", "stream"] } tokio = { version = "1.0", features = ["rt", "macros", "net", "io-util"] } [[example]] @@ -51,5 +50,5 @@ path = "examples/http-low-level.rs" required-features = ["hyper-h1"] [package.metadata.docs.rs] -all-features = true +features = ["rustls", "hyper-h1", "hyper-h2"] rustdoc-args = ["--cfg", "docsrs"] diff --git a/README.md b/README.md index 99bb5b8..aaaef51 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ This library is intended to automatically initiate a TLS connection as for each new connection in a source of new streams (such as a listening TCP or unix domain socket). -In particular, the `TlsListener` can be used as the `incoming` argument to `hyper::server::Server::builder`. +In particular, the `TlsListener` can be used as the `incoming` argument to `hyper::server::Server::builder` (requires +one of the `hyper-h1` or `hyper-h2` features). See examples for examples of usage. + +You must enable either one of the `rustls` or `native-tls` features depending on which implementation you would +like to use. diff --git a/examples/echo.rs b/examples/echo.rs index b488aa7..471642b 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -3,10 +3,13 @@ use std::net::SocketAddr; use tls_listener::TlsListener; use tokio::io::{copy, split}; use tokio::net::{TcpListener, TcpStream}; +#[cfg(feature = "native-tls")] +use tokio_native_tls::TlsStream; +#[cfg(feature = "rustls")] use tokio_rustls::server::TlsStream; mod tls_config; -use tls_config::tls_config; +use tls_config::tls_acceptor; #[inline] async fn handle_stream(stream: TlsStream) { @@ -17,13 +20,15 @@ async fn handle_stream(stream: TlsStream) { }; } +/// For example try opening and closing a connection with: +/// `echo "Q" | openssl s_client -connect host:port` #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); let listener = TcpListener::bind(&addr).await?; - TlsListener::new(tls_config(), listener) + TlsListener::new(tls_acceptor(), listener) .for_each_concurrent(None, |s| async { match s { Ok(stream) => { diff --git a/examples/http-low-level.rs b/examples/http-low-level.rs index cbdcf9e..4f6dd07 100644 --- a/examples/http-low-level.rs +++ b/examples/http-low-level.rs @@ -9,12 +9,12 @@ use hyper::{Body, Response}; use std::convert::Infallible; mod tls_config; -use tls_config::tls_config; +use tls_config::tls_acceptor; #[tokio::main(flavor = "current_thread")] async fn main() { let addr = ([127, 0, 0, 1], 3000).into(); - let listener = tls_listener::builder(tls_config()) + let listener = tls_listener::builder(tls_acceptor()) .max_handshakes(10) .listen(AddrIncoming::bind(&addr).unwrap()); diff --git a/examples/http-stream.rs b/examples/http-stream.rs index a7b08b8..bfe9b5c 100644 --- a/examples/http-stream.rs +++ b/examples/http-stream.rs @@ -9,7 +9,7 @@ use std::future::ready; use tls_listener::TlsListener; mod tls_config; -use tls_config::tls_config; +use tls_config::tls_acceptor; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { @@ -22,7 +22,7 @@ async fn main() -> Result<(), Box> { }); // This uses a filter to handle errors with connecting - let incoming = TlsListener::new(tls_config(), AddrIncoming::bind(&addr)?).filter(|conn| { + let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?).filter(|conn| { if let Err(err) = conn { eprintln!("Error: {:?}", err); ready(false) diff --git a/examples/http.rs b/examples/http.rs index 9c994a1..ea711e4 100644 --- a/examples/http.rs +++ b/examples/http.rs @@ -5,7 +5,7 @@ use std::convert::Infallible; use tls_listener::TlsListener; mod tls_config; -use tls_config::tls_config; +use tls_config::tls_acceptor; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { // This could be handled by adding a filter to the stream to filter out // unwanted errors (and possibly log them), then use `hyper::server::accept::from_stream`, // or by doing something similar to the http-low-level.rs example. - let incoming = TlsListener::new(tls_config(), AddrIncoming::bind(&addr)?); + let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?); let server = Server::builder(incoming).serve(new_svc); server.await?; diff --git a/examples/tls_config/local.pfx b/examples/tls_config/local.pfx new file mode 100644 index 0000000..703ced4 Binary files /dev/null and b/examples/tls_config/local.pfx differ diff --git a/examples/tls_config/mod.rs b/examples/tls_config/mod.rs index 2cda835..149c75d 100644 --- a/examples/tls_config/mod.rs +++ b/examples/tls_config/mod.rs @@ -1,15 +1,32 @@ -use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; +#[cfg(feature = "rustls")] +mod cert { + pub const CERT: &[u8] = include_bytes!("local.cert"); + pub const PKEY: &[u8] = include_bytes!("local.key"); +} +#[cfg(feature = "native-tls")] +const PFX: &[u8] = include_bytes!("local.pfx"); + +#[cfg(feature = "rustls")] +pub fn tls_acceptor() -> std::sync::Arc { + use std::sync::Arc; + use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; -const CERT: &[u8] = include_bytes!("local.cert"); -const PKEY: &[u8] = include_bytes!("local.key"); + let key = PrivateKey(cert::PKEY.into()); + let cert = Certificate(cert::CERT.into()); + + Arc::new( + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert], key) + .unwrap(), + ) +} -pub fn tls_config() -> ServerConfig { - let key = PrivateKey(PKEY.into()); - let cert = Certificate(CERT.into()); +#[cfg(feature = "native-tls")] +pub fn tls_acceptor() -> tokio_native_tls::native_tls::TlsAcceptor { + use tokio_native_tls::native_tls::{Identity, TlsAcceptor}; - ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![cert], key) - .unwrap() + let identity = Identity::from_pkcs12(PFX, "").unwrap(); + TlsAcceptor::builder(identity).build().unwrap() } diff --git a/src/lib.rs b/src/lib.rs index ca1a2a0..c521d49 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,21 +7,29 @@ //! for each new connection in a source of new streams (such as a listening //! TCP or unix domain socket). +mod compile_time_checks { + #[cfg(not(any(feature = "rustls", feature = "native-tls")))] + compile_error!("tls-listener requires either the `rustls` or `native-tls` feature"); + + #[cfg(all(feature = "rustls", feature = "native-tls"))] + compile_error!("The `rustls` and `native-tls` features in tls-listener are mutually exclusive"); +} + use futures_util::stream::{FuturesUnordered, Stream, StreamExt}; use pin_project_lite::pin_project; -use std::error::Error; use std::future::Future; use std::io; use std::marker::Unpin; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::{timeout, Timeout}; -use tokio_rustls::rustls::ServerConfig; -use tokio_rustls::server::TlsStream; -use tokio_rustls::TlsAcceptor; +#[cfg(feature = "native-tls")] +use tokio_native_tls::{TlsAcceptor, TlsStream}; +#[cfg(feature = "rustls")] +use tokio_rustls::{server::TlsStream, TlsAcceptor}; /// Default number of concurrent handshakes pub const DEFAULT_MAX_HANDSHAKES: usize = 64; @@ -42,6 +50,12 @@ pub trait AsyncAccept { ) -> Poll>; } +#[cfg(feature = "rustls")] +type TlsAcceptFuture = tokio_rustls::Accept; +#[cfg(feature = "native-tls")] +type TlsAcceptFuture = + Pin>>>>; + pin_project! { /// /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself @@ -69,7 +83,7 @@ pin_project! { #[pin] listener: A, tls: TlsAcceptor, - waiting: FuturesUnordered>>, + waiting: FuturesUnordered>>, max_handshakes: usize, timeout: Duration, } @@ -78,23 +92,39 @@ pin_project! { /// Builder for `TlsListener`. #[derive(Clone)] pub struct Builder { - server_config: Arc, + acceptor: TlsAcceptor, max_handshakes: usize, handshake_timeout: Duration, } +/// Wraps errors from either the listener or the TLS Acceptor +#[derive(Debug, Error)] +pub enum Error { + /// An error that arose from the listener ([AsyncAccept::Error]) + #[error("{0}")] + ListenerError(#[source] A), + /// An error that occurred during the TLS accept handshake + #[cfg(feature = "rustls")] + #[error("{0}")] + TlsAcceptError(#[source] std::io::Error), + /// An error that occurred during the TLS accept handshake + #[cfg(feature = "native-tls")] + #[error("{0}")] + TlsAcceptError(#[source] tokio_native_tls::native_tls::Error), +} + impl TlsListener { /// Create a `TlsListener` with default options. - pub fn new(server_config: ServerConfig, listener: A) -> Self { - builder(server_config).listen(listener) + pub fn new>(acceptor: T, listener: A) -> Self { + builder(acceptor).listen(listener) } } impl TlsListener where A: AsyncAccept, - A::Connection: AsyncRead + AsyncWrite + Unpin, - A::Error: Into>, + A::Connection: AsyncRead + AsyncWrite + Unpin + 'static, + A::Error: std::error::Error, Self: Unpin, { /// Accept the next connection @@ -108,10 +138,10 @@ where impl Stream for TlsListener where A: AsyncAccept, - A::Connection: AsyncRead + AsyncWrite + Unpin, - A::Error: Into>, + A::Connection: AsyncRead + AsyncWrite + Unpin + 'static, + A::Error: std::error::Error, { - type Item = io::Result>; + type Item = Result, Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); @@ -120,24 +150,29 @@ where match this.listener.as_mut().poll_accept(cx) { Poll::Pending => break, Poll::Ready(Ok(conn)) => { + #[cfg(feature = "rustls")] this.waiting .push(timeout(*this.timeout, this.tls.accept(conn))); + #[cfg(feature = "native-tls")] + { + let tls = this.tls.clone(); + this.waiting.push(timeout( + *this.timeout, + Box::pin(async move { tls.accept(conn).await }), + )); + } } Poll::Ready(Err(e)) => { - // Ideally we'd be able to do this match at compile time, but afaik, - // there isn't a way to do that with current rust. - let error = match e.into().downcast::() { - Ok(err) => *err, - Err(err) => io::Error::new(io::ErrorKind::ConnectionAborted, err), - }; - return Poll::Ready(Some(Err(error))); + return Poll::Ready(Some(Err(Error::ListenerError(e)))); } } } loop { return match this.waiting.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(conn))) => Poll::Ready(Some(conn)), + Poll::Ready(Some(Ok(conn))) => { + Poll::Ready(Some(conn.map_err(Error::TlsAcceptError))) + } // The handshake timed out, try getting another connection from the // queue Poll::Ready(Some(Err(_))) => continue, @@ -179,7 +214,7 @@ impl Builder { pub fn listen(&self, listener: A) -> TlsListener { TlsListener { listener, - tls: self.server_config.clone().into(), + tls: self.acceptor.clone(), waiting: FuturesUnordered::new(), max_handshakes: self.max_handshakes, timeout: self.handshake_timeout, @@ -190,9 +225,9 @@ impl Builder { /// Create a new Builder for a TlsListener /// /// `server_config` will be used to configure the TLS sessions. -pub fn builder(server_config: ServerConfig) -> Builder { +pub fn builder>(acceptor: T) -> Builder { Builder { - server_config: Arc::new(server_config), + acceptor: acceptor.into(), max_handshakes: DEFAULT_MAX_HANDSHAKES, handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT, } @@ -263,11 +298,11 @@ mod hyper_impl { impl HyperAccept for TlsListener where A: AsyncAccept, - A::Connection: AsyncRead + AsyncWrite + Unpin, - A::Error: Into>, + A::Connection: AsyncRead + AsyncWrite + Unpin + 'static, + A::Error: std::error::Error, { type Conn = TlsStream; - type Error = io::Error; + type Error = Error; fn poll_accept( self: Pin<&mut Self>,