diff --git a/crates/file_store/src/entry_iter.rs b/crates/file_store/src/entry_iter.rs index 770f264f3..6be3fd034 100644 --- a/crates/file_store/src/entry_iter.rs +++ b/crates/file_store/src/entry_iter.rs @@ -1,7 +1,7 @@ use bincode::Options; use std::{ fs::File, - io::{self, Seek}, + io::{self, BufReader, Seek}, marker::PhantomData, }; @@ -14,8 +14,9 @@ use crate::bincode_options; /// /// [`next`]: Self::next pub struct EntryIter<'t, T> { - db_file: Option<&'t mut File>, - + /// Buffered reader around the file + db_file: BufReader<&'t mut File>, + finished: bool, /// The file position for the first read of `db_file`. start_pos: Option, types: PhantomData, @@ -24,8 +25,9 @@ pub struct EntryIter<'t, T> { impl<'t, T> EntryIter<'t, T> { pub fn new(start_pos: u64, db_file: &'t mut File) -> Self { Self { - db_file: Some(db_file), + db_file: BufReader::new(db_file), start_pos: Some(start_pos), + finished: false, types: PhantomData, } } @@ -38,44 +40,44 @@ where type Item = Result; fn next(&mut self) -> Option { - // closure which reads a single entry starting from `self.pos` - let read_one = |f: &mut File, start_pos: Option| -> Result, IterError> { - let pos = match start_pos { - Some(pos) => f.seek(io::SeekFrom::Start(pos))?, - None => f.stream_position()?, - }; + if self.finished { + return None; + } + (|| { + if let Some(start) = self.start_pos.take() { + self.db_file.seek(io::SeekFrom::Start(start))?; + } - match bincode_options().deserialize_from(&*f) { - Ok(changeset) => { - f.stream_position()?; - Ok(Some(changeset)) - } + let pos_before_read = self.db_file.stream_position()?; + match bincode_options().deserialize_from(&mut self.db_file) { + Ok(changeset) => Ok(Some(changeset)), Err(e) => { + self.finished = true; + let pos_after_read = self.db_file.stream_position()?; + // allow unexpected EOF if 0 bytes were read if let bincode::ErrorKind::Io(inner) = &*e { - if inner.kind() == io::ErrorKind::UnexpectedEof { - let eof = f.seek(io::SeekFrom::End(0))?; - if pos == eof { - return Ok(None); - } + if inner.kind() == io::ErrorKind::UnexpectedEof + && pos_after_read == pos_before_read + { + return Ok(None); } } - f.seek(io::SeekFrom::Start(pos))?; + self.db_file.seek(io::SeekFrom::Start(pos_before_read))?; Err(IterError::Bincode(*e)) } } - }; - - let result = read_one(self.db_file.as_mut()?, self.start_pos.take()); - if result.is_err() { - self.db_file = None; - } - result.transpose() + })() + .transpose() } } -impl From for IterError { - fn from(value: io::Error) -> Self { - IterError::Io(value) +impl<'t, T> Drop for EntryIter<'t, T> { + fn drop(&mut self) { + // This syncs the underlying file's offset with the buffer's position. This way, we + // maintain the correct position to start the next read/write. + if let Ok(pos) = self.db_file.stream_position() { + let _ = self.db_file.get_mut().seek(io::SeekFrom::Start(pos)); + } } } @@ -97,4 +99,10 @@ impl core::fmt::Display for IterError { } } +impl From for IterError { + fn from(value: io::Error) -> Self { + IterError::Io(value) + } +} + impl std::error::Error for IterError {} diff --git a/crates/file_store/src/store.rs b/crates/file_store/src/store.rs index ebab2fd00..0dc45d28c 100644 --- a/crates/file_store/src/store.rs +++ b/crates/file_store/src/store.rs @@ -219,6 +219,7 @@ mod test { use bincode::DefaultOptions; use std::{ + collections::BTreeSet, io::{Read, Write}, vec::Vec, }; @@ -228,7 +229,7 @@ mod test { const TEST_MAGIC_BYTES: [u8; TEST_MAGIC_BYTES_LEN] = [98, 100, 107, 102, 115, 49, 49, 49, 49, 49, 49, 49]; - type TestChangeSet = Vec; + type TestChangeSet = BTreeSet; #[derive(Debug)] struct TestTracker; @@ -253,7 +254,7 @@ mod test { fn open_or_create_new() { let temp_dir = tempfile::tempdir().unwrap(); let file_path = temp_dir.path().join("db_file"); - let changeset = vec!["hello".to_string(), "world".to_string()]; + let changeset = BTreeSet::from(["hello".to_string(), "world".to_string()]); { let mut db = Store::::open_or_create_new(&TEST_MAGIC_BYTES, &file_path) @@ -304,7 +305,7 @@ mod test { let mut data = [255_u8; 2000]; data[..TEST_MAGIC_BYTES_LEN].copy_from_slice(&TEST_MAGIC_BYTES); - let changeset = vec!["one".into(), "two".into(), "three!".into()]; + let changeset = TestChangeSet::from(["one".into(), "two".into(), "three!".into()]); let mut file = NamedTempFile::new().unwrap(); file.write_all(&data).expect("should write"); @@ -340,4 +341,119 @@ mod test { assert_eq!(got_bytes, expected_bytes); } + + #[test] + fn last_write_is_short() { + let temp_dir = tempfile::tempdir().unwrap(); + + let changesets = [ + TestChangeSet::from(["1".into()]), + TestChangeSet::from(["2".into(), "3".into()]), + TestChangeSet::from(["4".into(), "5".into(), "6".into()]), + ]; + let last_changeset = TestChangeSet::from(["7".into(), "8".into(), "9".into()]); + let last_changeset_bytes = bincode_options().serialize(&last_changeset).unwrap(); + + for short_write_len in 1..last_changeset_bytes.len() - 1 { + let file_path = temp_dir.path().join(format!("{}.dat", short_write_len)); + println!("Test file: {:?}", file_path); + + // simulate creating a file, writing data where the last write is incomplete + { + let mut db = + Store::::create_new(&TEST_MAGIC_BYTES, &file_path).unwrap(); + for changeset in &changesets { + db.append_changeset(changeset).unwrap(); + } + // this is the incomplete write + db.db_file + .write_all(&last_changeset_bytes[..short_write_len]) + .unwrap(); + } + + // load file again and aggregate changesets + // write the last changeset again (this time it succeeds) + { + let mut db = Store::::open(&TEST_MAGIC_BYTES, &file_path).unwrap(); + let err = db + .aggregate_changesets() + .expect_err("should return error as last read is short"); + assert_eq!( + err.changeset, + changesets.iter().cloned().reduce(|mut acc, cs| { + Append::append(&mut acc, cs); + acc + }), + "should recover all changesets that are written in full", + ); + db.db_file.write_all(&last_changeset_bytes).unwrap(); + } + + // load file again - this time we should successfully aggregate all changesets + { + let mut db = Store::::open(&TEST_MAGIC_BYTES, &file_path).unwrap(); + let aggregated_changesets = db + .aggregate_changesets() + .expect("aggregating all changesets should succeed"); + assert_eq!( + aggregated_changesets, + changesets + .iter() + .cloned() + .chain(core::iter::once(last_changeset.clone())) + .reduce(|mut acc, cs| { + Append::append(&mut acc, cs); + acc + }), + "should recover all changesets", + ); + } + } + } + + #[test] + fn write_after_short_read() { + let temp_dir = tempfile::tempdir().unwrap(); + + let changesets = (0..20) + .map(|n| TestChangeSet::from([format!("{}", n)])) + .collect::>(); + let last_changeset = TestChangeSet::from(["last".into()]); + + for read_count in 0..changesets.len() { + let file_path = temp_dir.path().join(format!("{}.dat", read_count)); + println!("Test file: {:?}", file_path); + + // First, we create the file with all the changesets! + let mut db = Store::::create_new(&TEST_MAGIC_BYTES, &file_path).unwrap(); + for changeset in &changesets { + db.append_changeset(changeset).unwrap(); + } + drop(db); + + // We re-open the file and read `read_count` number of changesets. + let mut db = Store::::open(&TEST_MAGIC_BYTES, &file_path).unwrap(); + let mut exp_aggregation = db + .iter_changesets() + .take(read_count) + .map(|r| r.expect("must read valid changeset")) + .fold(TestChangeSet::default(), |mut acc, v| { + Append::append(&mut acc, v); + acc + }); + // We write after a short read. + db.write_changes(&last_changeset) + .expect("last write must succeed"); + Append::append(&mut exp_aggregation, last_changeset.clone()); + drop(db); + + // We open the file again and check whether aggregate changeset is expected. + let aggregation = Store::::open(&TEST_MAGIC_BYTES, &file_path) + .unwrap() + .aggregate_changesets() + .expect("must aggregate changesets") + .unwrap_or_default(); + assert_eq!(aggregation, exp_aggregation); + } + } }