diff --git a/src/ecdh.rs b/src/ecdh.rs index 394abe349..b72d6b75e 100644 --- a/src/ecdh.rs +++ b/src/ecdh.rs @@ -167,7 +167,6 @@ mod tests { use rand::thread_rng; use super::SharedSecret; use super::super::Secp256k1; - use Error; #[test] fn ecdh() { @@ -187,7 +186,7 @@ mod tests { let s = Secp256k1::signing_only(); let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); let (sk2, pk2) = s.generate_keypair(&mut thread_rng()); - + let sec1 = SharedSecret::new_with_hash(&pk1, &sk2, |x,_| x.into()); let sec2 = SharedSecret::new_with_hash(&pk2, &sk1, |x,_| x.into()); let sec_odd = SharedSecret::new_with_hash(&pk1, &sk1, |x,_| x.into()); diff --git a/src/key.rs b/src/key.rs index ba1d22177..ce25da96e 100644 --- a/src/key.rs +++ b/src/key.rs @@ -199,7 +199,34 @@ impl SecretKey { } } -serde_impl!(SecretKey, constants::SECRET_KEY_SIZE); + +#[cfg(feature = "serde")] +impl ::serde::Serialize for SecretKey { + fn serialize(&self, s: S) -> Result { + if s.is_human_readable() { + s.collect_str(self) + } else { + s.serialize_bytes(&self[..]) + } + } +} + +#[cfg(feature = "serde")] +impl<'de> ::serde::Deserialize<'de> for SecretKey { + fn deserialize>(d: D) -> Result { + if d.is_human_readable() { + d.deserialize_str(super::serde_util::HexVisitor::new( + "a hex string representing 32 byte SecretKey" + )) + } else { + d.deserialize_bytes(super::serde_util::BytesVisitor::new( + "raw 32 bytes SecretKey", + SecretKey::from_slice + )) + } + } +} + impl PublicKey { /// Obtains a raw const pointer suitable for use with FFI functions @@ -392,53 +419,14 @@ impl ::serde::Serialize for PublicKey { impl<'de> ::serde::Deserialize<'de> for PublicKey { fn deserialize>(d: D) -> Result { if d.is_human_readable() { - struct HexVisitor; - - impl<'de> ::serde::de::Visitor<'de> for HexVisitor { - type Value = PublicKey; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("an ASCII hex string") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - if let Ok(hex) = str::from_utf8(v) { - str::FromStr::from_str(hex).map_err(E::custom) - } else { - Err(E::invalid_value(::serde::de::Unexpected::Bytes(v), &self)) - } - } - - fn visit_str(self, v: &str) -> Result - where - E: ::serde::de::Error, - { - str::FromStr::from_str(v).map_err(E::custom) - } - } - d.deserialize_str(HexVisitor) + d.deserialize_str(super::serde_util::HexVisitor::new( + "an ASCII hex string representing a public key" + )) } else { - struct BytesVisitor; - - impl<'de> ::serde::de::Visitor<'de> for BytesVisitor { - type Value = PublicKey; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a bytestring") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - PublicKey::from_slice(v).map_err(E::custom) - } - } - - d.deserialize_bytes(BytesVisitor) + d.deserialize_bytes(super::serde_util::BytesVisitor::new( + "a bytestring representing a public key", + PublicKey::from_slice + )) } } } @@ -848,8 +836,20 @@ mod test { let pk = PublicKey::from_secret_key(&s, &sk); assert_tokens(&sk.compact(), &[Token::BorrowedBytes(&SK_BYTES[..])]); + assert_tokens(&sk.compact(), &[Token::Bytes(&SK_BYTES)]); + assert_tokens(&sk.compact(), &[Token::ByteBuf(&SK_BYTES)]); + assert_tokens(&sk.readable(), &[Token::BorrowedStr(SK_STR)]); + assert_tokens(&sk.readable(), &[Token::Str(SK_STR)]); + assert_tokens(&sk.readable(), &[Token::String(SK_STR)]); + assert_tokens(&pk.compact(), &[Token::BorrowedBytes(&PK_BYTES[..])]); + assert_tokens(&pk.compact(), &[Token::Bytes(&PK_BYTES)]); + assert_tokens(&pk.compact(), &[Token::ByteBuf(&PK_BYTES)]); + assert_tokens(&pk.readable(), &[Token::BorrowedStr(PK_STR)]); + assert_tokens(&pk.readable(), &[Token::Str(PK_STR)]); + assert_tokens(&pk.readable(), &[Token::String(PK_STR)]); + } } diff --git a/src/lib.rs b/src/lib.rs index 92c9a91dc..64c5c5095 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -165,6 +165,8 @@ pub mod ecdh; pub mod key; #[cfg(feature = "recovery")] pub mod recovery; +#[cfg(feature = "serde")] +mod serde_util; pub use key::SecretKey; pub use key::PublicKey; @@ -432,15 +434,16 @@ impl ::serde::Serialize for Signature { #[cfg(feature = "serde")] impl<'de> ::serde::Deserialize<'de> for Signature { - fn deserialize>(d: D) -> Result { - use ::serde::de::Error; - use str::FromStr; + fn deserialize>(d: D) -> Result { if d.is_human_readable() { - let sl: &str = ::serde::Deserialize::deserialize(d)?; - Signature::from_str(sl).map_err(D::Error::custom) + d.deserialize_str(serde_util::HexVisitor::new( + "a hex string representing a DER encoded Signature" + )) } else { - let sl: &[u8] = ::serde::Deserialize::deserialize(d)?; - Signature::from_der(sl).map_err(D::Error::custom) + d.deserialize_bytes(serde_util::BytesVisitor::new( + "raw byte stream, that represents a DER encoded Signature", + Signature::from_der + )) } } } @@ -1081,7 +1084,13 @@ mod tests { "; assert_tokens(&sig.compact(), &[Token::BorrowedBytes(&SIG_BYTES[..])]); + assert_tokens(&sig.compact(), &[Token::Bytes(&SIG_BYTES)]); + assert_tokens(&sig.compact(), &[Token::ByteBuf(&SIG_BYTES)]); + assert_tokens(&sig.readable(), &[Token::BorrowedStr(SIG_STR)]); + assert_tokens(&sig.readable(), &[Token::Str(SIG_STR)]); + assert_tokens(&sig.readable(), &[Token::String(SIG_STR)]); + } // For WASM, just run through our general tests in this file all at once. diff --git a/src/macros.rs b/src/macros.rs index 9cf9ba6d3..bfd41b7bb 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -43,47 +43,3 @@ macro_rules! impl_from_array_len { )+ } } - -#[cfg(feature="serde")] -/// Implements `Serialize` and `Deserialize` for a type `$t` which represents -/// a newtype over a byte-slice over length `$len`. Type `$t` must implement -/// the `FromStr` and `Display` trait. -macro_rules! serde_impl( - ($t:ident, $len:expr) => ( - impl ::serde::Serialize for $t { - fn serialize(&self, s: S) -> Result { - if s.is_human_readable() { - s.collect_str(self) - } else { - s.serialize_bytes(&self[..]) - } - } - } - - impl<'de> ::serde::Deserialize<'de> for $t { - fn deserialize>(d: D) -> Result<$t, D::Error> { - use ::serde::de::Error; - use core::str::FromStr; - - if d.is_human_readable() { - let sl: &str = ::serde::Deserialize::deserialize(d)?; - SecretKey::from_str(sl).map_err(D::Error::custom) - } else { - let sl: &[u8] = ::serde::Deserialize::deserialize(d)?; - if sl.len() != $len { - Err(D::Error::invalid_length(sl.len(), &stringify!($len))) - } else { - let mut ret = [0; $len]; - ret.copy_from_slice(sl); - Ok($t(ret)) - } - } - } - } - ) -); - -#[cfg(not(feature="serde"))] -macro_rules! serde_impl( - ($t:ident, $len:expr) => () -); diff --git a/src/serde_util.rs b/src/serde_util.rs new file mode 100644 index 000000000..50344167f --- /dev/null +++ b/src/serde_util.rs @@ -0,0 +1,76 @@ +use core::fmt; +use core::marker::PhantomData; +use core::str::{self, FromStr}; +use serde::de; + +pub struct HexVisitor { + expectation: &'static str, + _pd: PhantomData, +} + +impl HexVisitor { + pub fn new(expectation: &'static str) -> Self { + HexVisitor { + expectation, + _pd: PhantomData, + } + } +} + +impl<'de, T> de::Visitor<'de> for HexVisitor +where + T: FromStr, + ::Err: fmt::Display, +{ + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(self.expectation) + } + + fn visit_bytes(self, v: &[u8]) -> Result { + if let Ok(hex) = str::from_utf8(v) { + FromStr::from_str(hex).map_err(E::custom) + } else { + Err(E::invalid_value(de::Unexpected::Bytes(v), &self)) + } + } + + fn visit_str(self, v: &str) -> Result { + FromStr::from_str(v).map_err(E::custom) + } +} + +pub struct BytesVisitor { + expectation: &'static str, + parse_fn: F, +} + +impl BytesVisitor +where + F: FnOnce(&[u8]) -> Result, + Err: fmt::Display, +{ + pub fn new(expectation: &'static str, parse_fn: F) -> Self { + BytesVisitor { + expectation, + parse_fn, + } + } +} + +impl<'de, F, T, Err> de::Visitor<'de> for BytesVisitor +where + F: FnOnce(&[u8]) -> Result, + Err: fmt::Display, +{ + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(self.expectation) + } + + fn visit_bytes(self, v: &[u8]) -> Result { + (self.parse_fn)(v).map_err(E::custom) + } +}