Skip to content

refactor: Remove dyn Any usage in BufferElem #1672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ take-static = "0.1"
talc = { version = "4" }
time = { version = "0.3", default-features = false }
volatile = "0.6"
zerocopy = { version = "0.8", default-features = false }
zerocopy = { version = "0.8", features = ["derive"], default-features = false }
uhyve-interface = "0.1.3"

[dependencies.smoltcp]
Expand Down
25 changes: 12 additions & 13 deletions src/drivers/fs/virtio_fs.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use alloc::boxed::Box;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use core::str;
Expand All @@ -8,6 +7,7 @@ use virtio::FeatureBits;
use virtio::fs::ConfigVolatileFieldAccess;
use volatile::VolatileRef;
use volatile::access::ReadOnly;
use zerocopy::{FromBytes, Immutable, IntoBytes};

use crate::config::VIRTIO_MAX_QUEUE_SIZE;
use crate::drivers::Driver;
Expand Down Expand Up @@ -158,38 +158,37 @@ impl FuseInterface for VirtioFsDriver {
rsp_payload_len: u32,
) -> Result<fuse::Rsp<O>, VirtqError>
where
<O as fuse::ops::Op>::InStruct: Send,
<O as fuse::ops::Op>::OutStruct: Send,
<O as fuse::ops::Op>::InStruct: Send + IntoBytes + Immutable,
<O as fuse::ops::Op>::OutStruct: Send + FromBytes,
{
let fuse::Cmd {
headers: cmd_headers,
payload: cmd_payload_opt,
} = cmd;
let send = if let Some(cmd_payload) = cmd_payload_opt {
vec![
BufferElem::Sized(cmd_headers),
BufferElem::Vector(cmd_payload),
]
vec![BufferElem::from(cmd_headers), BufferElem(cmd_payload)]
} else {
vec![BufferElem::Sized(cmd_headers)]
vec![BufferElem::from(cmd_headers)]
};

let rsp_headers = Box::<RspHeader<O>, _>::new_uninit_in(DeviceAlloc);
let recv = if rsp_payload_len == 0 {
vec![BufferElem::Sized(rsp_headers)]
vec![BufferElem::new_uninit::<RspHeader<O>>()]
} else {
let rsp_payload = Vec::with_capacity_in(rsp_payload_len as usize, DeviceAlloc);
vec![
BufferElem::Sized(rsp_headers),
BufferElem::Vector(rsp_payload),
BufferElem::new_uninit::<RspHeader<O>>(),
BufferElem(rsp_payload),
]
};

let buffer_tkn = AvailBufferToken::new(send, recv).unwrap();
let mut transfer_result =
self.vqueues[1].dispatch_blocking(buffer_tkn, BufferType::Direct)?;

let headers = transfer_result.used_recv_buff.pop_front_downcast().unwrap();
let headers = transfer_result
.used_recv_buff
.pop_front_deserialize()
.unwrap();
let payload = transfer_result.used_recv_buff.pop_front_vec();
Ok(Rsp { headers, payload })
}
Expand Down
19 changes: 8 additions & 11 deletions src/drivers/net/virtio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ cfg_if::cfg_if! {
}
}

use alloc::boxed::Box;
use alloc::vec::Vec;

