diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 4e6bb504f..d358682ff 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -231,11 +231,9 @@ pub(crate) fn infer_to_python_known( PyList::new_bound(py, items).into_py(py) } ObType::Complex => { - let dict = value.downcast::()?; - let new_dict = PyDict::new_bound(py); - let _ = new_dict.set_item("real", dict.get_item("real")?); - let _ = new_dict.set_item("imag", dict.get_item("imag")?); - new_dict.into_py(py) + let v = value.downcast::()?; + let complex_str = type_serializers::complex::complex_to_str(v); + complex_str.into_py(py) } ObType::Path => value.str()?.into_py(py), ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py), @@ -286,11 +284,9 @@ pub(crate) fn infer_to_python_known( iter.into_py(py) } ObType::Complex => { - let dict = value.downcast::()?; - let new_dict = PyDict::new_bound(py); - let _ = new_dict.set_item("real", dict.get_item("real")?); - let _ = new_dict.set_item("imag", dict.get_item("imag")?); - new_dict.into_py(py) + let v = value.downcast::()?; + let complex_str = type_serializers::complex::complex_to_str(v); + complex_str.into_py(py) } ObType::Unknown => { if let Some(fallback) = extra.fallback { @@ -422,10 +418,8 @@ pub(crate) fn infer_serialize_known( ObType::Bool => serialize!(bool), ObType::Complex => { let v = value.downcast::().map_err(py_err_se_err)?; - let mut map = serializer.serialize_map(Some(2))?; - map.serialize_entry(&"real", &v.real())?; - map.serialize_entry(&"imag", &v.imag())?; - map.end() + let complex_str = type_serializers::complex::complex_to_str(v); + Ok(serializer.collect_str::(&complex_str)?) } ObType::Float | ObType::FloatSubclass => { let v = value.extract::().map_err(py_err_se_err)?; @@ -672,7 +666,7 @@ pub(crate) fn infer_json_key_known<'a>( } Ok(Cow::Owned(key_build.finish())) } - ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => { + ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => { py_err!(PyTypeError; "`{}` not valid as object key", ob_type) } ObType::Dataclass | ObType::PydanticSerializable => { @@ -689,6 +683,10 @@ pub(crate) fn infer_json_key_known<'a>( // FIXME it would be nice to have a "PyCow" which carries ownership of the Python type too Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned())) } + ObType::Complex => { + let v = key.downcast::()?; + Ok(type_serializers::complex::complex_to_str(v).into()) + } ObType::Pattern => Ok(Cow::Owned( key.getattr(intern!(key.py(), "pattern"))? .str()? diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 8d20efaa9..4fddc8790 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -252,6 +252,8 @@ impl ObTypeLookup { ObType::Url } else if ob_type == self.multi_host_url { ObType::MultiHostUrl + } else if ob_type == self.complex { + ObType::Complex } else if ob_type == self.uuid_object.as_ptr() as usize { ObType::Uuid } else if is_pydantic_serializable(op_value) { diff --git a/src/serializers/type_serializers/complex.rs b/src/serializers/type_serializers/complex.rs index 5a525476e..57c5ba069 100644 --- a/src/serializers/type_serializers/complex.rs +++ b/src/serializers/type_serializers/complex.rs @@ -33,22 +33,10 @@ impl TypeSerializer for ComplexSerializer { ) -> PyResult { let py = value.py(); match value.downcast::() { - Ok(py_complex) => match extra.mode { - SerMode::Json => { - let re = py_complex.real(); - let im = py_complex.imag(); - let mut s = format!("{im}j"); - if re != 0.0 { - let mut sign = ""; - if im >= 0.0 { - sign = "+"; - } - s = format!("{re}{sign}{s}"); - } - Ok(s.into_py(py)) - } - _ => Ok(value.into_py(py)), - }, + Ok(py_complex) => Ok(match extra.mode { + SerMode::Json => complex_to_str(py_complex).into_py(py), + _ => value.into_py(py), + }), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), value, extra)?; infer_to_python(value, include, exclude, extra) @@ -70,16 +58,7 @@ impl TypeSerializer for ComplexSerializer { ) -> Result { match value.downcast::() { Ok(py_complex) => { - let re = py_complex.real(); - let im = py_complex.imag(); - let mut s = format!("{im}j"); - if re != 0.0 { - let mut sign = ""; - if im >= 0.0 { - sign = "+"; - } - s = format!("{re}{sign}{s}"); - } + let s = complex_to_str(py_complex); Ok(serializer.collect_str::(&s)?) } Err(_) => { @@ -93,3 +72,17 @@ impl TypeSerializer for ComplexSerializer { "complex" } } + +pub fn complex_to_str(py_complex: &Bound<'_, PyComplex>) -> String { + let re = py_complex.real(); + let im = py_complex.imag(); + let mut s = format!("{im}j"); + if re != 0.0 { + let mut sign = ""; + if im >= 0.0 { + sign = "+"; + } + s = format!("{re}{sign}{s}"); + } + s +} diff --git a/tests/serializers/test_infer.py b/tests/serializers/test_infer.py new file mode 100644 index 000000000..7762f68bc --- /dev/null +++ b/tests/serializers/test_infer.py @@ -0,0 +1,28 @@ +from enum import Enum + +from pydantic_core import SchemaSerializer, core_schema + + +# serializing enum calls methods in serializers::infer +def test_infer_to_python(): + class MyEnum(Enum): + complex_ = complex(1, 2) + + v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.to_python(MyEnum.complex_, mode='json') == '1+2j' + + +def test_infer_serialize(): + class MyEnum(Enum): + complex_ = complex(1, 2) + + v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.to_json(MyEnum.complex_) == b'"1+2j"' + + +def test_infer_json_key(): + class MyEnum(Enum): + complex_ = {complex(1, 2): 1} + + v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.to_json(MyEnum.complex_) == b'{"1+2j":1}'