Skip to content

Commit 2d86dbe

Browse files
committed
Support nonblocking reads,
1 parent 29a5509 commit 2d86dbe

File tree

3 files changed

+99
-47
lines changed

3 files changed

+99
-47
lines changed

src/receiver.rs

+43-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ pub struct Receiver<T: DeserializeOwned, E: Endian, R: Read = BufReader<TcpStrea
1616
config: Config,
1717
max_size: usize,
1818
_marker: PhantomData<(T, E)>,
19+
20+
// This buffer is used for storing the currently read bytes in case the stream is nonblocking.
21+
// Otherwise, bincode would deserialize only the currently read bytes.
22+
buffer: Vec<u8>,
23+
24+
bytes_read: usize,
25+
bytes_to_read: usize,
1926
}
2027

2128
/// A more convenient way of initializing receivers.
@@ -83,6 +90,9 @@ impl<T: DeserializeOwned, R: Read, E: Endian> TypedReceiverBuilder<T, R, E> {
8390
reader,
8491
config: E::config(),
8592
max_size: self.max_size,
93+
buffer: Vec::new(),
94+
bytes_read: 0,
95+
bytes_to_read: 0,
8696
}
8797
}
8898
}
@@ -98,6 +108,9 @@ impl<T: DeserializeOwned, E: Endian> TypedReceiverBuilder<T, BufReader<TcpStream
98108
_marker: PhantomData,
99109
reader: BufReader::new(stream),
100110
max_size: self.max_size,
111+
buffer: Vec::new(),
112+
bytes_read: 0,
113+
bytes_to_read: 0,
101114
})
102115
}
103116
}
@@ -107,13 +120,15 @@ impl<T: DeserializeOwned, E: Endian> TypedReceiverBuilder<T, TcpStream, E> {
107120
let listener = TcpListener::bind(address)?;
108121

109122
let (stream, _) = listener.accept()?;
110-
stream.set_nodelay(true)?;
111123

112124
Ok(Receiver {
113125
config: E::config(),
114126
_marker: PhantomData,
115127
reader: stream,
116128
max_size: self.max_size,
129+
buffer: Vec::new(),
130+
bytes_read: 0,
131+
bytes_to_read: 0,
117132
})
118133
}
119134
}
@@ -122,12 +137,33 @@ impl<T: DeserializeOwned, E: Endian, R: Read> ChannelRecv<T> for Receiver<T, E,
122137
type Error = RecvError;
123138

124139
fn recv(&mut self) -> Result<T, RecvError> {
125-
let length = self.reader.read_u64::<E>()? as usize;
126-
if length > self.max_size {
127-
return Err(RecvError::TooLarge(length))
140+
if self.bytes_to_read == 0 {
141+
let length = self.reader.read_u64::<E>()? as usize;
142+
if length > self.max_size {
143+
return Err(RecvError::TooLarge(length))
144+
}
145+
146+
if self.buffer.len() < length {
147+
self.buffer.extend(std::iter::repeat(0).take(length - self.buffer.len()));
148+
}
149+
150+
self.bytes_to_read = length;
151+
self.bytes_read = 0;
128152
}
129-
let mut buffer = vec! [0; length];
130-
self.reader.read_exact(&mut buffer)?;
131-
Ok(self.config.deserialize(&buffer)?)
153+
154+
loop {
155+
match self.reader.read(&mut self.buffer[self.bytes_read..self.bytes_to_read]) {
156+
Ok(size) => {
157+
self.bytes_read += size;
158+
if self.bytes_read >= self.bytes_to_read {
159+
let length = self.bytes_to_read;
160+
self.bytes_to_read = 0;
161+
return Ok(self.config.deserialize(&self.buffer[0..length])?)
162+
}
163+
},
164+
Err(error) => return Err(error.into()),
165+
}
166+
}
167+
132168
}
133169
}

tests/blob.rs

+35-24
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ extern crate serde;
55

66
use std::any::Any;
77
use std::io::{BufReader, BufWriter, ErrorKind as IoErrorKind};
8+
use std::io::prelude::*;
89
use std::net::{TcpListener, TcpStream};
910
use std::thread::JoinHandle;
1011