use smoltcp::phy::{Checksum, ChecksumCapabilities};
Expand Down Expand Up @@ -116,8 +115,8 @@ fn fill_queue(vq: &mut dyn Virtq, num_packets: u16, packet_size: u32) {
let buff_tkn = match AvailBufferToken::new(
vec![],
vec![
BufferElem::Sized(Box::<Hdr, _>::new_uninit_in(DeviceAlloc)),
BufferElem::Vector(Vec::with_capacity_in(
BufferElem::new_uninit::<Hdr>(),
BufferElem(Vec::with_capacity_in(
packet_size.try_into().unwrap(),
DeviceAlloc,
)),
Expand Down Expand Up @@ -256,7 +255,7 @@ impl NetworkDriver for VirtioNetDriver {
result
};

let mut header = Box::new_in(<Hdr as Default>::default(), DeviceAlloc);
let mut header = Hdr::default();
// If a checksum isn't necessary, we have inform the host within the header
// see Virtio specification 5.1.6.2
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
Expand Down Expand Up @@ -291,11 +290,9 @@ impl NetworkDriver for VirtioNetDriver {
.into();
}

let buff_tkn = AvailBufferToken::new(
vec![BufferElem::Sized(header), BufferElem::Vector(packet)],
vec![],
)
.unwrap();
let buff_tkn =
AvailBufferToken::new(vec![BufferElem::from(header), BufferElem(packet)], vec![])
.unwrap();

self.send_vqs.vqs[0]
.dispatch(buff_tkn, false, BufferType::Direct)
Expand All @@ -309,7 +306,7 @@ impl NetworkDriver for VirtioNetDriver {
RxQueues::post_processing(&mut buffer_tkn)
.inspect_err(|vnet_err| warn!("Post processing failed. Err: {vnet_err:?}"))
.ok()?;
let first_header = buffer_tkn.used_recv_buff.pop_front_downcast::<Hdr>()?;
let first_header = buffer_tkn.used_recv_buff.pop_front_deserialize::<Hdr>()?;
let first_packet = buffer_tkn.used_recv_buff.pop_front_vec()?;
trace!("Header: {first_header:?}");

Expand All @@ -329,7 +326,7 @@ impl NetworkDriver for VirtioNetDriver {
RxQueues::post_processing(&mut buffer_tkn)
.inspect_err(|vnet_err| warn!("Post processing failed. Err: {vnet_err:?}"))
.ok()?;
let _header = buffer_tkn.used_recv_buff.pop_front_downcast::<Hdr>()?;
let _header = buffer_tkn.used_recv_buff.pop_front_deserialize::<Hdr>()?;
let packet = buffer_tkn.used_recv_buff.pop_front_vec()?;
packets.push(packet);
}
Expand Down
128 changes: 64 additions & 64 deletions src/drivers/virtio/virtqueue/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ pub mod split;
use alloc::boxed::Box;
use alloc::collections::vec_deque::VecDeque;
use alloc::vec::Vec;
use core::any::Any;
use core::mem::MaybeUninit;
use core::{mem, ptr};
use core::mem;

use memory_addresses::VirtAddr;
use virtio::{le32, le64, pvirtq, virtq};
use zerocopy::{Immutable, IntoBytes};

use self::error::VirtqError;
#[cfg(not(feature = "pci"))]
Expand Down Expand Up @@ -265,7 +264,7 @@ trait VirtqPrivate {
.chain(recv_desc_iter)
.map(|(mem_descr, len, incomplete_flags)| {
Self::Descriptor::incomplete_desc(
paging::virt_to_phys(VirtAddr::from_ptr(mem_descr.addr()))
paging::virt_to_phys(VirtAddr::from_ptr(mem_descr.as_ptr()))
.as_u64()
.into(),
len.into(),
Expand Down Expand Up @@ -344,41 +343,49 @@ impl<Descriptor> TransferToken<Descriptor> {
}

#[derive(Debug)]
pub enum BufferElem {
Sized(Box<dyn Any + Send, DeviceAlloc>),
Vector(Vec<u8, DeviceAlloc>),
}
pub struct BufferElem(pub Vec<u8, DeviceAlloc>);

impl BufferElem {
// Returns the initialized length of the element. Assumes [Self::Sized] to
// be initialized, since the type of the object is erased and we cannot
// detect if the content is actually a [MaybeUninit]. However, this function
// should be only relevant for read buffer elements, which should not be uninit.
// If the element belongs to a write buffer, it is likely that [Self::capacity]
// is more appropriate.
/// Returns the initialized length of the element.
pub fn len(&self) -> u32 {
match self {
BufferElem::Sized(sized) => mem::size_of_val(sized.as_ref()),
BufferElem::Vector(vec) => vec.len(),
}
.try_into()
.unwrap()
self.0.len().try_into().unwrap()
}

/// Returns the allocated capacity of the element.
pub fn capacity(&self) -> u32 {
match self {
BufferElem::Sized(sized) => mem::size_of_val(sized.as_ref()),
BufferElem::Vector(vec) => vec.capacity(),
}
.try_into()
.unwrap()
self.0.capacity().try_into().unwrap()
}

pub fn addr(&self) -> *const u8 {
match self {
BufferElem::Sized(sized) => ptr::from_ref(sized.as_ref()).cast::<u8>(),
BufferElem::Vector(vec) => vec.as_ptr(),
}
/// Returns a pointer to the buffer.
pub fn as_ptr(&self) -> *const u8 {
self.0.as_ptr()
}

/// Helper method to create a [`BufferElem`] that pre-allocates capacity
/// for a given element of type `T`. This ensures the buffer is aligned
/// to the same boundaries as `T`.
pub fn new_uninit<T>() -> Self {
let uninit_mem = Box::<T, _>::new_uninit_in(DeviceAlloc);
// SAFETY: Length is 0 because it's uninit, capacity matches the memory amount that the Box allocated.
// The pointer was allocated with the same allocator: DeviceAlloc.
let uninit_vec = unsafe {
Vec::from_raw_parts_in(
Box::into_raw(uninit_mem).cast(),
0,
size_of::<T>(),
DeviceAlloc,
)
};
Self(uninit_vec)
}
}

impl<T> From<T> for BufferElem
where
T: IntoBytes + Immutable,
{
fn from(value: T) -> Self {
Self(value.as_bytes().to_vec_in(DeviceAlloc))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a copy that we were previously able to avoid.

}
}

Expand Down Expand Up @@ -419,46 +426,39 @@ pub(crate) struct UsedDeviceWritableBuffer {
}

impl UsedDeviceWritableBuffer {
pub fn pop_front_downcast<T>(&mut self) -> Option<Box<T, DeviceAlloc>>
where
T: Any,
{
pub fn pop_front_deserialize<T>(&mut self) -> Option<Box<T, DeviceAlloc>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are giving up some type safety here. As long as the alignment matches, we can cast any buffer into a box of any type (or a vector in the case of pop_front_vec), right?

Copy link
Member Author

@Gelbpunkt Gelbpunkt Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't really have any "type safety" before. The behaviour change is that previously, it was e.g. possible to have [BufferElem::Sized, BufferElem::Vector] (pseudo-code) and then calling pop_front_vector would return None, while now it won't complain and instead happily return you a Vec<u8>. That isn't necessarily wrong and all conversions here are safe. Deserializing into the required type by the caller is left to the caller, we only ensure that the deserialization is sound.

If you had the order wrong before, e.g. receive_packet would return None, which is arguably wrong since it would not really detect that there was an issue in the code. Now, it lets you treat the buffers as what they are (buffers), and you cannot really "mishandle" it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the downcasting operation would check the variant (and the type ID in the case of the Sized variant) of the box during downcasting and would error if a dynamic box of incorrect type was provided (The vectors are a special case as Rust did not [and probably still does not] allow us to handle Box<[u8]>s as Box<dyn Any>).

I think it is better to catch early when we are using a buffer of incorrect type. For example, we could forget to pop the header of a read operation and try to pop a vector from it directly. Currently, this will result in an error as the read buffer should be of the Vector variant but the first buffer is BufferElement::Sized. With the proposed changes, a vector also containing the VIRTIO header will be returned. If that buffer is passed on to the network stack, for example, the error will happen there as what should only be the frame will contain the VIRTIO header content that is not expected and cannot be handled by the network stack. Errors in the networking stack are particularly annoying as transmission errors are expected and result in repeated attempts instead of outright errors.

Another scenario would be the queue returning the incorrect transfer to the caller because of a buffer ID mix-up. The type ID check can allow us to catch early when we are given the result of an open operation, for example, when we were expecting a write result. Because the header type IDs are different, the downcast would return None.

Of course, such hypotheticals depend on us making mistakes in our code but we do make mistakes and early error catching precautions make debugging much easier.

if self.remaining_written_len < u32::try_from(size_of::<T>()).unwrap() {
return None;
}

let elem = self.elems.pop_front()?;
if let BufferElem::Sized(sized) = elem {
match sized.downcast::<MaybeUninit<T>>() {
Ok(cast) => {
self.remaining_written_len -= u32::try_from(size_of::<T>()).unwrap();
Some(unsafe { cast.assume_init() })
}
Err(sized) => {
self.elems.push_front(BufferElem::Sized(sized));
None
}
}
} else {
self.elems.push_front(elem);
None
}
let BufferElem(buf) = self.elems.pop_front()?;
self.remaining_written_len -= u32::try_from(size_of::<T>()).unwrap();

// Ensure the buffer is aligned to T. This is the case if it was created via
// [`BufferElem::new_uninit::<T>`] and should always be the case, but since
// it is technically possible to construct an unaligned buffer and use that,
// we should check it.
assert!(
buf.as_ptr().addr() % align_of::<T>() == 0,
"Attempted to deserialize buffer as type with different alignment"
);

// SAFETY: Management of the memory is transferred from the Vec to the Box
// Both heap allocations were made with the same alloc: DeviceAlloc
// The alignment was checked manually before.
Some(unsafe { Box::from_raw_in(buf.into_raw_parts().0.cast(), DeviceAlloc) })
}

pub fn pop_front_vec(&mut self) -> Option<Vec<u8, DeviceAlloc>> {
let elem = self.elems.pop_front()?;
if let BufferElem::Vector(mut vector) = elem {
let new_len = u32::min(
vector.capacity().try_into().unwrap(),
self.remaining_written_len,
);
self.remaining_written_len -= new_len;
unsafe { vector.set_len(new_len.try_into().unwrap()) };
Some(vector)
} else {
self.elems.push_front(elem);
None
}
let BufferElem(mut vector) = self.elems.pop_front()?;
let new_len = u32::min(
vector.capacity().try_into().unwrap(),
self.remaining_written_len,
);
self.remaining_written_len -= new_len;
unsafe { vector.set_len(new_len.try_into().unwrap()) };

Some(vector)
}
}

Expand Down
9 changes: 4 additions & 5 deletions src/drivers/vsock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#[cfg(feature = "pci")]
pub mod pci;

use alloc::boxed::Box;
use alloc::vec::Vec;
use core::mem;

Expand All @@ -29,8 +28,8 @@ fn fill_queue(vq: &mut dyn Virtq, num_packets: u16, packet_size: u32) {
let buff_tkn = match AvailBufferToken::new(
vec![],
vec![
BufferElem::Sized(Box::<Hdr, _>::new_uninit_in(DeviceAlloc)),
BufferElem::Vector(Vec::with_capacity_in(
BufferElem::new_uninit::<Hdr>(),
BufferElem(Vec::with_capacity_in(
packet_size.try_into().unwrap(),
DeviceAlloc,
)),
Expand Down Expand Up @@ -99,7 +98,7 @@ impl RxQueue {
while let Some(mut buffer_tkn) = self.get_next() {
let header = buffer_tkn
.used_recv_buff
.pop_front_downcast::<Hdr>()
.pop_front_deserialize::<Hdr>()
.unwrap();
let packet = buffer_tkn.used_recv_buff.pop_front_vec().unwrap();

Expand Down Expand Up @@ -170,7 +169,7 @@ impl TxQueue {
result
};

let buff_tkn = AvailBufferToken::new(vec![BufferElem::Vector(packet)], vec![]).unwrap();
let buff_tkn = AvailBufferToken::new(vec![BufferElem(packet)], vec![]).unwrap();

vq.dispatch(buff_tkn, false, BufferType::Direct).unwrap();

Expand Down
Loading
Loading