Skip to content

Fix recv_uninit to handle the TRUNC flag. #1159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/net/send_recv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::net::SocketAddrUnix;
use crate::net::{SocketAddr, SocketAddrAny, SocketAddrV4, SocketAddrV6};
use crate::{backend, io};
use backend::fd::{AsFd, BorrowedFd};
use core::cmp::min;
use core::mem::MaybeUninit;

pub use backend::net::send_recv::{RecvFlags, SendFlags};
Expand Down Expand Up @@ -71,17 +72,24 @@ pub fn recv<Fd: AsFd>(fd: Fd, buf: &mut [u8], flags: RecvFlags) -> io::Result<us
/// This is equivalent to [`recv`], except that it can read into uninitialized
/// memory. It returns the slice that was initialized by this function and the
/// slice that remains uninitialized.
///
/// Because this interface returns the length via the returned slice, it's
/// unsable to return the untruncated length that would be returned when the
/// `RecvFlags::TRUNC` flag is used. If you need the untruncated length, use
/// [`recv`].
#[inline]
pub fn recv_uninit<Fd: AsFd>(
fd: Fd,
buf: &mut [MaybeUninit<u8>],
flags: RecvFlags,
) -> io::Result<(&mut [u8], &mut [MaybeUninit<u8>])> {
let length = unsafe {
backend::net::syscalls::recv(fd.as_fd(), buf.as_mut_ptr().cast::<u8>(), buf.len(), flags)
backend::net::syscalls::recv(fd.as_fd(), buf.as_mut_ptr().cast::<u8>(), buf.len(), flags)?
};

Ok(unsafe { split_init(buf, length?) })
// If the `TRUNC` flag is set, the returned `length` may be longer than the
// buffer length.
Ok(unsafe { split_init(buf, min(length, buf.len())) })
}

/// `send(fd, buf, flags)`—Writes data to a socket.
Expand Down Expand Up @@ -160,6 +168,11 @@ pub fn recvfrom<Fd: AsFd>(
/// This is equivalent to [`recvfrom`], except that it can read into
/// uninitialized memory. It returns the slice that was initialized by this
/// function and the slice that remains uninitialized.
///
/// Because this interface returns the length via the returned slice, it's
/// unsable to return the untruncated length that would be returned when the
/// `RecvFlags::TRUNC` flag is used. If you need the untruncated length, use
/// [`recvfrom`].
#[allow(clippy::type_complexity)]
#[inline]
pub fn recvfrom_uninit<Fd: AsFd>(
Expand All @@ -175,7 +188,10 @@ pub fn recvfrom_uninit<Fd: AsFd>(
flags,
)?
};
let (init, uninit) = unsafe { split_init(buf, length) };

// If the `TRUNC` flag is set, the returned `length` may be longer than the
// buffer length.
let (init, uninit) = unsafe { split_init(buf, min(length, buf.len())) };
Ok((init, uninit, addr))
}

Expand Down
2 changes: 2 additions & 0 deletions tests/net/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ mod connect_bind_send;
mod dgram;
#[cfg(feature = "event")]
mod poll;
#[cfg(unix)]
mod recv_trunc;
mod sockopt;
#[cfg(unix)]
mod unix;
Expand Down
29 changes: 29 additions & 0 deletions tests/net/recv_trunc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use rustix::net::{AddressFamily, RecvFlags, SendFlags, SocketAddrUnix, SocketType};
use std::mem::MaybeUninit;

/// Test `recv_uninit` with the `RecvFlags::Trunc` flag.
#[test]
fn net_recv_uninit_trunc() {
crate::init();

let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().join("recv_uninit_trunc");
let name = SocketAddrUnix::new(&path).unwrap();

let receiver = rustix::net::socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap();
rustix::net::bind_unix(&receiver, &name).expect("bind");

let sender = rustix::net::socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap();
let request = b"Hello, World!!!";
let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send");
assert_eq!(n, request.len());
drop(sender);

let mut response = [MaybeUninit::<u8>::zeroed(); 5];
let (init, uninit) =
rustix::net::recv_uninit(&receiver, &mut response, RecvFlags::TRUNC).expect("recv_uninit");

// We used the `TRUNC` flag, so we should have only gotten 5 bytes.
assert_eq!(init, b"Hello");
assert!(uninit.is_empty());
}
Loading