Skip to content

Commit 88e12f9

Browse files
committed
Snapshot refactoring
1 parent 39f8dc2 commit 88e12f9

File tree

1 file changed

+69
-35
lines changed

1 file changed

+69
-35
lines changed

src/vmm/src/snapshot/mod.rs

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use std::io::{Read, Write};
3131

3232
use bincode::config;
3333
use bincode::config::{Configuration, Fixint, Limit, LittleEndian};
34+
use bincode::error::{DecodeError, EncodeError};
3435
use semver::Version;
3536
use serde::de::DeserializeOwned;
3637
use serde::{Deserialize, Serialize};
@@ -73,7 +74,7 @@ pub enum SnapshotError {
7374

7475
/// Firecracker snapshot header
7576
#[derive(Debug, Serialize, Deserialize)]
76-
pub struct SnapshotHdr {
77+
struct SnapshotHdr {
7778
/// Magic value
7879
magic: u64,
7980
/// Snapshot data version
@@ -82,21 +83,26 @@ pub struct SnapshotHdr {
8283

8384
impl SnapshotHdr {
8485
/// Create a new header for writing snapshots
85-
pub fn new(version: Version) -> Self {
86+
fn new(version: Version) -> Self {
8687
Self {
8788
magic: SNAPSHOT_MAGIC_ID,
8889
version,
8990
}
9091
}
9192

9293
/// Load and deserialize just the header (magic + version)
93-
pub fn load<R: Read>(reader: &mut R) -> Result<Self, SnapshotError> {
94+
fn load<R: Read>(reader: &mut R) -> Result<Self, SnapshotError> {
9495
let hdr: SnapshotHdr = deserialize(reader)?;
95-
Ok(hdr)
96+
if hdr.magic != SNAPSHOT_MAGIC_ID {
97+
Err(SnapshotError::InvalidMagic(hdr.magic))
98+
}
99+
else {
100+
Ok(hdr)
101+
}
96102
}
97103

98104
/// Serialize and write just the header
99-
pub fn store<W: Write>(&self, writer: &mut W) -> Result<(), SnapshotError> {
105+
fn store<W: Write>(&self, writer: &mut W) -> Result<(), SnapshotError> {
100106
serialize(writer, self)?;
101107
Ok(())
102108
}
@@ -109,44 +115,44 @@ pub struct Snapshot<Data> {
109115
pub data: Data,
110116
}
111117

118+
// Implementations for deserializing snapshots
119+
// Publicly exposed functions:
120+
// - load_unchecked()
121+
//- load()
112122
impl<Data: DeserializeOwned + Debug> Snapshot<Data> {
113-
/// Load without CRC or version‐check, but verify the magic
123+
/// Load without CRC or version‐check, but verify magic via `SnapshotHdr::load`.
114124
pub fn load_unchecked<R: Read + Debug>(reader: &mut R) -> Result<Self, SnapshotError> {
115-
let hdr: SnapshotHdr = deserialize(reader)?;
116-
if hdr.magic != SNAPSHOT_MAGIC_ID {
117-
return Err(SnapshotError::InvalidMagic(hdr.magic));
118-
}
119-
let data: Data = deserialize(reader)?;
125+
// this calls `deserialize` + checks magic internally
126+
let hdr: SnapshotHdr = SnapshotHdr::load(reader)?;
127+
let data: Data = deserialize(reader)?;
120128
Ok(Self { header: hdr, data })
121129
}
122130

123-
/// Load with CRC64 validation
124-
pub fn load<R: Read + Debug>(reader: &mut R, snapshot_len: usize) -> Result<Self, SnapshotError> {
131+
/// Load with CRC64 validation in one pass, using `load_unchecked` for header+data.
132+
pub fn load<R: Read + Debug>(reader: &mut R) -> Result<Self, SnapshotError> {
133+
// 1) Wrap in CRC reader
125134
let mut crc_reader = CRC64Reader::new(reader);
126135

127-
// Snapshot must be at least (len of CRC64)
128-
let raw_snapshot_len = snapshot_len
129-
.checked_sub(std::mem::size_of::<u64>())
130-
.ok_or(SnapshotError::InvalidSnapshotSize)?;
131-
132-
let mut snapshot_buf = vec![0u8; raw_snapshot_len];
133-
crc_reader
134-
.read_exact(&mut snapshot_buf)
135-
.map_err(|err| SnapshotError::Io(err.raw_os_error().unwrap_or(libc::EINVAL)))?;
136+
// 2) Parse header + payload & magic‐check
137+
let snapshot = Snapshot::load_unchecked(&mut crc_reader)?;
136138

137-
// Compute then read stored checksum
139+
// 3) Grab the computed CRC over everything read so far
138140
let computed = crc_reader.checksum();
141+
142+
// 4) Deserialize the trailing u64 and compare
139143
let stored: u64 = deserialize(&mut crc_reader)?;
140-
if computed != stored {
144+
if stored != computed {
141145
return Err(SnapshotError::Crc64(computed));
142146
}
143147

144-
// Now parse header+data from the buffered bytes
145-
let mut slice: &[u8] = snapshot_buf.as_slice();
146-
Snapshot::load_unchecked(&mut slice)
148+
Ok(snapshot)
147149
}
148150
}
149151

152+
// Implementations for serializing snapshots
153+
// Publicly-exposed *methods*:
154+
// - save(self,...)
155+
// - save_with_crc(self,...)
150156
impl<Data: Serialize + Debug> Snapshot<Data> {
151157
/// Save without CRC64
152158
pub fn save_without_crc<W: Write + Debug>(&self, writer: &mut W) -> Result<(), SnapshotError> {
@@ -166,9 +172,12 @@ impl<Data: Serialize + Debug> Snapshot<Data> {
166172
}
167173
}
168174

175+
// General methods for snapshots (related to serialization, see above, since an
176+
// instance is needed to serialize)
169177
impl<Data> Snapshot<Data> {
170178
/// Construct from a pre‐built header + payload
171-
pub fn new(header: SnapshotHdr, data: Data) -> Self {
179+
pub fn new(version: Version, data: Data) -> Self {
180+
header = SnapshotHdr::new(version);
172181
Snapshot { header, data }
173182
}
174183

@@ -178,31 +187,56 @@ impl<Data> Snapshot<Data> {
178187
}
179188
}
180189

181-
/// Deserialize any `O: DeserializeOwned + Debug` via bincode + our config
190+
/// Deserialize any `O: DeserializeOwned + Debug` via bincode + our config,
182191
fn deserialize<T, O>(reader: &mut T) -> Result<O, SnapshotError>
183192
where
184193
T: Read,
185194
O: DeserializeOwned + Debug,
186195
{
187196
bincode::serde::decode_from_std_read(reader, BINCODE_CONFIG)
188-
.map_err(|err| SnapshotError::Serde(err.to_string()))
197+
.map_err(|err| match err {
198+
// The reader hit an actual IO error.
199+
DecodeError::Io { inner, .. } =>
200+
SnapshotError::Io(inner.raw_os_error().unwrap_or(EIO)),
201+
202+
// Not enough bytes in the input for what we expected.
203+
DecodeError::UnexpectedEnd { .. } |
204+
DecodeError::LimitExceeded =>
205+
SnapshotError::InvalidSnapshotSize,
206+
207+
// Anything else is a ser/de format issue.
208+
other =>
209+
SnapshotError::Serde(other.to_string()),
210+
})
189211
}
190212

191-
/// Serialize any `O: Serialize + Debug` into a Vec, write it, and return the byte‐count
213+
/// Serialize any `O: Serialize + Debug` into a Vec, write it, and return the byte‐count,
192214
fn serialize<T, O>(writer: &mut T, data: &O) -> Result<usize, SnapshotError>
193215
where
194216
T: Write,
195217
O: Serialize + Debug,
196218
{
197-
// First serialize into an inmemory buffer using our config
219+
// 1) Encode into an in-memory buffer
198220
let mut buf = Vec::new();
199221
bincode::serde::encode_into_std_write(data, &mut buf, BINCODE_CONFIG)
200-
.map_err(|err| SnapshotError::Serde(err.to_string()))?;
222+
.map_err(|err| match err {
223+
// Ran out of room while encoding
224+
EncodeError::UnexpectedEnd =>
225+
SnapshotError::Io(libc::EIO),
226+
227+
// Underlying IO failure during encode (index tells how many bytes got written)
228+
EncodeError::Io { inner, .. } =>
229+
SnapshotError::Io(inner.raw_os_error().unwrap_or(libc::EIO)),
230+
231+
// Any other encode error we surface as Serde
232+
other =>
233+
SnapshotError::Serde(other.to_string()),
234+
})?;
201235

202-
// Then write it out
236+
// 2) Flush that buffer to the target writer
203237
writer
204238
.write_all(&buf)
205-
.map_err(|err| SnapshotError::Io(err.raw_os_error().unwrap_or(libc::EIO)))?;
239+
.map_err(|io_err| SnapshotError::Io(io_err.raw_os_error().unwrap_or(libc::EIO)))?;
206240

207241
Ok(buf.len())
208242
}

0 commit comments

Comments
 (0)