diff --git a/src/de.rs b/src/de.rs index 8b2a9f9..165bc31 100644 --- a/src/de.rs +++ b/src/de.rs @@ -454,8 +454,7 @@ impl<'de> de::MapAccess<'de> for Map<'de> { K: de::DeserializeSeed<'de>, { if let Some(pair) = self.pairs.pop_front() { - seed.deserialize(&mut Deserializer::from_pair(pair)) - .map(Some) + seed.deserialize(KeyDeserializer { pair }).map(Some) } else { Ok(None) } @@ -550,3 +549,90 @@ impl<'de, 'a> de::VariantAccess<'de> for Variant<'de> { } } } + +struct KeyDeserializer<'de> { + pair: Pair<'de, Rule>, +} + +macro_rules! deserialize_ints { + ($($function:ident => $visit_method:ident,)*) => { $( + fn $function(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let span = self.pair.as_span(); + debug_assert!(matches!(self.pair.as_rule(), Rule::string | Rule::identifier)); + let string = parse_string(self.pair)?; + let mut res = if let Ok(parsed) = string.parse() { + visitor.$visit_method(parsed) + } else { + visitor.visit_string(string) + }; + error::set_location(&mut res, &span); + res + } + )* } +} + +impl<'de> de::Deserializer<'de> for KeyDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let span = self.pair.as_span(); + debug_assert!(matches!( + self.pair.as_rule(), + Rule::string | Rule::identifier + )); + let mut res = visitor.visit_string(parse_string(self.pair)?); + error::set_location(&mut res, &span); + res + } + + deserialize_ints! { + deserialize_i8 => visit_i8, + deserialize_i16 => visit_i16, + deserialize_i32 => visit_i32, + deserialize_i64 => visit_i64, + deserialize_i128 => visit_i128, + deserialize_u8 => visit_u8, + deserialize_u16 => visit_u16, + deserialize_u32 => visit_u32, + deserialize_u64 => visit_u64, + deserialize_u128 => visit_u128, + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // Keys cannot be null + visitor.visit_some(self) + } + + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + Deserializer::from_pair(self.pair).deserialize_enum(name, variants, visitor) + } + + forward_to_deserialize_any! { + bool f32 f64 char str string bytes byte_buf unit unit_struct seq + tuple tuple_struct map struct identifier ignored_any + } +} diff --git a/tests/de.rs b/tests/de.rs index 7268c1a..d58660e 100644 --- a/tests/de.rs +++ b/tests/de.rs @@ -608,6 +608,41 @@ fn deserializes_map() { deserializes_to("{ a: 1, 'b': 2, \"c\": 3 }", m); } +#[test] +fn deserializes_map_strange_keys() { + // integer keys + let mut m = HashMap::new(); + m.insert(0, "a".to_owned()); + m.insert(1, "b".to_owned()); + m.insert(2, "c".to_owned()); + deserializes_to(r#"{'0':"a","1":"b","02":"c"}"#, m); + + // option keys + let mut m = HashMap::new(); + m.insert(Some("a".to_owned()), 0); + deserializes_to("{a:0}", m); + + // newtype keys + #[derive(Debug, PartialEq, Eq, Hash, Deserialize)] + struct New(i32); + let mut m = HashMap::new(); + m.insert(New(0), "a".to_owned()); + deserializes_to("{'0':'a'}", m); + + // enum keys + #[derive(Debug, PartialEq, Eq, Hash, Deserialize)] + enum Key { + A, + B, + C, + } + let mut m = HashMap::new(); + m.insert(Key::A, "a".to_owned()); + m.insert(Key::B, "b".to_owned()); + m.insert(Key::C, "c".to_owned()); + deserializes_to("{'A':'a','B':'b','C':'c'}", m); +} + #[test] fn deserializes_map_size_hint() { #[derive(Debug, PartialEq)]