Skip to content

Commit dd1d32e

Browse files
authored
Merge pull request #50 from qaspen-python/feature/support_enum
Supported ENUM PostgreSQL Type
2 parents 27d9ca8 + 998b940 commit dd1d32e

File tree

3 files changed

+147
-38
lines changed

3 files changed

+147
-38
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/tests/test_value_converter.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import uuid
3+
from enum import Enum
34
from ipaddress import IPv4Address
45
from typing import Any, Dict, List, Union
56

@@ -282,6 +283,10 @@ async def test_deserialization_composite_into_python(
282283
"""Test that it's possible to deserialize custom postgresql type."""
283284
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
284285
await psql_pool.execute("DROP TYPE IF EXISTS all_types")
286+
await psql_pool.execute("DROP TYPE IF EXISTS inner_type")
287+
await psql_pool.execute("DROP TYPE IF EXISTS enum_type")
288+
await psql_pool.execute("CREATE TYPE enum_type AS ENUM ('sad', 'ok', 'happy')")
289+
await psql_pool.execute("CREATE TYPE inner_type AS (inner_value VARCHAR, some_enum enum_type)")
285290
create_type_query = """
286291
CREATE type all_types AS (
287292
bytea_ BYTEA,
@@ -316,7 +321,9 @@ async def test_deserialization_composite_into_python(
316321
uuid_arr UUID ARRAY,
317322
inet_arr INET ARRAY,
318323
jsonb_arr JSONB ARRAY,
319-
json_arr JSON ARRAY
324+
json_arr JSON ARRAY,
325+
test_inner_value inner_type,
326+
test_enum_type enum_type
320327
)
321328
"""
322329
create_table_query = """
@@ -329,8 +336,14 @@ async def test_deserialization_composite_into_python(
329336
await psql_pool.execute(
330337
querystring=create_table_query,
331338
)
339+
340+
class TestEnum(Enum):
341+
OK = "ok"
342+
SAD = "sad"
343+
HAPPY = "happy"
344+
332345
await psql_pool.execute(
333-
querystring="INSERT INTO for_test VALUES (ROW($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32))", # noqa: E501
346+
querystring="INSERT INTO for_test VALUES (ROW($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, ROW($33, $34), $35))", # noqa: E501
334347
parameters=[
335348
b"Bytes",
336349
"Some String",
@@ -394,9 +407,16 @@ async def test_deserialization_composite_into_python(
394407
},
395408
),
396409
],
410+
"inner type value",
411+
"happy",
412+
TestEnum.OK,
397413
],
398414
)
399415

416+
class ValidateModelForInnerValueType(BaseModel):
417+
inner_value: str
418+
some_enum: TestEnum
419+
400420
class ValidateModelForCustomType(BaseModel):
401421
bytea_: List[int]
402422
varchar_: str
@@ -432,6 +452,9 @@ class ValidateModelForCustomType(BaseModel):
432452
jsonb_arr: List[Dict[str, List[Union[str, int, List[str]]]]]
433453
json_arr: List[Dict[str, List[Union[str, int, List[str]]]]]
434454

455+
test_inner_value: ValidateModelForInnerValueType
456+
test_enum_type: TestEnum
457+
435458
class TopLevelModel(BaseModel):
436459
custom_type: ValidateModelForCustomType
437460

@@ -446,6 +469,41 @@ class TopLevelModel(BaseModel):
446469
assert isinstance(model_result[0], TopLevelModel)
447470

448471

472+
async def test_enum_type(psql_pool: ConnectionPool) -> None:
473+
"""Test that we can decode ENUM type from PostgreSQL."""
474+
475+
class TestEnum(Enum):
476+
OK = "ok"
477+
SAD = "sad"
478+
HAPPY = "happy"
479+
480+
class TestStrEnum(str, Enum):
481+
OK = "ok"
482+
SAD = "sad"
483+
HAPPY = "happy"
484+
485+
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
486+
await psql_pool.execute("DROP TYPE IF EXISTS mood")
487+
await psql_pool.execute(
488+
"CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')",
489+
)
490+
await psql_pool.execute(
491+
"CREATE TABLE for_test (test_mood mood, test_mood2 mood)",
492+
)
493+
494+
await psql_pool.execute(
495+
querystring="INSERT INTO for_test VALUES ($1, $2)",
496+
parameters=[TestEnum.HAPPY, TestEnum.OK],
497+
)
498+
499+
qs_result = await psql_pool.execute(
500+
"SELECT * FROM for_test",
501+
)
502+
assert qs_result.result()[0]["test_mood"] == TestEnum.HAPPY.value
503+
assert qs_result.result()[0]["test_mood"] != TestEnum.HAPPY
504+
assert qs_result.result()[0]["test_mood2"] == TestStrEnum.OK
505+
506+
449507
async def test_custom_type_as_parameter(
450508
psql_pool: ConnectionPool,
451509
) -> None:

src/value_converter.rs

Lines changed: 85 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
22
use macaddr::{MacAddr6, MacAddr8};
3-
use postgres_types::{Field, FromSql, Kind};
3+
use postgres_types::{Field, FromSql, Kind, ToSql};
44
use serde_json::{json, Map, Value};
55
use std::{fmt::Debug, net::IpAddr};
66
use uuid::Uuid;
@@ -15,7 +15,7 @@ use pyo3::{
1515
Bound, Py, PyAny, Python, ToPyObject,
1616
};
1717
use tokio_postgres::{
18-
types::{to_sql_checked, ToSql, Type},
18+
types::{to_sql_checked, Type},
1919
Column, Row,
2020
};
2121

@@ -447,6 +447,15 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
447447
return Ok(PythonDTO::PyIpAddress(id_address));
448448
}
449449

450+
// It's used for Enum.
451+
// If StrEnum is used on Python side,
452+
// we simply stop at the `is_instance_of::<PyString>``.
453+
if let Ok(value_attr) = parameter.getattr("value") {
454+
if let Ok(possible_string) = value_attr.extract::<String>() {
455+
return Ok(PythonDTO::PyString(possible_string));
456+
}
457+
}
458+
450459
Err(RustPSQLDriverError::PyToRustValueConversionError(format!(
451460
"Can not covert you type {parameter} into inner one",
452461
)))
@@ -691,15 +700,12 @@ fn postgres_bytes_to_py(
691700
pub fn composite_postgres_to_py(
692701
py: Python<'_>,
693702
fields: &Vec<Field>,
694-
buf: &[u8],
703+
buf: &mut &[u8],
704+
custom_decoders: &Option<Py<PyDict>>,
695705
) -> RustPSQLDriverPyResult<Py<PyAny>> {
696-
let mut vec_buf: Vec<u8> = vec![];
697-
vec_buf.extend_from_slice(buf);
698-
let mut buf: &[u8] = vec_buf.as_slice();
699-
700706
let result_py_dict: Bound<'_, PyDict> = PyDict::new_bound(py);
701707

702-
let num_fields = postgres_types::private::read_be_i32(&mut buf).map_err(|err| {
708+
let num_fields = postgres_types::private::read_be_i32(buf).map_err(|err| {
703709
RustPSQLDriverError::RustToPyValueConversionError(format!(
704710
"Cannot read bytes data from PostgreSQL: {err}"
705711
))
@@ -713,66 +719,111 @@ pub fn composite_postgres_to_py(
713719
}
714720

715721
for field in fields {
716-
let oid = postgres_types::private::read_be_i32(&mut buf).map_err(|err| {
722+
let oid = postgres_types::private::read_be_i32(buf).map_err(|err| {
717723
RustPSQLDriverError::RustToPyValueConversionError(format!(
718724
"Cannot read bytes data from PostgreSQL: {err}"
719725
))
720726
})? as u32;
727+
721728
if oid != field.type_().oid() {
722729
return Err(RustPSQLDriverError::RustToPyValueConversionError(
723730
"unexpected OID".into(),
724731
));
725732
}
726733

727-
result_py_dict.set_item(
728-
field.name(),
729-
postgres_bytes_to_py(py, field.type_(), &mut buf, false)?.to_object(py),
730-
)?;
734+
match field.type_().kind() {
735+
Kind::Simple | Kind::Array(_) => {
736+
result_py_dict.set_item(
737+
field.name(),
738+
postgres_bytes_to_py(py, field.type_(), buf, false)?.to_object(py),
739+
)?;
740+
}
741+
Kind::Enum(_) => {
742+
result_py_dict.set_item(
743+
field.name(),
744+
postgres_bytes_to_py(py, &Type::VARCHAR, buf, false)?.to_object(py),
745+
)?;
746+
}
747+
_ => {
748+
let (_, tail) = buf.split_at(4_usize);
749+
*buf = tail;
750+
result_py_dict.set_item(
751+
field.name(),
752+
raw_bytes_data_process(py, buf, field.name(), field.type_(), custom_decoders)?
753+
.to_object(py),
754+
)?;
755+
}
756+
}
731757
}
732758

733759
Ok(result_py_dict.to_object(py))
734760
}
735761

736-
/// Convert type from postgres to python type.
762+
/// Process raw bytes from `PostgreSQL`.
737763
///
738764
/// # Errors
739765
///
740766
/// May return Err Result if cannot convert postgres
741767
/// type into rust one.
742-
pub fn postgres_to_py(
768+
pub fn raw_bytes_data_process(
743769
py: Python<'_>,
744-
row: &Row,
745-
column: &Column,
746-
column_i: usize,
770+
raw_bytes_data: &mut &[u8],
771+
column_name: &str,
772+
column_type: &Type,
747773
custom_decoders: &Option<Py<PyDict>>,
748774
) -> RustPSQLDriverPyResult<Py<PyAny>> {
749-
let raw_bytes_data = row.col_buffer(column_i);
750-
751775
if let Some(custom_decoders) = custom_decoders {
752776
let py_encoder_func = custom_decoders
753777
.bind(py)
754-
.get_item(column.name().to_lowercase());
778+
.get_item(column_name.to_lowercase());
755779

756780
if let Ok(Some(py_encoder_func)) = py_encoder_func {
757-
return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind());
781+
return Ok(py_encoder_func
782+
.call((raw_bytes_data.to_vec(),), None)?
783+
.unbind());
758784
}
759785
}
760786

761-
let column_type = column.type_();
762-
match raw_bytes_data {
763-
Some(mut raw_bytes_data) => match column_type.kind() {
764-
Kind::Simple | Kind::Array(_) => {
765-
postgres_bytes_to_py(py, column.type_(), &mut raw_bytes_data, true)
766-
}
767-
Kind::Composite(fields) => composite_postgres_to_py(py, fields, raw_bytes_data),
768-
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
769-
column.type_().to_string(),
770-
)),
771-
},
772-
None => Ok(py.None()),
787+
match column_type.kind() {
788+
Kind::Simple | Kind::Array(_) => {
789+
postgres_bytes_to_py(py, column_type, raw_bytes_data, true)
790+
}
791+
Kind::Composite(fields) => {
792+
composite_postgres_to_py(py, fields, raw_bytes_data, custom_decoders)
793+
}
794+
Kind::Enum(_) => postgres_bytes_to_py(py, &Type::VARCHAR, raw_bytes_data, true),
795+
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
796+
column_type.to_string(),
797+
)),
773798
}
774799
}
775800

801+
/// Convert type from postgres to python type.
802+
///
803+
/// # Errors
804+
///
805+
/// May return Err Result if cannot convert postgres
806+
/// type into rust one.
807+
pub fn postgres_to_py(
808+
py: Python<'_>,
809+
row: &Row,
810+
column: &Column,
811+
column_i: usize,
812+
custom_decoders: &Option<Py<PyDict>>,
813+
) -> RustPSQLDriverPyResult<Py<PyAny>> {
814+
let raw_bytes_data = row.col_buffer(column_i);
815+
if let Some(mut raw_bytes_data) = raw_bytes_data {
816+
return raw_bytes_data_process(
817+
py,
818+
&mut raw_bytes_data,
819+
column.name(),
820+
column.type_(),
821+
custom_decoders,
822+
);
823+
}
824+
Ok(py.None())
825+
}
826+
776827
/// Convert python List of Dict type or just Dict into serde `Value`.
777828
///
778829
/// # Errors

0 commit comments

Comments
 (0)