Skip to content

Commit 1eeec11

Browse files
committed
Add peek to TcpStream.
1 parent fb48ea4 commit 1eeec11

File tree

3 files changed

+206
-2
lines changed

3 files changed

+206
-2
lines changed

src/net/tcp/split_owned.rs

+13
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ impl OwnedReadHalf {
3636
pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
3737
reunite(self, other)
3838
}
39+
40+
/// Attempt to receive data on the socket, without removing that data from the queue, registering the current task for wakeup if data is not yet available.
41+
pub fn poll_peek(
42+
mut self: Pin<&mut Self>,
43+
cx: &mut Context<'_>,
44+
buf: &mut ReadBuf,
45+
) -> Poll<io::Result<usize>> {
46+
Pin::new(&mut self.inner).poll_peek(cx, buf)
47+
}
48+
49+
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
50+
self.inner.peek(buf).await
51+
}
3952
}
4053

4154
/// Owned write half of a `TcpStream`, created by `into_split`.

src/net/tcp/stream.rs

+54-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use bytes::{Buf, Bytes};
2+
use std::future::poll_fn;
13
use std::{
24
fmt::Debug,
35
io::{self, Error, Result},
@@ -6,8 +8,6 @@ use std::{
68
sync::Arc,
79
task::{ready, Context, Poll},
810
};
9-
10-
use bytes::{Buf, Bytes};
1111
use tokio::{
1212
io::{AsyncRead, AsyncWrite, ReadBuf},
1313
runtime::Handle,
@@ -165,6 +165,16 @@ impl TcpStream {
165165
pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
166166
Ok(())
167167
}
168+
169+
/// Receives data on the socket from the remote address to which it is connected,
170+
/// without removing that data from the queue. On success, returns the number of bytes peeked.
171+
pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
172+
self.read_half.peek(buf).await
173+
}
174+
175+
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
176+
self.read_half.poll_peek(cx, buf)
177+
}
168178
}
169179

170180
pub(crate) struct ReadHalf {
@@ -234,6 +244,48 @@ impl ReadHalf {
234244
Some(avail)
235245
}
236246
}
247+
248+
pub(crate) fn poll_peek(
249+
&mut self,
250+
cx: &mut Context<'_>,
251+
buf: &mut ReadBuf,
252+
) -> Poll<Result<usize>> {
253+
if self.is_closed || buf.capacity() == 0 {
254+
return Poll::Ready(Ok(0));
255+
}
256+
257+
// If we have buffered data, peek from it
258+
if let Some(bytes) = &self.rx.buffer {
259+
let len = std::cmp::min(bytes.len(), buf.remaining());
260+
buf.put_slice(&bytes[..len]);
261+
return Poll::Ready(Ok(len));
262+
}
263+
264+
match ready!(self.rx.recv.poll_recv(cx)) {
265+
Some(seg) => match seg {
266+
SequencedSegment::Data(bytes) => {
267+
let len = std::cmp::min(bytes.len(), buf.remaining());
268+
buf.put_slice(&bytes[..len]);
269+
self.rx.buffer = Some(bytes);
270+
271+
Poll::Ready(Ok(len))
272+
}
273+
SequencedSegment::Fin => {
274+
self.is_closed = true;
275+
Poll::Ready(Ok(0))
276+
}
277+
},
278+
None => Poll::Ready(Err(io::Error::new(
279+
io::ErrorKind::ConnectionReset,
280+
"Connection reset",
281+
))),
282+
}
283+
}
284+
285+
pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
286+
let mut buf = ReadBuf::new(buf);
287+
poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
288+
}
237289
}
238290

