From 634988d15db3b6f8fce077c76a1d15ea67259c6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 3 Dec 2024 01:41:07 +0900 Subject: [PATCH] feat(io): reimplement vectored extensions --- compio-io/src/lib.rs | 1 + compio-io/src/read/ext.rs | 93 +++++++++++----------- compio-io/src/vectored.rs | 153 +++++++++++++++++++++++++++++++++++++ compio-io/src/write/ext.rs | 83 +++++++++----------- compio-io/tests/io.rs | 73 ++++++++++++++++++ 5 files changed, 309 insertions(+), 94 deletions(-) create mode 100644 compio-io/src/vectored.rs diff --git a/compio-io/src/lib.rs b/compio-io/src/lib.rs index 92ba98df..708fc3d4 100644 --- a/compio-io/src/lib.rs +++ b/compio-io/src/lib.rs @@ -109,6 +109,7 @@ pub mod compat; mod read; mod split; pub mod util; +mod vectored; mod write; pub(crate) type IoResult = std::io::Result; diff --git a/compio-io/src/read/ext.rs b/compio-io/src/read/ext.rs index 14d82ee5..8ac4c46e 100644 --- a/compio-io/src/read/ext.rs +++ b/compio-io/src/read/ext.rs @@ -1,9 +1,9 @@ #[cfg(feature = "allocator_api")] use std::alloc::Allocator; -use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, t_alloc}; +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, SetBufInit, t_alloc}; -use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take}; +use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take, vectored::VectoredWrap}; /// Shared code for read a scalar value from the underlying reader. macro_rules! read_scalar { @@ -36,63 +36,35 @@ macro_rules! read_scalar { /// Shared code for loop reading until reaching a certain length. macro_rules! loop_read_exact { - ($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => { - let mut $tracker = 0; + ($buf:ident, $len:expr, $tracker:ident, $read_expr:expr, $update_expr:expr, $buf_expr:expr) => { + let mut $tracker = 0usize; let len = $len; - while $tracker < len { - match $read_expr.await.into_inner() { - BufResult(Ok(0), buf) => { + let BufResult(res, buf) = $read_expr; + $buf = buf; + match res { + Ok(0) => { return BufResult( Err(::std::io::Error::new( ::std::io::ErrorKind::UnexpectedEof, "failed to fill whole buffer", )), - buf, + $buf_expr, ); } - BufResult(Ok(n), buf) => { - $tracker += n; - $buf = buf; - } - BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => { - $buf = buf; + Ok(n) => { + $tracker += n as usize; + $update_expr; } - BufResult(Err(e), buf) => return BufResult(Err(e), buf), + Err(ref e) if e.kind() == ::std::io::ErrorKind::Interrupted => {} + Err(e) => return BufResult(Err(e), $buf_expr), } } - return BufResult(Ok(()), $buf) + return BufResult(Ok(()), $buf_expr) }; } macro_rules! loop_read_vectored { - ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{ - use ::compio_buf::OwnedIterator; - - let mut $iter = match $buf.owned_iter() { - Ok(buf) => buf, - Err(buf) => return BufResult(Ok(()), buf), - }; - let mut $tracker: $tracker_ty = 0; - - loop { - let len = $iter.buf_capacity(); - if len > 0 { - match $read_expr.await { - BufResult(Ok(()), ret) => { - $iter = ret; - $tracker += len as $tracker_ty; - } - BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()), - }; - } - - match $iter.next() { - Ok(next) => $iter = next, - Err(buf) => return BufResult(Ok(()), buf), - } - } - }}; ($buf:ident, $iter:ident, $read_expr:expr) => {{ use ::compio_buf::OwnedIterator; @@ -158,7 +130,14 @@ pub trait AsyncReadExt: AsyncRead { /// Read the exact number of bytes required to fill the buf. async fn read_exact(&mut self, mut buf: T) -> BufResult<(), T> { - loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..))); + loop_read_exact!( + buf, + buf.buf_capacity(), + read, + self.read(buf.slice(read..)).await.into_inner(), + {}, + buf + ); } /// Read all bytes until underlying reader reaches `EOF`. @@ -171,7 +150,15 @@ pub trait AsyncReadExt: AsyncRead { /// Read the exact number of bytes required to fill the vectored buf. async fn read_vectored_exact(&mut self, buf: T) -> BufResult<(), T> { - loop_read_vectored!(buf, _total: usize, iter, loop self.read_exact(iter)) + let mut buf = VectoredWrap::new(buf); + loop_read_exact!( + buf, + buf.capacity(), + read, + self.read_vectored(buf).await, + unsafe { buf.set_buf_init(read) }, + buf.into_inner() + ); } /// Creates an adaptor which reads at most `limit` bytes from it. @@ -234,7 +221,11 @@ pub trait AsyncReadAtExt: AsyncReadAt { buf, buf.buf_capacity(), read, - loop self.read_at(buf.slice(read..), pos + read as u64) + self.read_at(buf.slice(read..), pos + read as u64) + .await + .into_inner(), + {}, + buf ); } @@ -262,7 +253,15 @@ pub trait AsyncReadAtExt: AsyncReadAt { buf: T, pos: u64, ) -> BufResult<(), T> { - loop_read_vectored!(buf, total: u64, iter, loop self.read_exact_at(iter, pos + total)) + let mut buf = VectoredWrap::new(buf); + loop_read_exact!( + buf, + buf.capacity(), + read, + self.read_vectored_at(buf, pos + read as u64).await, + unsafe { buf.set_buf_init(read) }, + buf.into_inner() + ); } } diff --git a/compio-io/src/vectored.rs b/compio-io/src/vectored.rs new file mode 100644 index 00000000..24376ac5 --- /dev/null +++ b/compio-io/src/vectored.rs @@ -0,0 +1,153 @@ +use std::pin::Pin; + +use compio_buf::{ + Indexable, IndexableMut, IndexedIter, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, + IoVectoredBufMut, MaybeOwned, MaybeOwnedMut, SetBufInit, +}; + +pub struct VectoredWrap { + buffers: Pin>, + wraps: Vec, + vec_off: usize, +} + +impl VectoredWrap { + pub fn new(buffers: T) -> Self { + let buffers = Box::pin(buffers); + let wraps = buffers.iter_buf().map(|buf| BufWrap::new(&*buf)).collect(); + Self { + buffers, + wraps, + vec_off: 0, + } + } + + pub fn len(&self) -> usize { + self.wraps.iter().map(|buf| buf.len).sum() + } + + pub fn capacity(&self) -> usize { + self.wraps.iter().map(|buf| buf.capacity).sum() + } +} + +impl IoVectoredBuf for VectoredWrap { + type Buf = BufWrap; + type OwnedIter = IndexedIter; + + fn iter_buf(&self) -> impl Iterator> { + self.wraps + .iter() + .skip(self.vec_off) + .map(MaybeOwned::Borrowed) + } + + fn owned_iter(self) -> Result + where + Self: Sized, + { + IndexedIter::new(self) + } +} + +impl IoVectoredBufMut for VectoredWrap { + fn iter_buf_mut(&mut self) -> impl Iterator> { + self.wraps + .iter_mut() + .skip(self.vec_off) + .map(MaybeOwnedMut::Borrowed) + } +} + +impl SetBufInit for VectoredWrap { + unsafe fn set_buf_init(&mut self, mut len: usize) { + self.buffers.as_mut().get_unchecked_mut().set_buf_init(len); + self.vec_off = 0; + for buf in self.wraps.iter_mut().skip(self.vec_off) { + let capacity = (*buf).buf_capacity(); + let buf_new_len = len.min(capacity); + buf.set_buf_init(buf_new_len); + *buf = buf.offset(buf_new_len); + if len >= capacity { + len -= capacity; + } else { + break; + } + self.vec_off += 1; + } + } +} + +impl Indexable for VectoredWrap { + type Output = BufWrap; + + fn index(&self, n: usize) -> Option<&Self::Output> { + self.wraps.get(n + self.vec_off) + } +} + +impl IndexableMut for VectoredWrap { + fn index_mut(&mut self, n: usize) -> Option<&mut Self::Output> { + self.wraps.get_mut(n + self.vec_off) + } +} + +impl IntoInner for VectoredWrap { + type Inner = T; + + fn into_inner(self) -> Self::Inner { + // Safety: no pointers still maintaining + *unsafe { Pin::into_inner_unchecked(self.buffers) } + } +} + +pub struct BufWrap { + ptr: *mut u8, + len: usize, + capacity: usize, +} + +impl BufWrap { + fn new(buf: &T) -> Self { + Self { + ptr: buf.as_buf_ptr().cast_mut(), + len: buf.buf_len(), + capacity: buf.buf_capacity(), + } + } + + fn offset(&self, off: usize) -> Self { + Self { + ptr: unsafe { self.ptr.add(off) }, + len: self.len.saturating_sub(off), + capacity: self.capacity.saturating_sub(off), + } + } +} + +unsafe impl IoBuf for BufWrap { + fn as_buf_ptr(&self) -> *const u8 { + self.ptr.cast_const() + } + + fn buf_len(&self) -> usize { + self.len + } + + fn buf_capacity(&self) -> usize { + self.capacity + } +} + +unsafe impl IoBufMut for BufWrap { + fn as_buf_mut_ptr(&mut self) -> *mut u8 { + self.ptr + } +} + +impl SetBufInit for BufWrap { + unsafe fn set_buf_init(&mut self, len: usize) { + debug_assert!(len <= self.capacity, "{} > {}", len, self.capacity); + self.len = self.len.max(len); + } +} diff --git a/compio-io/src/write/ext.rs b/compio-io/src/write/ext.rs index 0f5a3fdd..a4b07eb4 100644 --- a/compio-io/src/write/ext.rs +++ b/compio-io/src/write/ext.rs @@ -1,6 +1,6 @@ use compio_buf::{BufResult, IntoInner, IoBuf, IoVectoredBuf}; -use crate::{AsyncWrite, AsyncWriteAt, IoResult}; +use crate::{AsyncWrite, AsyncWriteAt, IoResult, vectored::VectoredWrap}; /// Shared code for write a scalar value into the underlying writer. macro_rules! write_scalar { @@ -33,64 +33,35 @@ macro_rules! write_scalar { /// Shared code for loop writing until all contents are written. macro_rules! loop_write_all { - ($buf:ident, $len:expr, $needle:ident,loop $expr_expr:expr) => { + ($buf:ident, $len:expr, $tracker:ident, $write_expr:expr, $buf_expr:expr) => { + let mut $tracker = 0usize; let len = $len; - let mut $needle = 0; - - while $needle < len { - match $expr_expr.await.into_inner() { - BufResult(Ok(0), buf) => { + while $tracker < len { + let BufResult(res, buf) = $write_expr; + $buf = buf; + match res { + Ok(0) => { return BufResult( Err(::std::io::Error::new( ::std::io::ErrorKind::WriteZero, "failed to write whole buffer", )), - buf, + $buf_expr, ); } - BufResult(Ok(n), buf) => { - $needle += n; - $buf = buf; - } - BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => { - $buf = buf; + Ok(n) => { + $tracker += n as usize; } - BufResult(Err(e), buf) => return BufResult(Err(e), buf), + Err(ref e) if e.kind() == ::std::io::ErrorKind::Interrupted => {} + Err(e) => return BufResult(Err(e), $buf_expr), } } - return BufResult(Ok(()), $buf); + return BufResult(Ok(()), $buf_expr); }; } macro_rules! loop_write_vectored { - ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{ - use ::compio_buf::OwnedIterator; - - let mut $iter = match $buf.owned_iter() { - Ok(buf) => buf, - Err(buf) => return BufResult(Ok(()), buf), - }; - let mut $tracker: $tracker_ty = 0; - - loop { - let len = $iter.buf_len(); - if len > 0 { - match $read_expr.await { - BufResult(Ok(()), ret) => { - $iter = ret; - $tracker += len as $tracker_ty; - } - BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()), - }; - } - - match $iter.next() { - Ok(next) => $iter = next, - Err(buf) => return BufResult(Ok(()), buf), - } - } - }}; ($buf:ident, $iter:ident, $read_expr:expr) => {{ use ::compio_buf::OwnedIterator; @@ -133,7 +104,8 @@ pub trait AsyncWriteExt: AsyncWrite { buf, buf.buf_len(), needle, - loop self.write(buf.slice(needle..)) + self.write(buf.slice(needle..)).await.into_inner(), + buf ); } @@ -141,7 +113,14 @@ pub trait AsyncWriteExt: AsyncWrite { /// [`AsyncWrite::write_vectored`], except that it tries to write the entire /// contents of the buffer into this writer. async fn write_vectored_all(&mut self, buf: T) -> BufResult<(), T> { - loop_write_vectored!(buf, _total: usize, iter, loop self.write_all(iter)) + let mut buf = VectoredWrap::new(buf); + loop_write_all!( + buf, + buf.len(), + needle, + self.write_vectored(buf).await, + buf.into_inner() + ); } write_scalar!(u8, to_be_bytes, to_le_bytes); @@ -171,7 +150,10 @@ pub trait AsyncWriteAtExt: AsyncWriteAt { buf, buf.buf_len(), needle, - loop self.write_at(buf.slice(needle..), pos + needle as u64) + self.write_at(buf.slice(needle..), pos + needle as u64) + .await + .into_inner(), + buf ); } @@ -182,7 +164,14 @@ pub trait AsyncWriteAtExt: AsyncWriteAt { buf: T, pos: u64, ) -> BufResult<(), T> { - loop_write_vectored!(buf, total: u64, iter, loop self.write_all_at(iter, pos + total)) + let mut buf = VectoredWrap::new(buf); + loop_write_all!( + buf, + buf.len(), + needle, + self.write_vectored_at(buf, pos + needle as u64).await, + buf.into_inner() + ); } } diff --git a/compio-io/tests/io.rs b/compio-io/tests/io.rs index fc3d12bd..0f3918b0 100644 --- a/compio-io/tests/io.rs +++ b/compio-io/tests/io.rs @@ -126,6 +126,30 @@ fn readv() { assert_eq!(len, 13); assert!(buf[0].is_empty()); assert_eq!(buf[1], [1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0]); + + let mut src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..]; + let ((), buf) = src + .read_vectored_exact([Vec::with_capacity(3), Vec::with_capacity(3)]) + .await + .unwrap(); + assert_eq!(buf[0], [1, 1, 4]); + assert_eq!(buf[1], [5, 1, 4]); + + let mut src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..]; + let ((), buf) = src + .read_vectored_exact([vec![], Vec::with_capacity(3)]) + .await + .unwrap(); + assert!(buf[0].is_empty()); + assert_eq!(buf[1], [1, 1, 4]); + + let mut src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..]; + let BufResult(res, buf) = src + .read_vectored_exact([Vec::with_capacity(10), Vec::with_capacity(10)]) + .await; + assert!(res.is_err()); + assert_eq!(buf[0], [1, 1, 4, 5, 1, 4, 1, 9, 1, 9]); + assert_eq!(buf[1], [8, 1, 0]); }) } @@ -160,6 +184,23 @@ fn writev() { assert_eq!(len, 6); assert_eq!(dst.len(), 6); assert_eq!(dst, [1, 1, 4, 5, 1, 4]); + + let mut dst = Cursor::new([0u8; 10]); + let ((), _) = dst + .write_vectored_all([vec![1, 1, 4], vec![5, 1, 4]]) + .await + .unwrap(); + + assert_eq!(dst.position(), 6); + assert_eq!(dst.into_inner(), [1, 1, 4, 5, 1, 4, 0, 0, 0, 0]); + + let mut dst = Cursor::new([0u8; 10]); + let BufResult(res, _) = dst + .write_vectored_all([vec![1, 1, 4, 5, 1, 4], vec![1, 9, 1, 9, 8, 1, 0]]) + .await; + + assert!(res.is_err()); + assert_eq!(dst.into_inner(), [1, 1, 4, 5, 1, 4, 1, 9, 1, 9]); }) } @@ -185,6 +226,22 @@ fn readv_at() { assert_eq!(len, 4); assert_eq!(buf[0].as_slice(), [4, 5, 1]); assert_eq!(buf[1].as_slice(), [4]); + + let ((), buf) = SRC + .read_vectored_exact_at([vec![0; 3], Vec::with_capacity(1)], 2) + .await + .unwrap(); + + assert_eq!(buf[0].as_slice(), [4, 5, 1]); + assert_eq!(buf[1].as_slice(), [4]); + + let BufResult(res, buf) = SRC + .read_vectored_exact_at([Vec::with_capacity(6), Vec::with_capacity(6)], 2) + .await; + + assert!(res.is_err()); + assert_eq!(buf[0].as_slice(), &SRC[2..]); + assert!(buf[1].is_empty()); }) } @@ -218,6 +275,22 @@ fn writev_at() { assert_eq!(len, 6); assert_eq!(dst.len(), 8); assert_eq!(dst, [0, 0, 1, 1, 4, 5, 1, 4]); + + let mut dst = [0u8; 10]; + let ((), _) = dst + .write_vectored_all_at([vec![1, 1, 4], vec![5, 1, 4]], 2) + .await + .unwrap(); + + assert_eq!(dst, [0, 0, 1, 1, 4, 5, 1, 4, 0, 0]); + + let mut dst = [0u8; 5]; + let BufResult(res, _) = dst + .write_vectored_all_at([vec![1, 1, 4], vec![5, 1, 4]], 2) + .await; + + assert!(res.is_err()); + assert_eq!(dst, [0, 0, 1, 1, 4]); }) }