From aa56721052e41fef9708a1d72a8fe0b7250e87e7 Mon Sep 17 00:00:00 2001 From: bevzd Date: Sun, 29 Dec 2024 17:09:59 +0200 Subject: [PATCH] added asyncio support for sqlalchemy --- test/test_core.py | 91 +++++++++++++++------- ydb_sqlalchemy/sqlalchemy/__init__.py | 104 +++++++++++++++++--------- 2 files changed, 134 insertions(+), 61 deletions(-) diff --git a/test/test_core.py b/test/test_core.py index 3f7d808..e54d900 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -2,11 +2,13 @@ import datetime from decimal import Decimal from typing import NamedTuple - +from sqlalchemy.schema import CreateTable, DropTable import pytest import sqlalchemy as sa import ydb -from sqlalchemy import Column, Integer, String, Table, Unicode +from sqlalchemy import Column, Integer, String, Table, Unicode, insert, select +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.testing import async_test from sqlalchemy.testing.fixtures import TablesTest, TestBase, config from ydb._grpc.v4.protos import ydb_common_pb2 @@ -380,11 +382,11 @@ def test_auto_partitioning_partition_size_mb(self, connection, auto_partitioning ], ) def test_auto_partitioning_min_partitions_count( - self, - connection, - auto_partitioning_min_partitions_count, - res, - metadata, + self, + connection, + auto_partitioning_min_partitions_count, + res, + metadata, ): desc = self._create_table_and_get_desc( connection, @@ -401,11 +403,11 @@ def test_auto_partitioning_min_partitions_count( ], ) def test_auto_partitioning_max_partitions_count( - self, - connection, - auto_partitioning_max_partitions_count, - res, - metadata, + self, + connection, + auto_partitioning_max_partitions_count, + res, + metadata, ): desc = self._create_table_and_get_desc( connection, @@ -422,11 +424,11 @@ def test_auto_partitioning_max_partitions_count( ], ) def test_uniform_partitions( - self, - connection, - uniform_partitions, - res, - metadata, + self, + connection, + uniform_partitions, + res, + metadata, ): desc = self._create_table_and_get_desc( connection, @@ -444,11 +446,11 @@ def test_uniform_partitions( ], ) def test_partition_at_keys( - self, - connection, - partition_at_keys, - res, - metadata, + self, + connection, + partition_at_keys, + res, + metadata, ): desc = self._create_table_and_get_desc( connection, @@ -535,10 +537,10 @@ def test_interactive_transaction(self, connection_no_trans, connection, isolatio @pytest.mark.parametrize( "isolation_level", ( - IsolationLevel.ONLINE_READONLY, - IsolationLevel.ONLINE_READONLY_INCONSISTENT, - IsolationLevel.STALE_READONLY, - IsolationLevel.AUTOCOMMIT, + IsolationLevel.ONLINE_READONLY, + IsolationLevel.ONLINE_READONLY_INCONSISTENT, + IsolationLevel.STALE_READONLY, + IsolationLevel.AUTOCOMMIT, ), ) def test_not_interactive_transaction(self, connection_no_trans, connection, isolation_level): @@ -673,6 +675,13 @@ def ydb_pool(self, ydb_driver): finally: loop.run_until_complete(session_pool.stop()) + @pytest.mark.asyncio + async def test_crud_commands_on_session(self, async_testing_engine): + engine = async_testing_engine() + maker = async_sessionmaker(engine) + async with maker() as session: + await session.execute(sa.text("SELECT 1 as value")) + class TestCredentials(TestBase): __backend__ = True @@ -725,6 +734,36 @@ def test_ydb_credentials_bad(self, query_client_settings, driver_config_for_cred assert "Invalid password" in str(excinfo.value) +class TestAsyncCRUD(TestBase): + __only_on__ = "yql+ydb_async" + + @async_test + async def test_crud(self, async_testing_engine, metadata): + engine = async_testing_engine() + maker = async_sessionmaker(engine) + async with maker() as session: + res = await session.scalar(sa.text("SELECT 1 as value")) + + assert res == 1 + table = Table( + 'test', + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + async with maker() as session: + await session.execute(CreateTable(table)) + + # t1 = TestModel(id=1, name="test") + stmt = insert(table).values(id=1, name="test") + await session.execute(stmt) + stmt = select(table).where(table.c.id == 1) + t2 = await session.scalar(stmt) + assert t2 == 1 + await session.execute(DropTable(table)) + + class TestUpsert(TablesTest): __backend__ = True diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 7bdf8f8..60e91c2 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -5,11 +5,11 @@ import collections import collections.abc -from typing import Any, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Mapping, Optional, Sequence, Tuple, Union, Type import sqlalchemy as sa import ydb -from sqlalchemy import util +from sqlalchemy import util, AsyncAdaptedQueuePool, URL, Pool from sqlalchemy.engine import characteristics, reflection from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect from sqlalchemy.exc import NoSuchTableError @@ -22,10 +22,10 @@ from ydb_sqlalchemy.sqlalchemy.dml import Upsert from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler - +from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncCursor +from ydb_dbapi.utils import CursorStatus from . import types - OLD_SA = sa.__version__ < "2." @@ -86,12 +86,12 @@ def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: ydb_dbap dialect.reset_ydb_request_settings(dbapi_connection) def set_characteristic( - self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings + self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings ) -> None: dialect.set_ydb_request_settings(dbapi_connection, value) def get_characteristic( - self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection + self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection ) -> ydb.BaseRequestSettings: return dialect.get_ydb_request_settings(dbapi_connection) @@ -179,11 +179,11 @@ def dbapi(cls): return cls.import_dbapi() def __init__( - self, - json_serializer=None, - json_deserializer=None, - _add_declare_for_yql_stmt_vars=False, - **kwargs, + self, + json_serializer=None, + json_deserializer=None, + _add_declare_for_yql_stmt_vars=False, + **kwargs, ): super().__init__(**kwargs) @@ -295,9 +295,9 @@ def get_isolation_level(self, dbapi_connection: ydb_dbapi.Connection) -> str: return dbapi_connection.get_isolation_level() def set_ydb_request_settings( - self, - dbapi_connection: ydb_dbapi.Connection, - value: ydb.BaseRequestSettings, + self, + dbapi_connection: ydb_dbapi.Connection, + value: ydb.BaseRequestSettings, ) -> None: dbapi_connection.set_ydb_request_settings(value) @@ -332,10 +332,10 @@ def _handle_column_name(self, variable): return "`" + variable + "`" def _format_variables( - self, - statement: str, - parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]], - execute_many: bool, + self, + statement: str, + parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]], + execute_many: bool, ) -> Tuple[str, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: formatted_statement = statement formatted_parameters = None @@ -370,7 +370,7 @@ def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types): return f"{declarations}\n{statement}" def __merge_parameters_values_and_types( - self, values: Mapping[str, Any], types: Mapping[str, Any], execute_many: bool + self, values: Mapping[str, Any], types: Mapping[str, Any], execute_many: bool ) -> Sequence[Mapping[str, ydb.TypedValue]]: if isinstance(values, collections.abc.Mapping): values = [values] @@ -387,11 +387,11 @@ def __merge_parameters_values_and_types( return result_list if execute_many else result_list[0] def _prepare_ydb_query( - self, - statement: str, - context: Optional[DefaultExecutionContext] = None, - parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None, - execute_many: bool = False, + self, + statement: str, + context: Optional[DefaultExecutionContext] = None, + parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None, + execute_many: bool = False, ) -> Tuple[Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: is_ddl = context.isddl if context is not None else False @@ -417,21 +417,21 @@ def do_ping(self, dbapi_connection: ydb_dbapi.Connection) -> bool: return True def do_executemany( - self, - cursor: ydb_dbapi.Cursor, - statement: str, - parameters: Optional[Sequence[Mapping[str, Any]]], - context: Optional[DefaultExecutionContext] = None, + self, + cursor: ydb_dbapi.Cursor, + statement: str, + parameters: Optional[Sequence[Mapping[str, Any]]], + context: Optional[DefaultExecutionContext] = None, ) -> None: operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=True) cursor.executemany(operation, parameters) def do_execute( - self, - cursor: ydb_dbapi.Cursor, - statement: str, - parameters: Optional[Mapping[str, Any]] = None, - context: Optional[DefaultExecutionContext] = None, + self, + cursor: ydb_dbapi.Cursor, + statement: str, + parameters: Optional[Mapping[str, Any]] = None, + context: Optional[DefaultExecutionContext] = None, ) -> None: operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=False) is_ddl = context.isddl if context is not None else False @@ -441,10 +441,44 @@ def do_execute( cursor.execute(operation, parameters) + + + +class AsyncCursor(AdaptedAsyncCursor): + def fetchone(self): + return self._cursor._fetchone_from_buffer() + + def fetchmany(self, size=None): + size = size or self.arraysize + return self._cursor._fetchmany_from_buffer(size) + + def fetchall(self): + return self._cursor._fetchall_from_buffer() + + def close(self): + self._cursor._state = CursorStatus.closed + + +class AsyncConnection(AdaptedAsyncConnection): + def cursor(self): + return AsyncCursor(self._connection.cursor()) + + class AsyncYqlDialect(YqlDialect): driver = "ydb_async" is_async = True supports_statement_cache = True + def __init__(self, json_serializer=None, + json_deserializer=None, + _add_declare_for_yql_stmt_vars=True, + **kwargs): + super().__init__(json_serializer=json_serializer, json_deserializer=json_deserializer, + _add_declare_for_yql_stmt_vars=_add_declare_for_yql_stmt_vars, + **kwargs) + def connect(self, *cargs, **cparams): - return AdaptedAsyncConnection(util.await_only(self.dbapi.async_connect(*cargs, **cparams))) + return AsyncConnection(util.await_only(self.dbapi.async_connect(*cargs, **cparams))) + + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: + return AsyncAdaptedQueuePool \ No newline at end of file