diff --git a/src/net/send_recv/mod.rs b/src/net/send_recv/mod.rs index 2bbf0cf7d..cf7b7fda2 100644 --- a/src/net/send_recv/mod.rs +++ b/src/net/send_recv/mod.rs @@ -10,6 +10,7 @@ use crate::net::SocketAddrUnix; use crate::net::{SocketAddr, SocketAddrAny, SocketAddrV4, SocketAddrV6}; use crate::{backend, io}; use backend::fd::{AsFd, BorrowedFd}; +use core::cmp::min; use core::mem::MaybeUninit; pub use backend::net::send_recv::{RecvFlags, SendFlags}; @@ -71,6 +72,11 @@ pub fn recv(fd: Fd, buf: &mut [u8], flags: RecvFlags) -> io::Result( fd: Fd, @@ -78,10 +84,12 @@ pub fn recv_uninit( flags: RecvFlags, ) -> io::Result<(&mut [u8], &mut [MaybeUninit])> { let length = unsafe { - backend::net::syscalls::recv(fd.as_fd(), buf.as_mut_ptr().cast::(), buf.len(), flags) + backend::net::syscalls::recv(fd.as_fd(), buf.as_mut_ptr().cast::(), buf.len(), flags)? }; - Ok(unsafe { split_init(buf, length?) }) + // If the `TRUNC` flag is set, the returned `length` may be longer than the + // buffer length. + Ok(unsafe { split_init(buf, min(length, buf.len())) }) } /// `send(fd, buf, flags)`—Writes data to a socket. @@ -160,6 +168,11 @@ pub fn recvfrom( /// This is equivalent to [`recvfrom`], except that it can read into /// uninitialized memory. It returns the slice that was initialized by this /// function and the slice that remains uninitialized. +/// +/// Because this interface returns the length via the returned slice, it's +/// unsable to return the untruncated length that would be returned when the +/// `RecvFlags::TRUNC` flag is used. If you need the untruncated length, use +/// [`recvfrom`]. #[allow(clippy::type_complexity)] #[inline] pub fn recvfrom_uninit( @@ -175,7 +188,10 @@ pub fn recvfrom_uninit( flags, )? }; - let (init, uninit) = unsafe { split_init(buf, length) }; + + // If the `TRUNC` flag is set, the returned `length` may be longer than the + // buffer length. + let (init, uninit) = unsafe { split_init(buf, min(length, buf.len())) }; Ok((init, uninit, addr)) } diff --git a/tests/net/main.rs b/tests/net/main.rs index 2c4b6a312..f2532c6ed 100644 --- a/tests/net/main.rs +++ b/tests/net/main.rs @@ -12,6 +12,8 @@ mod connect_bind_send; mod dgram; #[cfg(feature = "event")] mod poll; +#[cfg(unix)] +mod recv_trunc; mod sockopt; #[cfg(unix)] mod unix; diff --git a/tests/net/recv_trunc.rs b/tests/net/recv_trunc.rs new file mode 100644 index 000000000..c5b10468e --- /dev/null +++ b/tests/net/recv_trunc.rs @@ -0,0 +1,29 @@ +use rustix::net::{AddressFamily, RecvFlags, SendFlags, SocketAddrUnix, SocketType}; +use std::mem::MaybeUninit; + +/// Test `recv_uninit` with the `RecvFlags::Trunc` flag. +#[test] +fn net_recv_uninit_trunc() { + crate::init(); + + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("recv_uninit_trunc"); + let name = SocketAddrUnix::new(&path).unwrap(); + + let receiver = rustix::net::socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap(); + rustix::net::bind_unix(&receiver, &name).expect("bind"); + + let sender = rustix::net::socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap(); + let request = b"Hello, World!!!"; + let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send"); + assert_eq!(n, request.len()); + drop(sender); + + let mut response = [MaybeUninit::::zeroed(); 5]; + let (init, uninit) = + rustix::net::recv_uninit(&receiver, &mut response, RecvFlags::TRUNC).expect("recv_uninit"); + + // We used the `TRUNC` flag, so we should have only gotten 5 bytes. + assert_eq!(init, b"Hello"); + assert!(uninit.is_empty()); +}