Skip to content

feat: support for both Send and !Send Framed impls #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions crates/ironrdp-async/src/framed.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
use std::io;
use std::pin::Pin;

use bytes::{Bytes, BytesMut};
use ironrdp_pdu::PduHint;

// TODO: use static async fn / return position impl trait in traits when stabiziled (https://github.com/rust-lang/rust/issues/91611)
// TODO: investigate if we could use static async fn / return position impl trait in traits when stabilized:
// https://github.com/rust-lang/rust/issues/91611

pub trait FramedRead {
/// Reads from stream and fills internal buffer
fn read<'a>(
&'a mut self,
buf: &'a mut BytesMut,
) -> Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'a>>
type ReadFut<'read>: std::future::Future<Output = io::Result<usize>> + 'read
where
Self: 'a;
Self: 'read;

/// Reads from stream and fills internal buffer
fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>;
}

pub trait FramedWrite {
/// Writes an entire buffer into this stream.
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'a>>
type WriteAllFut<'write>: std::future::Future<Output = io::Result<()>> + 'write
where
Self: 'a;
Self: 'write;

/// Writes an entire buffer into this stream.
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>;
}

pub trait StreamWrapper: Sized {
Expand Down
8 changes: 5 additions & 3 deletions crates/ironrdp-connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,21 @@ pub struct Config {
pub no_server_pointer: bool,
}

pub trait State: Send + Sync + core::fmt::Debug {
ironrdp_pdu::assert_impl!(Config: Send, Sync);

pub trait State: Send + Sync + core::fmt::Debug + 'static {
fn name(&self) -> &'static str;
fn is_terminal(&self) -> bool;
fn as_any(&self) -> &dyn Any;
}

ironrdp_pdu::assert_obj_safe!(State);

pub fn state_downcast<T: State + Any>(state: &dyn State) -> Option<&T> {
pub fn state_downcast<T: State>(state: &dyn State) -> Option<&T> {
state.as_any().downcast_ref()
}

pub fn state_is<T: State + Any>(state: &dyn State) -> bool {
pub fn state_is<T: State>(state: &dyn State) -> bool {
state.as_any().is::<T>()
}

Expand Down
93 changes: 80 additions & 13 deletions crates/ironrdp-futures/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#[rustfmt::skip] // do not re-order this pub use
pub use ironrdp_async::*;

use std::io;
use std::pin::Pin;

use bytes::BytesMut;
use futures_util::io::{AsyncRead, AsyncWrite};

#[rustfmt::skip] // do not re-order this pub use
pub use ironrdp_async::*;

pub type FuturesFramed<S> = Framed<FuturesStream<S>>;

pub struct FuturesStream<S> {
Expand Down Expand Up @@ -35,15 +35,13 @@ impl<S> StreamWrapper for FuturesStream<S> {

impl<S> FramedRead for FuturesStream<S>
where
S: Unpin + AsyncRead,
S: Send + Sync + Unpin + AsyncRead,
{
fn read<'a>(
&'a mut self,
buf: &'a mut BytesMut,
) -> Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'a>>
type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + Send + Sync + 'read>>
where
Self: 'a,
{
Self: 'read;

fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use futures_util::io::AsyncReadExt as _;

Box::pin(async {
Expand All @@ -58,13 +56,82 @@ where
}

impl<S> FramedWrite for FuturesStream<S>
where
S: Send + Sync + Unpin + AsyncWrite,
{
type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + Send + Sync + 'write>>
where
Self: 'write;

fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use futures_util::io::AsyncWriteExt as _;

Box::pin(async {
self.inner.write_all(buf).await?;
self.inner.flush().await?;

Ok(())
})
}
}

pub type SingleThreadedFuturesFramed<S> = Framed<SingleThreadedFuturesStream<S>>;

pub struct SingleThreadedFuturesStream<S> {
inner: S,
}

impl<S> StreamWrapper for SingleThreadedFuturesStream<S> {
type InnerStream = S;

fn from_inner(stream: Self::InnerStream) -> Self {
Self { inner: stream }
}

fn into_inner(self) -> Self::InnerStream {
self.inner
}

fn get_inner(&self) -> &Self::InnerStream {
&self.inner
}

fn get_inner_mut(&mut self) -> &mut Self::InnerStream {
&mut self.inner
}
}

impl<S> FramedRead for SingleThreadedFuturesStream<S>
where
S: Unpin + AsyncRead,
{
type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'read>>
where
Self: 'read;

fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use futures_util::io::AsyncReadExt as _;

Box::pin(async {
// NOTE(perf): tokio implementation is more efficient
let mut read_bytes = [0u8; 1024];
let len = self.inner.read(&mut read_bytes[..]).await?;
buf.extend_from_slice(&read_bytes[..len]);

Ok(len)
})
}
}

impl<S> FramedWrite for SingleThreadedFuturesStream<S>
where
S: Unpin + AsyncWrite,
{
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'a>>
type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'write>>
where
Self: 'a,
{
Self: 'write;

fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use futures_util::io::AsyncWriteExt as _;

Box::pin(async {
Expand Down
2 changes: 1 addition & 1 deletion crates/ironrdp-pdu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ pub fn find_size(bytes: &[u8]) -> PduResult<Option<PduInfo>> {
}
}

pub trait PduHint: core::fmt::Debug {
pub trait PduHint: Send + Sync + core::fmt::Debug + 'static {
/// Finds next PDU size by reading the next few bytes.
fn find_size(&self, bytes: &[u8]) -> PduResult<Option<usize>>;
}
Expand Down
16 changes: 15 additions & 1 deletion crates/ironrdp-pdu/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,28 @@ macro_rules! cast_int {

/// Asserts that the traits support dynamic dispatch.
///
/// From <https://docs.rs/static_assertions/latest/src/static_assertions/assert_obj_safe.rs.html#72-76>
/// From <https://docs.rs/static_assertions/1.1.0/src/static_assertions/assert_obj_safe.rs.html#72-76>
#[macro_export]
macro_rules! assert_obj_safe {
($($xs:path),+ $(,)?) => {
$(const _: Option<&dyn $xs> = None;)+
};
}

/// Asserts that the type implements _all_ of the given traits.
///
/// From <https://docs.rs/static_assertions/1.1.0/src/static_assertions/assert_impl.rs.html#113-121>
#[macro_export]
macro_rules! assert_impl {
($type:ty: $($trait:path),+ $(,)?) => {
const _: fn() = || {
// Only callable when `$type` implements all traits in `$($trait)+`.
fn assert_impl_all<T: ?Sized $(+ $trait)+>() {}
assert_impl_all::<$type>();
};
};
}

/// Implements additional traits for a plain old data structure (POD).
#[macro_export]
macro_rules! impl_pdu_pod {
Expand Down
84 changes: 73 additions & 11 deletions crates/ironrdp-tokio/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#[rustfmt::skip] // do not re-order this pub use
pub use ironrdp_async::*;

use std::io;
use std::pin::Pin;

use bytes::BytesMut;
pub use ironrdp_async::*;
use tokio::io::{AsyncRead, AsyncWrite};

pub type TokioFramed<S> = Framed<TokioStream<S>>;
Expand Down Expand Up @@ -33,29 +35,89 @@ impl<S> StreamWrapper for TokioStream<S> {

impl<S> FramedRead for TokioStream<S>
where
S: Unpin + AsyncRead,
S: Send + Sync + Unpin + AsyncRead,
{
fn read<'a>(
&'a mut self,
buf: &'a mut BytesMut,
) -> Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'a>>
type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + Send + Sync + 'read>>
where
Self: 'a,
{
Self: 'read;

fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use tokio::io::AsyncReadExt as _;

Box::pin(async { self.inner.read_buf(buf).await })
}
}

impl<S> FramedWrite for TokioStream<S>
where
S: Send + Sync + Unpin + AsyncWrite,
{
type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + Send + Sync + 'write>>
where
Self: 'write;

fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use tokio::io::AsyncWriteExt as _;

Box::pin(async {
self.inner.write_all(buf).await?;
self.inner.flush().await?;

Ok(())
})
}
}

pub type SingleThreadedTokioFramed<S> = Framed<SingleThreadedTokioStream<S>>;

pub struct SingleThreadedTokioStream<S> {
inner: S,
}

impl<S> StreamWrapper for SingleThreadedTokioStream<S> {
type InnerStream = S;

fn from_inner(stream: Self::InnerStream) -> Self {
Self { inner: stream }
}

fn into_inner(self) -> Self::InnerStream {
self.inner
}

fn get_inner(&self) -> &Self::InnerStream {
&self.inner
}

fn get_inner_mut(&mut self) -> &mut Self::InnerStream {
&mut self.inner
}
}

impl<S> FramedRead for SingleThreadedTokioStream<S>
where
S: Unpin + AsyncRead,
{
type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'read>>
where
Self: 'read;

fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use tokio::io::AsyncReadExt as _;

Box::pin(async { self.inner.read_buf(buf).await })
}
}

impl<S> FramedWrite for SingleThreadedTokioStream<S>
where
S: Unpin + AsyncWrite,
{
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'a>>
type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'write>>
where
Self: 'a,
{
Self: 'write;

fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use tokio::io::AsyncWriteExt as _;

Box::pin(async {
Expand Down
4 changes: 2 additions & 2 deletions crates/ironrdp-web/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ impl Session {
.take()
.expect("run called only once");

let mut framed = ironrdp_futures::FuturesFramed::new(rdp_reader);
let mut framed = ironrdp_futures::SingleThreadedFuturesFramed::new(rdp_reader);

info!("Start RDP session");

Expand Down Expand Up @@ -565,7 +565,7 @@ async fn connect(
destination: String,
pcb: Option<String>,
) -> Result<(connector::ConnectionResult, WebSocketCompat), IronRdpError> {
let mut framed = ironrdp_futures::FuturesFramed::new(ws);
let mut framed = ironrdp_futures::SingleThreadedFuturesFramed::new(ws);

let mut connector = connector::ClientConnector::new(config)
.with_server_name(&destination)
Expand Down