1112
use rand::{FromEntropy, RngCore, rngs::SmallRng};
12-
use tcp_channel::{SenderBuilder, ReceiverBuilder, ChannelSend, ChannelRecv, BigEndian, DEFAULT_MAX_SIZE};
13+
use serde::de::DeserializeOwned;
14+
use tcp_channel::{SenderBuilder, ReceiverBuilder, ChannelSend, ChannelRecv, BigEndian, Receiver as TcpReceiver, RecvError, DEFAULT_MAX_SIZE};
1315

14-
// This emulates a real TCP connection.
16+
// This emulates a real delayed TCP connection.
1517
mod slow_io;
1618
use slow_io::{SlowReader, SlowWriter};
1719

@@ -50,6 +52,18 @@ quick_error! {
5052
}
5153
}
5254

55+
fn pretend_blocking_read<T: DeserializeOwned, R: Read>(receiver: &mut TcpReceiver<T, BigEndian, R>) -> Result<T, RecvError> {
56+
loop {
57+
match receiver.recv() {
58+
Ok(value) => return Ok(value),
59+
Err(RecvError::IoError(ioerror)) => match ioerror.kind() {
60+
IoErrorKind::WouldBlock => continue,
61+
_ => return Err(RecvError::IoError(ioerror).into()),
62+
}
63+
Err(error) => return Err(error.into()),
64+
}
65+
}
66+
}
5367
fn blob(slow: bool, blocking: bool, max_size: usize) -> Result<(), Error> {
5468
const SIZE: usize = 262_144;
5569
// This test generates a random 256KiB BLOB, sends it, and then receives the BLOB, where every byte is
@@ -72,46 +86,46 @@ fn blob(slow: bool, blocking: bool, max_size: usize) -> Result<(), Error> {
7286
break Err(ioerror)
7387
}
7488
}
75-
}?;
89+
}.unwrap();
7690

77-
sender.send(port)?;
78-
let (stream, _) = listener.accept()?;
91+
sender.send(port).unwrap();
92+
let (stream, _) = listener.accept().unwrap();
7993

8094
let mut receiver = ReceiverBuilder::buffered()
8195
.with_type::<Request>()
8296
.with_endianness::<BigEndian>()
8397
.with_reader::<BufReader<SlowReader<TcpStream>>>()
8498
.with_max_size(max_size)
85-
.build(BufReader::new(SlowReader::new(stream.try_clone()?, slow, blocking)));
99+
.build(BufReader::new(SlowReader::new(stream.try_clone().unwrap(), slow, blocking)));
86100

87101
let mut sender = SenderBuilder::buffered()
88102
.with_type::<Response>()
89103
.with_endianness::<BigEndian>()
90104
.with_writer::<BufWriter<SlowWriter<TcpStream>>>()
91-
.build(BufWriter::new(SlowWriter::new(stream, slow, blocking)));
105+
.build(BufWriter::new(SlowWriter::new(stream, slow, true)));
92106

93-
while let Ok(command) = receiver.recv() {
107+
while let Ok(command) = pretend_blocking_read(&mut receiver) {
94108
match command {
95109
Request::SendBlob(mut blob) => {
96110
for byte in blob.iter_mut() {
97111
*byte = byte.wrapping_add(1)
98112
}
99-
sender.send(&Response::Respond(blob))?;
100-
sender.flush()?;
113+
sender.send(&Response::Respond(blob)).unwrap();
114+
sender.flush().unwrap();
101115
},
102116
Request::Stop => return Ok(())
103117
}
104118
}
105119

106120
Ok(())
107121
});
108-
let port = receiver.recv()?;
109-
let stream = TcpStream::connect(format!("127.0.0.1:{}", port))?;
122+
let port = receiver.recv().unwrap();
123+
let stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
110124
let mut sender = SenderBuilder::realtime()
111125
.with_type::<Request>()
112126
.with_writer::<SlowWriter<TcpStream>>()
113127
.with_endianness::<BigEndian>()
114-
.build(SlowWriter::new(stream.try_clone()?, slow, blocking));
128+
.build(SlowWriter::new(stream.try_clone().unwrap(), slow, true));
115129

116130
let mut receiver = ReceiverBuilder::buffered()
117131
.with_type::<Response>()
@@ -128,23 +142,24 @@ fn blob(slow: bool, blocking: bool, max_size: usize) -> Result<(), Error> {
128142
blob.into_boxed_slice()
129143
};
130144