239291
impl Debug for ReadHalf {

tests/tcp.rs

+139
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,145 @@ fn split() -> Result {
749749
sim.run()
750750
}
751751

752+
#[test]
753+
fn peek_empty_buffer() -> Result {
754+
let mut sim = Builder::new().build();
755+
756+
sim.client("server", async move {
757+
let listener = bind().await?;
758+
let _ = listener.accept().await?;
759+
Ok(())
760+
});
761+
762+
sim.client("client", async move {
763+
let mut s = TcpStream::connect(("server", PORT)).await?;
764+
765+
// no-op peek with empty buffer
766+
let mut buf = [0; 0];
767+
let n = s.peek(&mut buf).await?;
768+
assert_eq!(0, n);
769+
770+
Ok(())
771+
});
772+
773+
sim.run()
774+
}
775+
776+
#[test]
777+
fn peek_then_read() -> Result {
778+
let mut sim = Builder::new().build();
779+
780+
sim.client("server", async move {
781+
let listener = bind().await?;
782+
let (mut s, _) = listener.accept().await?;
783+
784+
s.write_u64(1234).await?;
785+
Ok(())
786+
});
787+
788+
sim.client("client", async move {
789+
let mut s = TcpStream::connect(("server", PORT)).await?;
790+
791+
// peek full message
792+
let mut peek_buf = [0; 8];
793+
assert_eq!(8, s.peek(&mut peek_buf).await?);
794+
assert_eq!(1234u64, u64::from_be_bytes(peek_buf));
795+
796+
// peek again should see same data
797+
let mut peek_buf2 = [0; 8];
798+
assert_eq!(8, s.peek(&mut peek_buf2).await?);
799+
assert_eq!(1234u64, u64::from_be_bytes(peek_buf2));
800+
801+
// read should consume the data
802+
assert_eq!(1234, s.read_u64().await?);
803+
let mut buf = [0; 8];
804+
assert!(matches!(s.read(&mut buf).await, Ok(0)));
805+
806+
Ok(())
807+
});
808+
809+
sim.run()
810+
}
811+
812+
#[test]
813+
fn peek_partial() -> Result {
814+
let mut sim = Builder::new().build();
815+
816+
sim.client("server", async move {
817+
let listener = bind().await?;
818+
let (mut s, _) = listener.accept().await?;
819+
820+
s.write_all(&[0, 0, 1, 1]).await?;
821+
Ok(())
822+
});
823+
824+
sim.client("client", async move {
825+
let mut s = TcpStream::connect(("server", PORT)).await?;
826+
827+
// peek with smaller buffer
828+
let mut peek_buf = [0; 2];
829+
assert_eq!(2, s.peek(&mut peek_buf).await?);
830+
assert_eq!([0, 0], peek_buf);
831+
832+
// peek with larger buffer should still see all data
833+
let mut peek_buf2 = [0; 4];
834+
assert_eq!(4, s.peek(&mut peek_buf2).await?);
835+
assert_eq!([0, 0, 1, 1], peek_buf2);
836+
837+
// read partial
838+
let mut read_buf = [0; 2];
839+
assert_eq!(2, s.read(&mut read_buf).await?);
840+
assert_eq!([0, 0], read_buf);
841+
842+
// peek remaining
843+
let mut peek_buf3 = [0; 2];
844+
assert_eq!(2, s.peek(&mut peek_buf3).await?);
845+
assert_eq!([1, 1], peek_buf3);
846+
847+
Ok(())
848+
});
849+
850+
sim.run()
851+
}
852+
853+
#[test]
854+
fn peek_multiple_messages() -> Result {
855+
let mut sim = Builder::new().build();
856+
857+
sim.client("server", async move {
858+
let listener = bind().await?;
859+
let (mut s, _) = listener.accept().await?;
860+
861+
s.write_u64(1234).await?;
862+
s.write_u64(5678).await?;
863+
Ok(())
864+
});
865+
866+
sim.client("client", async move {
867+
let mut s = TcpStream::connect(("server", PORT)).await?;
868+
869+
// peek first message
870+
let mut peek_buf = [0; 8];
871+
assert_eq!(8, s.peek(&mut peek_buf).await?);
872+
assert_eq!(1234u64, u64::from_be_bytes(peek_buf));
873+
874+
// read first message
875+
assert_eq!(1234, s.read_u64().await?);
876+
877+
// peek second message
878+
let mut peek_buf2 = [0; 8];
879+
assert_eq!(8, s.peek(&mut peek_buf2).await?);
880+
assert_eq!(5678u64, u64::from_be_bytes(peek_buf2));
881+
882+
// read second message
883+
assert_eq!(5678, s.read_u64().await?);
884+
885+
Ok(())
886+
});
887+
888+
sim.run()
889+
}
890+
752891
// # IpVersion specific tests
753892

754893
#[test]

0 commit comments

Comments
 (0)