Skip to content

Add new features runtime-{runtime}-notls to avoid tls dependency #2298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ runtime-tokio-rustls = [
"_rt-tokio",
]

runtime-actix-notls = ["runtime-tokio-notls"]
runtime-async-std-notls = [
"sqlx-core/runtime-async-std-notls",
"sqlx-macros/runtime-async-std-notls",
"_rt-async-std",
]
runtime-tokio-notls = [
"sqlx-core/runtime-tokio-notls",
"sqlx-macros/runtime-tokio-notls",
"_rt-tokio",
]

# for conditional compilation
_rt-async-std = []
_rt-tokio = []
Expand Down
15 changes: 15 additions & 0 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,25 @@ runtime-tokio-rustls = [
"_rt-tokio"
]

runtime-actix-notls = ['runtime-tokio-notls']
runtime-async-std-notls = [
"sqlx-rt/runtime-async-std-notls",
"sqlx/runtime-async-std-notls",
"_tls-notls",
"_rt-async-std",
]
runtime-tokio-notls = [
"sqlx-rt/runtime-tokio-notls",
"sqlx/runtime-tokio-notls",
"_tls-notls",
"_rt-tokio"
]

# for conditional compilation
_rt-async-std = []
_rt-tokio = ["tokio-stream"]
_tls-native-tls = []
_tls-notls = []
_tls-rustls = ["rustls", "rustls-pemfile", "webpki-roots"]

# support offline/decoupled building (enables serialization of `Describe`)
Expand Down
87 changes: 56 additions & 31 deletions sqlx-core/src/net/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use std::path::PathBuf;
use std::pin::Pin;
use std::task::{Context, Poll};

use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};
#[cfg(not(feature = "_tls-notls"))]
use sqlx_rt::TlsStream;
use sqlx_rt::{AsyncRead, AsyncWrite};

use crate::error::Error;
use std::mem::replace;
Expand Down Expand Up @@ -56,6 +58,9 @@ impl std::fmt::Display for CertificateInput {
#[cfg(feature = "_tls-rustls")]
mod rustls;

#[cfg(feature = "_tls-notls")]
pub struct MaybeTlsStream<S>(S);
#[cfg(not(feature = "_tls-notls"))]
pub enum MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
Expand All @@ -69,11 +74,28 @@ impl<S> MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
#[cfg(feature = "_tls-notls")]
#[inline]
pub fn is_tls(&self) -> bool {
false
}
#[cfg(not(feature = "_tls-notls"))]
#[inline]
pub fn is_tls(&self) -> bool {
matches!(self, Self::Tls(_))
}

#[cfg(feature = "_tls-notls")]
pub async fn upgrade(
&mut self,
host: &str,
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert_path: Option<&CertificateInput>,
) -> Result<(), Error> {
Ok(())
}
#[cfg(not(feature = "_tls-notls"))]
pub async fn upgrade(
&mut self,
host: &str,
Expand Down Expand Up @@ -112,6 +134,24 @@ where
}
}

#[cfg(feature = "_tls-notls")]
macro_rules! exec_on_stream {
($stream:ident, $fn_name:ident, $($arg:ident),*) => (
Pin::new(&mut $stream.0).$fn_name($($arg,)*)
)
}
#[cfg(not(feature = "_tls-notls"))]
macro_rules! exec_on_stream {
($stream:ident, $fn_name:ident, $($arg:ident),*) => (
match &mut *$stream {
MaybeTlsStream::Raw(s) => Pin::new(s).$fn_name($($arg,)*),
MaybeTlsStream::Tls(s) => Pin::new(s).$fn_name($($arg,)*),

MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
)
}

#[cfg(feature = "_tls-native-tls")]
async fn configure_tls_connector(
accept_invalid_certs: bool,
Expand Down Expand Up @@ -155,12 +195,7 @@ where
cx: &mut Context<'_>,
buf: &mut super::PollReadBuf<'_>,
) -> Poll<io::Result<super::PollReadOut>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),

MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
exec_on_stream!(self, poll_read, cx, buf)
}
}

Expand All @@ -173,41 +208,21 @@ where
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),

MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
exec_on_stream!(self, poll_write, cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),

MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
exec_on_stream!(self, poll_flush, cx)
}

#[cfg(feature = "_rt-tokio")]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),

MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
exec_on_stream!(self, poll_shutdown, cx)
}

#[cfg(feature = "_rt-async-std")]
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),

MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
exec_on_stream!(self, poll_close, cx)
}
}

Expand All @@ -218,6 +233,11 @@ where
type Target = S;

fn deref(&self) -> &Self::Target {
#[cfg(feature = "_tls-notls")]
{
&self.0
}
#[cfg(not(feature = "_tls-notls"))]
match self {
MaybeTlsStream::Raw(s) => s,

Expand All @@ -242,6 +262,11 @@ where
S: Unpin + AsyncWrite + AsyncRead,
{
fn deref_mut(&mut self) -> &mut Self::Target {
#[cfg(feature = "_tls-notls")]
{
&mut self.0
}
#[cfg(not(feature = "_tls-notls"))]
match self {
MaybeTlsStream::Raw(s) => s,

Expand Down
12 changes: 12 additions & 0 deletions sqlx-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ runtime-tokio-rustls = [
"_rt-tokio",
]

runtime-actix-notls = ["runtime-tokio-notls"]
runtime-async-std-notls = [
"sqlx-core/runtime-async-std-notls",
"sqlx-rt/runtime-async-std-notls",
"_rt-async-std",
]
runtime-tokio-notls = [
"sqlx-core/runtime-tokio-notls",
"sqlx-rt/runtime-tokio-notls",
"_rt-tokio",
]

# for conditional compilation
_rt-async-std = []
_rt-tokio = []
Expand Down
5 changes: 5 additions & 0 deletions sqlx-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ runtime-actix-rustls = ["runtime-tokio-rustls"]
runtime-async-std-rustls = ["_rt-async-std", "_tls-rustls", "futures-rustls"]
runtime-tokio-rustls = ["_rt-tokio", "_tls-rustls", "tokio-rustls"]

runtime-actix-notls = ["runtime-tokio-notls"]
runtime-async-std-notls = ["_rt-async-std", "_tls-notls"]
runtime-tokio-notls = ["_rt-tokio", "_tls-notls"]

# Not used directly and not re-exported from sqlx
_rt-async-std = ["async-std"]
_rt-tokio = ["tokio", "once_cell"]
_tls-native-tls = ["native-tls"]
_tls-notls = []
_tls-rustls = []

[dependencies]
Expand Down
11 changes: 9 additions & 2 deletions sqlx-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,30 @@
feature = "runtime-actix-rustls",
feature = "runtime-async-std-rustls",
feature = "runtime-tokio-rustls",
feature = "runtime-actix-notls",
feature = "runtime-async-std-notls",
feature = "runtime-tokio-notls",
)))]
compile_error!(
"one of the features ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \
'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \
'runtime-tokio-rustls'] must be enabled"
'runtime-tokio-rustls', 'runtime-actix-notls', 'runtime-async-std-notls', \
'runtime-tokio-notls'] must be enabled"
);

#[cfg(any(
all(feature = "_rt-actix", feature = "_rt-async-std"),
all(feature = "_rt-actix", feature = "_rt-tokio"),
all(feature = "_rt-async-std", feature = "_rt-tokio"),
all(feature = "_tls-native-tls", feature = "_tls-rustls"),
all(feature = "_tls-native-tls", feature = "_tls-notls"),
all(feature = "_tls-rustls", feature = "_tls-notls"),
))]
compile_error!(
"only one of ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \
'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \
'runtime-tokio-rustls'] can be enabled"
'runtime-tokio-rustls', 'runtime-actix-notls', 'runtime-async-std-notls', \
'runtime-tokio-notls'] can be enabled"
);

#[cfg(feature = "_rt-async-std")]
Expand Down