131-
sender.send(&Request::SendBlob(blob.clone()))?;
132-
sender.flush()?;
145+
sender.send(&Request::SendBlob(blob.clone())).unwrap();
146+
sender.flush().unwrap();
133147

134-
let new_blob = match receiver.recv()? {
135-
Response::Respond(new_blob) => new_blob,
148+
let new_blob = match pretend_blocking_read(&mut receiver).unwrap() {
149+
Response::Respond(blob) => blob,
136150
};
137151
let precalculated_new_blob = blob.into_iter()
138152
.map(|byte| byte.wrapping_add(1))
139153
.collect::<Box<[u8]>>();
140154

141155
assert_ne!(blob, new_blob);
142156
assert_eq!(new_blob, precalculated_new_blob);
157+
println!("Asserted");
143158

144-
sender.send(&Request::Stop)?;
145-
sender.flush()?;
159+
sender.send(&Request::Stop).unwrap();
160+
sender.flush().unwrap();
146161

147-
thread.join()??;
162+
thread.join().unwrap().unwrap();
148163

149164
Ok(())
150165
}
@@ -153,10 +168,6 @@ fn fast_blob() -> Result<(), Error> {
153168
blob(false, true, DEFAULT_MAX_SIZE)
154169
}
155170
#[test]
156-
fn fast_nonblocking_blob() -> Result<(), Error> {
157-
blob(false, false, DEFAULT_MAX_SIZE)
158-
}
159-
#[test]
160171
fn slow_blob() -> Result<(), Error> {
161172
blob(true, true, DEFAULT_MAX_SIZE)
162173
}

tests/slow_io.rs

+21-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use std::io::{Read, Result, Write, ErrorKind as IoErrorKind};
1+
use std::io::{Result, ErrorKind as IoErrorKind};
2+
use std::io::prelude::*;
23
use std::time::{Instant, Duration};
34

45
// In milliseconds.
@@ -20,26 +21,36 @@ impl<T: Write> SlowWriter<T> {
2021
}
2122
}
2223
}
24+
25+
fn emulate_nonblocking(last_io: &mut Option<Instant>) -> Result<()> {
26+
match *last_io {
27+
Some(last_io_some) => if last_io_some + Duration::from_millis(DELAY) < Instant::now() {
28+
*last_io = None;
29+
return Err(IoErrorKind::WouldBlock.into())
30+
},
31+
None => *last_io = Some(Instant::now()),
32+
}
33+
Ok(())
34+
}
35+
2336
impl<T: Write> Write for SlowWriter<T> {
2437
fn write(&mut self, data: &[u8]) -> Result<usize> {
2538
if self.slow {
2639
if self.blocking {
2740
std::thread::sleep(Duration::from_millis(DELAY));
2841
} else {
29-
match self.last_write {
30-
Some(last_write) => if last_write + Duration::from_millis(DELAY) > Instant::now() {
31-
} else {
32-
return Err(IoErrorKind::WouldBlock.into())
33-
},
34-
None => self.last_write = Some(Instant::now()),
35-
}
42+
emulate_nonblocking(&mut self.last_write)?
3643
}
3744
}
3845
self.inner.write(data)
3946
}
4047
fn flush(&mut self) -> Result<()> {
4148
if self.slow {
42-
std::thread::sleep(Duration::from_millis(DELAY));
49+
if self.blocking {
50+
std::thread::sleep(Duration::from_millis(DELAY));
51+
} else {
52+
emulate_nonblocking(&mut self.last_write)?;
53+
}
4354
}
4455
self.inner.flush()
4556
}
@@ -66,13 +77,7 @@ impl<T: Read> Read for SlowReader<T> {
6677
if self.blocking {
6778
std::thread::sleep(Duration::from_millis(DELAY));
6879
} else {
69-
match self.last_read {
70-
Some(last_write) => if last_write + Duration::from_millis(DELAY) > Instant::now() {
71-
} else {
72-
return Err(IoErrorKind::WouldBlock.into())
73-
},
74-
None => self.last_read = Some(Instant::now()),
75-
}
80+
emulate_nonblocking(&mut self.last_read)?
7681
}
7782
}
7883
self.inner.read(buffer)

0 commit comments

Comments
 (0)