diff --git a/serde-reflection/src/de.rs b/serde-reflection/src/de.rs index c08baa634..d7707f4b8 100644 --- a/serde-reflection/src/de.rs +++ b/serde-reflection/src/de.rs @@ -20,18 +20,15 @@ use std::collections::btree_map::{BTreeMap, Entry}; /// `&'a mut` references used to return tracing results. /// * The lifetime 'de is fixed and the `&'de` reference meant to let us /// borrow values from previous serialization runs. -pub(crate) struct Deserializer<'de, 'a> { +pub struct Deserializer<'de, 'a> { tracer: &'a mut Tracer, samples: &'de Samples, format: &'a mut Format, } impl<'de, 'a> Deserializer<'de, 'a> { - pub(crate) fn new( - tracer: &'a mut Tracer, - samples: &'de Samples, - format: &'a mut Format, - ) -> Self { + /// Create a new Deserializer + pub fn new(tracer: &'a mut Tracer, samples: &'de Samples, format: &'a mut Format) -> Self { Deserializer { tracer, samples, @@ -422,9 +419,11 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { _ => unreachable!(), }; - // If the enum is already marked as incomplete, visit the first index, hoping - // to avoid recursion. - if self.tracer.incomplete_enums.contains_key(enum_name) { + // If the enum is already marked as incomplete and not pending, visit the first index, + // hoping to avoid recursion. + if let Some(EnumProgress::IndexedVariantsRemaining | EnumProgress::NamedVariantsRemaining) = + self.tracer.incomplete_enums.get(enum_name) + { return visitor.visit_enum(EnumDeserializer::new( self.tracer, self.samples, diff --git a/serde-reflection/src/lib.rs b/serde-reflection/src/lib.rs index 4de1c3332..580d3847f 100644 --- a/serde-reflection/src/lib.rs +++ b/serde-reflection/src/lib.rs @@ -361,7 +361,9 @@ mod ser; mod trace; mod value; +pub use de::Deserializer; pub use error::{Error, Result}; pub use format::{ContainerFormat, Format, FormatHolder, Named, Variable, VariantFormat}; -pub use trace::{Registry, Samples, Tracer, TracerConfig}; +pub use ser::Serializer; +pub use trace::{EnumProgress, Registry, Samples, Tracer, TracerConfig}; pub use value::Value; diff --git a/serde-reflection/src/ser.rs b/serde-reflection/src/ser.rs index f221878dc..7b15e28d4 100644 --- a/serde-reflection/src/ser.rs +++ b/serde-reflection/src/ser.rs @@ -12,13 +12,14 @@ use serde::{ser, Serialize}; /// Serialize a single value. /// The lifetime 'a is set by the serialization call site and the `&'a mut` /// references used to return tracing results and serialization samples. -pub(crate) struct Serializer<'a> { +pub struct Serializer<'a> { tracer: &'a mut Tracer, samples: &'a mut Samples, } impl<'a> Serializer<'a> { - pub(crate) fn new(tracer: &'a mut Tracer, samples: &'a mut Samples) -> Self { + /// Create a new Serializer + pub fn new(tracer: &'a mut Tracer, samples: &'a mut Samples) -> Self { Self { tracer, samples } } } diff --git a/serde-reflection/src/trace.rs b/serde-reflection/src/trace.rs index b3fcf3b46..480c278b3 100644 --- a/serde-reflection/src/trace.rs +++ b/serde-reflection/src/trace.rs @@ -36,12 +36,15 @@ pub struct Tracer { pub(crate) discriminants: BTreeMap<(TypeId, VariantId<'static>), Discriminant>, } +/// Type of untraced enum variants #[derive(Copy, Clone, Debug)] -pub(crate) enum EnumProgress { +pub enum EnumProgress { /// There are variant names that have not yet been traced. NamedVariantsRemaining, /// There are variant numbers that have not yet been traced. IndexedVariantsRemaining, + /// Tracing of further variants is pending. + Pending, } #[derive(Eq, PartialEq, Ord, PartialOrd, Debug)] @@ -243,6 +246,24 @@ impl Tracer { Ok((format, value)) } + /// Enable tracing of further variants of a incomplete enum. + /// + /// Marks an enum name as pending in the map of incomplete enums + /// and returns which type of variant tracing still needs to be performed. + /// + /// Call this in order to (simultaneously): + /// + /// * determine whether all variants of an enum have been traced, + /// * determine which type of variant tracing ([`EnumProgress`]) still needs to be + /// performed, and + /// * allow `Deserializer`/`trace_type_once` to make progress on a top level enum by + /// enabling tracing the next variant. + pub fn pend_enum(&mut self, name: &str) -> Option { + self.incomplete_enums + .get_mut(name) + .map(|p| std::mem::replace(p, EnumProgress::Pending)) + } + /// Same as `trace_type_once` but if `T` is an enum, we repeat the process /// until all variants of `T` are covered. /// We accumulate and return all the sampled values at the end. @@ -255,9 +276,12 @@ impl Tracer { let (format, value) = self.trace_type_once::(samples)?; values.push(value); if let Format::TypeName(name) = &format { - if let Some(&progress) = self.incomplete_enums.get(name) { + if let Some(progress) = self.pend_enum(name) { + debug_assert!( + !matches!(progress, EnumProgress::Pending), + "failed to make progress tracing enum {name}" + ); // Restart the analysis to find more variants of T. - self.incomplete_enums.remove(name); if let EnumProgress::NamedVariantsRemaining = progress { values.pop().unwrap(); } @@ -294,9 +318,12 @@ impl Tracer { let (format, value) = self.trace_type_once_with_seed(samples, seed.clone())?; values.push(value); if let Format::TypeName(name) = &format { - if let Some(&progress) = self.incomplete_enums.get(name) { + if let Some(progress) = self.pend_enum(name) { + debug_assert!( + !matches!(progress, EnumProgress::Pending), + "failed to make progress tracing enum {name}" + ); // Restart the analysis to find more variants of T. - self.incomplete_enums.remove(name); if let EnumProgress::NamedVariantsRemaining = progress { values.pop().unwrap(); }