Skip to content

Commit d1ddfc8

Browse files
committed
Add support for Packet MMAP
1 parent 4ed2ea2 commit d1ddfc8

File tree

15 files changed

+1260
-14
lines changed

15 files changed

+1260
-14
lines changed

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ once_cell = { version = "1.5.2", optional = true }
3636
# libc backend can be selected via adding `--cfg=rustix_use_libc` to
3737
# `RUSTFLAGS` or enabling the `use-libc` cargo feature.
3838
[target.'cfg(all(not(rustix_use_libc), not(miri), target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64"))))'.dependencies]
39-
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
39+
linux-raw-sys = { version = "0.6.4", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
4040
libc_errno = { package = "errno", version = "0.3.8", default-features = false, optional = true }
4141
libc = { version = "0.2.152", default-features = false, features = ["extra_traits"], optional = true }
4242

@@ -53,7 +53,7 @@ libc = { version = "0.2.152", default-features = false, features = ["extra_trait
5353
# Some syscalls do not have libc wrappers, such as in `io_uring`. For these,
5454
# the libc backend uses the linux-raw-sys ABI and `libc::syscall`.
5555
[target.'cfg(all(any(target_os = "android", target_os = "linux"), any(rustix_use_libc, miri, not(all(target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64")))))))'.dependencies]
56-
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "ioctl", "no_std"] }
56+
linux-raw-sys = { version = "0.6.4", default-features = false, features = ["general", "ioctl", "no_std"] }
5757

5858
# For the libc backend on Windows, use the Winsock API in windows-sys.
5959
[target.'cfg(windows)'.dependencies.windows-sys]
@@ -141,7 +141,7 @@ io_uring = ["event", "fs", "net", "linux-raw-sys/io_uring"]
141141
mount = []
142142

143143
# Enable `rustix::net::*`.
144-
net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/xdp"]
144+
net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/if_packet", "linux-raw-sys/xdp"]
145145

146146
# Enable `rustix::thread::*`.
147147
thread = ["linux-raw-sys/prctl"]

examples/packet/inner.rs

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
use rustix::event::{poll, PollFd, PollFlags};
2+
use rustix::fd::OwnedFd;
3+
use rustix::mm::{mmap, munmap, MapFlags, ProtFlags};
4+
use rustix::net::{
5+
bind_link, eth,
6+
netdevice::name_to_index,
7+
packet::{PacketHeader2, PacketReq, PacketReqAny, PacketStatus, SocketAddrLink},
8+
send, socket_with,
9+
sockopt::{set_packet_rx_ring, set_packet_tx_ring, set_packet_version, PacketVersion},
10+
AddressFamily, SendFlags, SocketFlags, SocketType,
11+
};
12+
use std::{cell::Cell, collections::VecDeque, env, ffi::c_void, io, ptr, slice, str};
13+
14+
#[derive(Debug)]
15+
pub struct Socket {
16+
fd: OwnedFd,
17+
block_size: usize,
18+
block_count: usize,
19+
frame_size: usize,
20+
frame_count: usize,
21+
rx: Cell<*mut c_void>,
22+
tx: Cell<*mut c_void>,
23+
}
24+
25+
impl Socket {
26+
fn new(
27+
name: &str,
28+
block_size: usize,
29+
block_count: usize,
30+
frame_size: usize,
31+
) -> io::Result<Self> {
32+
let family = AddressFamily::PACKET;
33+
let type_ = SocketType::RAW;
34+
let flags = SocketFlags::empty();
35+
let fd = socket_with(family, type_, flags, None)?;
36+
37+
let index = name_to_index(&fd, name)?;
38+
39+
set_packet_version(&fd, PacketVersion::V2)?;
40+
41+
let frame_count = (block_size * block_count) / frame_size;
42+
let req = PacketReq {
43+
block_size: block_size as u32,
44+
block_nr: block_count as u32,
45+
frame_size: frame_size as u32,
46+
frame_nr: frame_count as u32,
47+
};
48+
49+
let req = PacketReqAny::V2(req);
50+
set_packet_rx_ring(&fd, &req)?;
51+
set_packet_tx_ring(&fd, &req)?;
52+
53+
let addr = SocketAddrLink::new(eth::ALL, index);
54+
bind_link(&fd, &addr)?;
55+
56+
let rx = unsafe {
57+
mmap(
58+
ptr::null_mut(),
59+
block_size * block_count * 2,
60+
ProtFlags::READ | ProtFlags::WRITE,
61+
MapFlags::SHARED,
62+
&fd,
63+
0,
64+
)
65+
}?;
66+
let tx = unsafe { rx.add(block_size * block_count) };
67+
68+
Ok(Self {
69+
fd,
70+
block_size,
71+
block_count,
72+
frame_size,
73+
frame_count,
74+
rx: Cell::new(rx),
75+
tx: Cell::new(tx),
76+
})
77+
}
78+
79+
/// Returns a reader object for receiving packets.
80+
pub fn reader(&self) -> Reader<'_> {
81+
assert!(!self.rx.get().is_null());
82+
Reader {
83+
socket: self,
84+
// Take ring pointer.
85+
ring: self.rx.replace(ptr::null_mut()),
86+
}
87+
}
88+
89+
/// Returns a writer object for transmitting packets.
90+
pub fn writer(&self) -> Writer<'_> {
91+
assert!(!self.tx.get().is_null());
92+
Writer {
93+
socket: self,
94+
// Take ring pointer.
95+
ring: self.tx.replace(ptr::null_mut()),
96+
}
97+
}
98+
99+
/// Flushes the transmit buffer.
100+
pub fn flush(&self) -> io::Result<()> {
101+
send(&self.fd, &[], SendFlags::empty())?;
102+
Ok(())
103+
}
104+
}
105+
106+
impl Drop for Socket {
107+
fn drop(&mut self) {
108+
debug_assert!(!self.rx.get().is_null());
109+
debug_assert!(!self.tx.get().is_null());
110+
unsafe {
111+
let _ = munmap(self.rx.get(), self.block_size * self.block_count * 2);
112+
}
113+
}
114+
}
115+
116+
/// TODO
117+
#[derive(Debug)]
118+
pub struct Packet<'r> {
119+
header: &'r mut PacketHeader2,
120+
}
121+
122+
impl<'r> Packet<'r> {
123+
pub fn payload(&self) -> &[u8] {
124+
let ptr = self.header.payload_rx();
125+
let len = self.header.len as usize;
126+
unsafe { slice::from_raw_parts(ptr, len) }
127+
}
128+
}
129+
130+
impl<'r> Drop for Packet<'r> {
131+
fn drop(&mut self) {
132+
self.header.status = PacketStatus::empty();
133+
}
134+
}
135+
136+
/// TODO
137+
#[derive(Debug)]
138+
pub struct Slot<'w> {
139+
header: &'w mut PacketHeader2,
140+
}
141+
142+
impl<'w> Slot<'w> {
143+
pub fn write(&mut self, payload: &[u8]) {
144+
let ptr = self.header.payload_tx();
145+
// TODO verify length
146+
let len = payload.len();
147+
unsafe {
148+
ptr.copy_from_nonoverlapping(payload.as_ptr(), len);
149+
self.header.len = len as u32;
150+
}
151+
}
152+
}
153+
154+
impl<'w> Drop for Slot<'w> {
155+
fn drop(&mut self) {
156+
self.header.status = PacketStatus::SEND_REQUEST;
157+
}
158+
}
159+
160+
/// A reader object for receiving packets.
161+
#[derive(Debug)]
162+
pub struct Reader<'s> {
163+
socket: &'s Socket,
164+
ring: *mut c_void, // Owned
165+
}
166+
167+
impl<'s> Reader<'s> {
168+
/// Returns an iterator over received packets.
169+
/// The iterator blocks until at least one packet is received.
170+
///
171+
/// # Lifetimes
172+
///
173+
/// - `'s`: The lifetime of the socket.
174+
/// - `'r`: The lifetime of the received packets.
175+
pub fn wait<'r>(&'r mut self) -> io::Result<ReadIter<'s, 'r>>
176+
where
177+
's: 'r,
178+
{
179+
let flags = PollFlags::IN | PollFlags::RDNORM | PollFlags::ERR;
180+
let pfd = PollFd::new(&self.socket.fd, flags);
181+
let pfd = &mut [pfd];
182+
let n = poll(pfd, -1)?;
183+
assert_eq!(n, 1);
184+
Ok(ReadIter {
185+
reader: self,
186+
index: 0,
187+
})
188+
}
189+
}
190+
191+
impl<'s> Drop for Reader<'s> {
192+
fn drop(&mut self) {
193+
// Give back ring pointer.
194+
self.socket.rx.set(self.ring);
195+
}
196+
}
197+
198+
/// A writer object for transmitting packets.
199+
#[derive(Debug)]
200+
pub struct Writer<'s> {
201+
socket: &'s Socket,
202+
ring: *mut c_void, // Owned
203+
}
204+
205+
impl<'s> Writer<'s> {
206+
/// Returns an iterator over available slots for transmitting packets.
207+
/// The iterator blocks until at least one slot is available.
208+
///
209+
/// # Lifetimes
210+
///
211+
/// - `'s`: The lifetime of the socket.
212+
/// - `'w`: The lifetime of the slots.
213+
pub fn wait<'w>(&'w mut self) -> io::Result<WriteIter<'s, 'w>>
214+
where
215+
's: 'w,
216+
{
217+
let flags = PollFlags::OUT | PollFlags::WRNORM | PollFlags::ERR;
218+
let pfd = PollFd::new(&self.socket.fd, flags);
219+
let pfd = &mut [pfd];
220+
let n = poll(pfd, -1)?;
221+
assert_eq!(n, 1);
222+
Ok(WriteIter {
223+
writer: self,
224+
index: 0,
225+
})
226+
}
227+
}
228+
229+
impl<'s> Drop for Writer<'s> {
230+
fn drop(&mut self) {
231+
// Give back ring pointer.
232+
self.socket.tx.set(self.ring);
233+
}
234+
}
235+
236+
/// An iterator over received packets.
237+
#[derive(Debug)]
238+
pub struct ReadIter<'s, 'r> {
239+
reader: &'r mut Reader<'s>,
240+
index: usize,
241+
}
242+
243+
impl<'s, 'r> Iterator for ReadIter<'s, 'r> {
244+
type Item = Packet<'r>;
245+
246+
fn next(&mut self) -> Option<Self::Item> {
247+
while self.index < self.reader.socket.frame_count {
248+
let base = unsafe {
249+
self.reader
250+
.ring
251+
.add(self.index * self.reader.socket.frame_size)
252+
};
253+
self.index += 1;
254+
255+
if let Some(header) = unsafe { PacketHeader2::from_rx_ptr(base) } {
256+
return Some(Packet { header });
257+
}
258+
}
259+
None
260+
}
261+
}
262+
263+
/// An iterator over available slots for transmitting packets.
264+
#[derive(Debug)]
265+
pub struct WriteIter<'s, 'w> {
266+
writer: &'w mut Writer<'s>,
267+
index: usize,
268+
}
269+
270+
impl<'s, 'w> Iterator for WriteIter<'s, 'w> {
271+
type Item = Slot<'w>;
272+
273+
fn next(&mut self) -> Option<Self::Item> {
274+
while self.index < self.writer.socket.frame_count {
275+
let base = unsafe {
276+
self.writer
277+
.ring
278+
.add(self.index * self.writer.socket.frame_size)
279+
};
280+
self.index += 1;
281+
282+
if let Some(header) = unsafe { PacketHeader2::from_tx_ptr(base) } {
283+
return Some(Slot { header });
284+
}
285+
}
286+
None
287+
}
288+
}
289+
290+
// ECHO server
291+
fn server(socket: Socket, mut count: usize) -> io::Result<()> {
292+
let mut reader = socket.reader();
293+
let mut writer = socket.writer();
294+
295+
while count > 0 {
296+
let mut queue = VecDeque::new();
297+
298+
for packet in reader.wait()? {
299+
queue.push_back(packet);
300+
}
301+
302+
while let Some(packet) = queue.pop_front() {
303+
let mut iter = writer.wait()?.take(count);
304+
while let Some(mut slot) = iter.next() {
305+
let mut payload = packet.payload().to_vec();
306+
assert_eq!(payload[12..14], [0x08, 0x00]);
307+
payload.swap(14, 15);
308+
309+
slot.write(&payload);
310+
drop(slot);
311+
count -= 1;
312+
}
313+
drop(packet);
314+
}
315+
316+
socket.flush()?;
317+
}
318+
319+
Ok(())
320+
}
321+
322+
// ECHO client
323+
fn client(socket: Socket, mut count: usize) -> io::Result<()> {
324+
let mut reader = socket.reader();
325+
let mut writer = socket.writer();
326+
327+
while count > 0 {
328+
let mut iter = writer.wait()?.take(count);
329+
while let Some(mut slot) = iter.next() {
330+
let payload = &[
331+
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // Destination
332+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Source
333+
0x08, 0x00, // Type (IPv4, but not really)
334+
0x13, 0x37, // Payload (some value)
335+
];
336+
337+
slot.write(payload);
338+
drop(slot);
339+
count -= 1;
340+
}
341+
342+
socket.flush()?;
343+
344+
for packet in reader.wait()? {
345+
assert_eq!(packet.payload()[14..16], [0x37, 0x13]);
346+
}
347+
}
348+
349+
Ok(())
350+
}
351+
352+
pub fn main() -> io::Result<()> {
353+
let mut args = env::args().skip(1);
354+
let name = args.next().expect("name");
355+
let mode = args.next().expect("mode");
356+
let count = args.next().expect("count");
357+
358+
let socket = Socket::new(&name, 4096, 4, 2048)?;
359+
let count = count.parse().unwrap();
360+
361+
match mode.as_str() {
362+
"server" => server(socket, count),
363+
"client" => client(socket, count),
364+
_ => panic!("invalid mode"),
365+
}
366+
}

0 commit comments

Comments
 (0)