diff --git a/bitcoin/core/serialize.py b/bitcoin/core/serialize.py index c89b73e4..4fd49feb 100644 --- a/bitcoin/core/serialize.py +++ b/bitcoin/core/serialize.py @@ -67,6 +67,23 @@ def __init__(self, msg, obj, padding): self.obj = obj self.padding = padding + +class DeserializationValueBoundsError(SerializationError): + """Deserialized value out of bounds + + Thrown by deserialize() when a deserialized value turns out to be out + of allowed bounds + """ + + def __init__(self, msg, klass=None, value=None, + upper_bound=None, lower_bound=None): + super().__init__(msg) + self.klass = klass + self.value = value + self.upper_bound = upper_bound + self.lower_bound = lower_bound + + def ser_read(f, n): """Read from a stream safely @@ -211,14 +228,49 @@ def stream_serialize(cls, i, f): @classmethod def stream_deserialize(cls, f): r = _bord(ser_read(f, 1)) + if r < 0xfd: return r - elif r == 0xfd: - return struct.unpack(b' MAX_SIZE: + raise DeserializationValueBoundsError( + "non-canonical compact size for variable integer: " + "value too large", + klass=cls, value=v, lower_bound=lower_bound, + upper_bound=MAX_SIZE) + + return v class BytesSerializer(Serializer): @@ -363,6 +415,7 @@ def uint256_to_shortstr(u): 'SerializationError', 'SerializationTruncationError', 'DeserializationExtraDataError', + 'DeserializationValueBoundsError', 'ser_read', 'Serializable', 'ImmutableSerializable', diff --git a/bitcoin/tests/test_serialize.py b/bitcoin/tests/test_serialize.py index a5e45330..928cfc4f 100644 --- a/bitcoin/tests/test_serialize.py +++ b/bitcoin/tests/test_serialize.py @@ -44,30 +44,45 @@ class Test_VarIntSerializer(unittest.TestCase): def test(self): def T(value, expected): expected = unhexlify(expected) + expected_int = VarIntSerializer.deserialize(expected) + self.assertEqual(value, expected_int) actual = VarIntSerializer.serialize(value) self.assertEqual(actual, expected) roundtrip = VarIntSerializer.deserialize(actual) self.assertEqual(value, roundtrip) + T(0x0, b'00') T(0xfc, b'fc') T(0xfd, b'fdfd00') T(0xffff, b'fdffff') + T(0x1234, b'fd3412') T(0x10000, b'fe00000100') - T(0xffffffff, b'feffffffff') - T(0x100000000, b'ff0000000001000000') - T(0xffffffffffffffff, b'ffffffffffffffffff') + T(0x1234567, b'fe67452301') + T(0x2000000, b'fe00000002') - def test_non_optimal(self): - def T(serialized, expected_value): - serialized = unhexlify(serialized) - actual_value = VarIntSerializer.deserialize(serialized) - self.assertEqual(actual_value, expected_value) - T(b'fd0000', 0) - T(b'fd3412', 0x1234) - T(b'fe00000000', 0) - T(b'fe67452301', 0x1234567) - T(b'ff0000000000000000', 0) - T(b'ffefcdab8967452301', 0x123456789abcdef) + with self.assertRaises(DeserializationValueBoundsError): + T(0x2000001, b'fe01000002') + + with self.assertRaises(DeserializationValueBoundsError): + T(0xffffffff, b'feffffffff') + + with self.assertRaises(DeserializationValueBoundsError): + T(0x100000000, b'ff0000000001000000') + + with self.assertRaises(DeserializationValueBoundsError): + T(0xffffffffffffffff, b'ffffffffffffffffff') + + with self.assertRaises(DeserializationValueBoundsError): + T(0, b'fd0000') + + with self.assertRaises(DeserializationValueBoundsError): + T(0, b'fe00000000') + + with self.assertRaises(DeserializationValueBoundsError): + T(0, b'ff0000000000000000') + + with self.assertRaises(DeserializationValueBoundsError): + T(0x123456789abcdef, b'ffefcdab8967452301') def test_truncated(self): def T(serialized):