diff --git a/crates/ironrdp-async/src/framed.rs b/crates/ironrdp-async/src/framed.rs index e3b32b247..7aedb5387 100644 --- a/crates/ironrdp-async/src/framed.rs +++ b/crates/ironrdp-async/src/framed.rs @@ -1,26 +1,27 @@ use std::io; -use std::pin::Pin; use bytes::{Bytes, BytesMut}; use ironrdp_pdu::PduHint; -// TODO: use static async fn / return position impl trait in traits when stabiziled (https://github.com/rust-lang/rust/issues/91611) +// TODO: investigate if we could use static async fn / return position impl trait in traits when stabilized: +// https://github.com/rust-lang/rust/issues/91611 pub trait FramedRead { - /// Reads from stream and fills internal buffer - fn read<'a>( - &'a mut self, - buf: &'a mut BytesMut, - ) -> Pin> + 'a>> + type ReadFut<'read>: std::future::Future> + 'read where - Self: 'a; + Self: 'read; + + /// Reads from stream and fills internal buffer + fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>; } pub trait FramedWrite { - /// Writes an entire buffer into this stream. - fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin> + 'a>> + type WriteAllFut<'write>: std::future::Future> + 'write where - Self: 'a; + Self: 'write; + + /// Writes an entire buffer into this stream. + fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>; } pub trait StreamWrapper: Sized { diff --git a/crates/ironrdp-connector/src/lib.rs b/crates/ironrdp-connector/src/lib.rs index da915d258..b0dcdad44 100644 --- a/crates/ironrdp-connector/src/lib.rs +++ b/crates/ironrdp-connector/src/lib.rs @@ -73,7 +73,9 @@ pub struct Config { pub no_server_pointer: bool, } -pub trait State: Send + Sync + core::fmt::Debug { +ironrdp_pdu::assert_impl!(Config: Send, Sync); + +pub trait State: Send + Sync + core::fmt::Debug + 'static { fn name(&self) -> &'static str; fn is_terminal(&self) -> bool; fn as_any(&self) -> &dyn Any; @@ -81,11 +83,11 @@ pub trait State: Send + Sync + core::fmt::Debug { ironrdp_pdu::assert_obj_safe!(State); -pub fn state_downcast(state: &dyn State) -> Option<&T> { +pub fn state_downcast(state: &dyn State) -> Option<&T> { state.as_any().downcast_ref() } -pub fn state_is(state: &dyn State) -> bool { +pub fn state_is(state: &dyn State) -> bool { state.as_any().is::() } diff --git a/crates/ironrdp-futures/src/lib.rs b/crates/ironrdp-futures/src/lib.rs index 567041b59..7a257d3c2 100644 --- a/crates/ironrdp-futures/src/lib.rs +++ b/crates/ironrdp-futures/src/lib.rs @@ -1,12 +1,12 @@ +#[rustfmt::skip] // do not re-order this pub use +pub use ironrdp_async::*; + use std::io; use std::pin::Pin; use bytes::BytesMut; use futures_util::io::{AsyncRead, AsyncWrite}; -#[rustfmt::skip] // do not re-order this pub use -pub use ironrdp_async::*; - pub type FuturesFramed = Framed>; pub struct FuturesStream { @@ -35,15 +35,13 @@ impl StreamWrapper for FuturesStream { impl FramedRead for FuturesStream where - S: Unpin + AsyncRead, + S: Send + Sync + Unpin + AsyncRead, { - fn read<'a>( - &'a mut self, - buf: &'a mut BytesMut, - ) -> Pin> + 'a>> + type ReadFut<'read> = Pin> + Send + Sync + 'read>> where - Self: 'a, - { + Self: 'read; + + fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> { use futures_util::io::AsyncReadExt as _; Box::pin(async { @@ -58,13 +56,82 @@ where } impl FramedWrite for FuturesStream +where + S: Send + Sync + Unpin + AsyncWrite, +{ + type WriteAllFut<'write> = Pin> + Send + Sync + 'write>> + where + Self: 'write; + + fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> { + use futures_util::io::AsyncWriteExt as _; + + Box::pin(async { + self.inner.write_all(buf).await?; + self.inner.flush().await?; + + Ok(()) + }) + } +} + +pub type SingleThreadedFuturesFramed = Framed>; + +pub struct SingleThreadedFuturesStream { + inner: S, +} + +impl StreamWrapper for SingleThreadedFuturesStream { + type InnerStream = S; + + fn from_inner(stream: Self::InnerStream) -> Self { + Self { inner: stream } + } + + fn into_inner(self) -> Self::InnerStream { + self.inner + } + + fn get_inner(&self) -> &Self::InnerStream { + &self.inner + } + + fn get_inner_mut(&mut self) -> &mut Self::InnerStream { + &mut self.inner + } +} + +impl FramedRead for SingleThreadedFuturesStream +where + S: Unpin + AsyncRead, +{ + type ReadFut<'read> = Pin> + 'read>> + where + Self: 'read; + + fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> { + use futures_util::io::AsyncReadExt as _; + + Box::pin(async { + // NOTE(perf): tokio implementation is more efficient + let mut read_bytes = [0u8; 1024]; + let len = self.inner.read(&mut read_bytes[..]).await?; + buf.extend_from_slice(&read_bytes[..len]); + + Ok(len) + }) + } +} + +impl FramedWrite for SingleThreadedFuturesStream where S: Unpin + AsyncWrite, { - fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin> + 'a>> + type WriteAllFut<'write> = Pin> + 'write>> where - Self: 'a, - { + Self: 'write; + + fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> { use futures_util::io::AsyncWriteExt as _; Box::pin(async { diff --git a/crates/ironrdp-pdu/src/lib.rs b/crates/ironrdp-pdu/src/lib.rs index 32d46e87c..4e3e28fe5 100644 --- a/crates/ironrdp-pdu/src/lib.rs +++ b/crates/ironrdp-pdu/src/lib.rs @@ -315,7 +315,7 @@ pub fn find_size(bytes: &[u8]) -> PduResult> { } } -pub trait PduHint: core::fmt::Debug { +pub trait PduHint: Send + Sync + core::fmt::Debug + 'static { /// Finds next PDU size by reading the next few bytes. fn find_size(&self, bytes: &[u8]) -> PduResult>; } diff --git a/crates/ironrdp-pdu/src/macros.rs b/crates/ironrdp-pdu/src/macros.rs index 1eb97eeef..f70462849 100644 --- a/crates/ironrdp-pdu/src/macros.rs +++ b/crates/ironrdp-pdu/src/macros.rs @@ -174,7 +174,7 @@ macro_rules! cast_int { /// Asserts that the traits support dynamic dispatch. /// -/// From +/// From #[macro_export] macro_rules! assert_obj_safe { ($($xs:path),+ $(,)?) => { @@ -182,6 +182,20 @@ macro_rules! assert_obj_safe { }; } +/// Asserts that the type implements _all_ of the given traits. +/// +/// From +#[macro_export] +macro_rules! assert_impl { + ($type:ty: $($trait:path),+ $(,)?) => { + const _: fn() = || { + // Only callable when `$type` implements all traits in `$($trait)+`. + fn assert_impl_all() {} + assert_impl_all::<$type>(); + }; + }; +} + /// Implements additional traits for a plain old data structure (POD). #[macro_export] macro_rules! impl_pdu_pod { diff --git a/crates/ironrdp-tokio/src/lib.rs b/crates/ironrdp-tokio/src/lib.rs index eb193a08d..7c6f4ea6a 100644 --- a/crates/ironrdp-tokio/src/lib.rs +++ b/crates/ironrdp-tokio/src/lib.rs @@ -1,8 +1,10 @@ +#[rustfmt::skip] // do not re-order this pub use +pub use ironrdp_async::*; + use std::io; use std::pin::Pin; use bytes::BytesMut; -pub use ironrdp_async::*; use tokio::io::{AsyncRead, AsyncWrite}; pub type TokioFramed = Framed>; @@ -33,15 +35,13 @@ impl StreamWrapper for TokioStream { impl FramedRead for TokioStream where - S: Unpin + AsyncRead, + S: Send + Sync + Unpin + AsyncRead, { - fn read<'a>( - &'a mut self, - buf: &'a mut BytesMut, - ) -> Pin> + 'a>> + type ReadFut<'read> = Pin> + Send + Sync + 'read>> where - Self: 'a, - { + Self: 'read; + + fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> { use tokio::io::AsyncReadExt as _; Box::pin(async { self.inner.read_buf(buf).await }) @@ -49,13 +49,75 @@ where } impl FramedWrite for TokioStream +where + S: Send + Sync + Unpin + AsyncWrite, +{ + type WriteAllFut<'write> = Pin> + Send + Sync + 'write>> + where + Self: 'write; + + fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> { + use tokio::io::AsyncWriteExt as _; + + Box::pin(async { + self.inner.write_all(buf).await?; + self.inner.flush().await?; + + Ok(()) + }) + } +} + +pub type SingleThreadedTokioFramed = Framed>; + +pub struct SingleThreadedTokioStream { + inner: S, +} + +impl StreamWrapper for SingleThreadedTokioStream { + type InnerStream = S; + + fn from_inner(stream: Self::InnerStream) -> Self { + Self { inner: stream } + } + + fn into_inner(self) -> Self::InnerStream { + self.inner + } + + fn get_inner(&self) -> &Self::InnerStream { + &self.inner + } + + fn get_inner_mut(&mut self) -> &mut Self::InnerStream { + &mut self.inner + } +} + +impl FramedRead for SingleThreadedTokioStream +where + S: Unpin + AsyncRead, +{ + type ReadFut<'read> = Pin> + 'read>> + where + Self: 'read; + + fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> { + use tokio::io::AsyncReadExt as _; + + Box::pin(async { self.inner.read_buf(buf).await }) + } +} + +impl FramedWrite for SingleThreadedTokioStream where S: Unpin + AsyncWrite, { - fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin> + 'a>> + type WriteAllFut<'write> = Pin> + 'write>> where - Self: 'a, - { + Self: 'write; + + fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> { use tokio::io::AsyncWriteExt as _; Box::pin(async { diff --git a/crates/ironrdp-web/src/session.rs b/crates/ironrdp-web/src/session.rs index 77fadc679..b15f7e4c0 100644 --- a/crates/ironrdp-web/src/session.rs +++ b/crates/ironrdp-web/src/session.rs @@ -325,7 +325,7 @@ impl Session { .take() .expect("run called only once"); - let mut framed = ironrdp_futures::FuturesFramed::new(rdp_reader); + let mut framed = ironrdp_futures::SingleThreadedFuturesFramed::new(rdp_reader); info!("Start RDP session"); @@ -565,7 +565,7 @@ async fn connect( destination: String, pcb: Option, ) -> Result<(connector::ConnectionResult, WebSocketCompat), IronRdpError> { - let mut framed = ironrdp_futures::FuturesFramed::new(ws); + let mut framed = ironrdp_futures::SingleThreadedFuturesFramed::new(ws); let mut connector = connector::ClientConnector::new(config) .with_server_name(&destination)