diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2c4c1cad..e6bda247 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,8 @@ jobs: fail-fast: false matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - surrealdb-version: ["v2.1.0", "v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour + # surrealdb-version: ["v2.1.0", "v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour + surrealdb-version: ["v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour and v2.1.0 does not support async batching name: Python ${{ matrix.python-version }} - SurrealDB ${{ matrix.surrealdb-version }} steps: - name: Checkout repository @@ -42,17 +43,18 @@ jobs: - name: Install dependencies run: pip install -r requirements.txt - - name: Run unit tests (HTTP) + - name: Run unit tests run: python -m unittest discover -s tests env: PYTHONPATH: ./src SURREALDB_URL: http://localhost:8000 + SURREALDB_VERSION: ${{ matrix.surrealdb-version }} - - name: Run unit tests (WebSocket) - run: python -m unittest discover -s tests - env: - PYTHONPATH: ./src - SURREALDB_URL: ws://localhost:8000 +# - name: Run unit tests (WebSocket) +# run: python -m unittest discover -s tests +# env: +# PYTHONPATH: ./src +# SURREALDB_URL: ws://localhost:8000 diff --git a/src/surrealdb/connections/async_http.py b/src/surrealdb/connections/async_http.py index 6e5fa89c..0d027752 100644 --- a/src/surrealdb/connections/async_http.py +++ b/src/surrealdb/connections/async_http.py @@ -98,16 +98,19 @@ def set_token(self, token: str) -> None: self.token = token async def authenticate(self) -> None: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=self.token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=self.token) + self.id = message.id return await self._send(message, "authenticating") async def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) + self.id = message.id await self._send(message, "invalidating") self.token = None async def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) + self.id = message.id response = await self._send(message, "signup") self.check_response_for_result(response, "signup") self.token = response["result"] @@ -115,7 +118,6 @@ async def signup(self, vars: Dict) -> str: async def signin(self, vars: dict) -> dict: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -124,24 +126,26 @@ async def signin(self, vars: dict) -> dict: namespace=vars.get("namespace"), variables=vars.get("variables"), ) + self.id = message.id response = await self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] return response["result"] async def info(self) -> dict: - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) + self.id = message.id response = await self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] async def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.token, RequestMethod.USE, namespace=namespace, database=database, ) + self.id = message.id _ = await self._send(message, "use") self.namespace = namespace self.database = database @@ -152,11 +156,11 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = await self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -167,11 +171,11 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = await self._send(message, "query", bypass=True) return response @@ -184,9 +188,8 @@ async def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) + self.id = message.id response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] @@ -194,7 +197,8 @@ async def create( async def delete( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) + self.id = message.id response = await self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -202,9 +206,8 @@ async def delete( async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) + self.id = message.id response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -213,8 +216,9 @@ async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) + self.id = message.id response = await self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] @@ -228,9 +232,8 @@ async def unset(self, key: str) -> None: async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) + self.id = message.id response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -238,15 +241,15 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) + self.id = message.id response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] async def select(self, thing: str) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) + self.id = message.id response = await self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -254,15 +257,15 @@ async def select(self, thing: str) -> Union[List[dict], dict]: async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) + self.id = message.id response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] async def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) + self.id = message.id response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] @@ -270,9 +273,8 @@ async def version(self) -> str: async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) + self.id = message.id response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] diff --git a/src/surrealdb/connections/async_template.py b/src/surrealdb/connections/async_template.py index 70e6053c..340b809d 100644 --- a/src/surrealdb/connections/async_template.py +++ b/src/surrealdb/connections/async_template.py @@ -7,7 +7,7 @@ class AsyncTemplate: - async def connect(self, url: str) -> Coroutine[Any, Any, None]: + async def connect(self, url: str) -> None: """Connects to a local or remote database endpoint. Args: @@ -18,17 +18,17 @@ async def connect(self, url: str) -> Coroutine[Any, Any, None]: # Connect to a remote endpoint await db.connect('https://cloud.surrealdb.com/rpc'); """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"connect not implemented for: {self}") - async def close(self) -> Coroutine[Any, Any, None]: + async def close(self) -> None: """Closes the persistent connection to the database. Example: await db.close() """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"close not implemented for: {self}") - async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: + async def use(self, namespace: str, database: str) -> None: """Switch to a specific namespace and database. Args: @@ -38,9 +38,9 @@ async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: Example: await db.use('test', 'test') """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"use not implemented for: {self}") - async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: + async def authenticate(self, token: str) -> None: """Authenticate the current connection with a JWT token. Args: @@ -51,7 +51,7 @@ async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: """ raise NotImplementedError(f"authenticate not implemented for: {self}") - async def invalidate(self) -> Coroutine[Any, Any, None]: + async def invalidate(self) -> None: """Invalidate the authentication for the current connection. Example: @@ -59,7 +59,7 @@ async def invalidate(self) -> Coroutine[Any, Any, None]: """ raise NotImplementedError(f"invalidate not implemented for: {self}") - async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: + async def signup(self, vars: Dict) -> str: """Sign this connection up to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signup) @@ -81,7 +81,7 @@ async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: """ raise NotImplementedError(f"signup not implemented for: {self}") - async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: + async def signin(self, vars: Dict) -> str: """Sign this connection in to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signin) @@ -94,9 +94,9 @@ async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: password: 'surrealdb', }) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"signin not implemented for: {self}") - async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: + async def let(self, key: str, value: Any) -> None: """Assign a value as a variable for this connection. Args: @@ -115,7 +115,7 @@ async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: """ raise NotImplementedError(f"let not implemented for: {self}") - async def unset(self, key: str) -> Coroutine[Any, Any, None]: + async def unset(self, key: str) -> None: """Removes a variable for this connection. Args: @@ -124,11 +124,11 @@ async def unset(self, key: str) -> Coroutine[Any, Any, None]: Example: await db.unset('name') """ - raise NotImplementedError(f"let not implemented for: {self}") + raise NotImplementedError(f"unset not implemented for: {self}") async def query( self, query: str, vars: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Run a unset of SurrealQL statements against the database. Args: @@ -145,7 +145,7 @@ async def query( async def select( self, thing: Union[str, RecordID, Table] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Select all records in a table (or other entity), or a specific record, in the database. @@ -158,13 +158,13 @@ async def select( Example: db.select('person') """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"select not implemented for: {self}") async def create( self, thing: Union[str, RecordID, Table], data: Optional[Union[List[dict], dict]] = None, - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Create a record in the database. This function will run the following query in the database: @@ -181,7 +181,7 @@ async def create( async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Update all records in a table, or a specific record, in the database. This function replaces the current document / record data with the @@ -207,11 +207,11 @@ async def update( }, }) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"update not implemented for: {self}") async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Insert records into the database, or to update them if they exist. @@ -239,7 +239,7 @@ async def upsert( async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Modify by deep merging all records in a table, or a specific record, in the database. This function merges the current document / record data with the @@ -267,11 +267,11 @@ async def merge( }) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"merge not implemented for: {self}") async def patch( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Apply JSON Patch changes to all records, or a specific record, in the database. This function patches the current document / record data with @@ -296,11 +296,11 @@ async def patch( { 'op': "remove", "path": "/temp" }, ]) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"patch not implemented for: {self}") async def delete( self, thing: Union[str, RecordID, Table] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Delete all records in a table, or a specific record, from the database. This function will run the following query in the database: @@ -324,11 +324,11 @@ async def info(self) -> Coroutine[Any, Any, dict]: Example: await db.info() """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"info not implemented for: {self}") async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """ Inserts one or multiple records in the database. @@ -343,11 +343,11 @@ async def insert( await db.insert('person', [{ name: 'Tobie'}, { name: 'Jaime'}]) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"insert not implemented for: {self}") async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """ Inserts one or multiple relations in the database. @@ -362,11 +362,9 @@ async def insert_relation( await db.insert_relation('likes', { in: person:1, id: 'object', out: person:2}) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"insert_relation not implemented for: {self}") - async def live( - self, table: Union[str, Table], diff: bool = False - ) -> Coroutine[Any, Any, UUID]: + async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: """Initiates a live query for a specified table name. Args: @@ -381,11 +379,9 @@ async def live( Example: await db.live('person') """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"live not implemented for: {self}") - async def subscribe_live( - self, query_uuid: Union[str, UUID] - ) -> Coroutine[Any, Any, Queue]: + async def subscribe_live(self, query_uuid: Union[str, UUID]) -> Queue: """Returns a queue that receives notification messages from a running live query. Args: @@ -397,9 +393,9 @@ async def subscribe_live( Example: await db.subscribe_live(UUID) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"subscribe_live not implemented for: {self}") - async def kill(self, query_uuid: Union[str, UUID]) -> Coroutine[Any, Any, None]: + async def kill(self, query_uuid: Union[str, UUID]) -> None: """Kills a running live query by it's UUID. Args: @@ -409,4 +405,4 @@ async def kill(self, query_uuid: Union[str, UUID]) -> Coroutine[Any, Any, None]: await db.kill(UUID) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"kill not implemented for: {self}") diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index 4cfe4e0b..bf362278 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -3,9 +3,8 @@ """ import asyncio -import uuid -from asyncio import Queue -from typing import Optional, Any, Dict, Union, List, AsyncGenerator +from asyncio import Queue, Task, Future, AbstractEventLoop +from typing import Optional, Any, Dict, Union, List from uuid import UUID import websockets @@ -46,55 +45,87 @@ def __init__( self.raw_url: str = f"{self.url.raw_url}/rpc" self.host: Optional[str] = self.url.hostname self.port: Optional[int] = self.url.port - self.id: str = str(uuid.uuid4()) self.token: Optional[str] = None self.socket = None + self.loop: AbstractEventLoop | None = None + self.qry: dict[str, Future] = {} + self.recv_task: Task[None] | None = None + self.live_queues: dict[str, list] = {} + + async def _recv_task(self): + assert self.socket + async for data in self.socket: + response = decode(data) + if response_id := response.get("id"): + if fut := self.qry.get(response_id): + fut.set_result(response) + else: + live_id = str(response["result"]["id"]) + for queue in self.live_queues.get(live_id, []): + queue.put_nowait(response["result"]) async def _send( self, message: RequestMessage, process: str, bypass: bool = False ) -> dict: await self.connect() assert ( - self.socket is not None + self.socket is not None and self.loop is not None ) # will always not be None as the self.connect ensures there's a connection - await self.socket.send(message.WS_CBOR_DESCRIPTOR) - response = decode(await self.socket.recv()) + + # setup future to wait for response + fut = self.loop.create_future() + query_id = message.id + self.qry[query_id] = fut + try: + # correlate mesage to query, send and forget it + await self.socket.send(message.WS_CBOR_DESCRIPTOR) + del message + + # wait for response + response = await fut + finally: + del self.qry[query_id] + if bypass is False: self.check_response_for_error(response, process) return response async def connect(self, url: Optional[str] = None) -> None: + if self.socket: + return + # overwrite params if passed in if url is not None: self.url = Url(url) self.raw_url = f"{self.url.raw_url}/rpc" self.host = self.url.hostname self.port = self.url.port - if self.socket is None: - self.socket = await websockets.connect( - self.raw_url, - max_size=None, - subprotocols=[websockets.Subprotocol("cbor")], - ) + + self.socket = await websockets.connect( + self.raw_url, + max_size=None, + subprotocols=[websockets.Subprotocol("cbor")], + ) + self.loop = asyncio.get_running_loop() + self.recv_task = asyncio.create_task(self._recv_task()) async def authenticate(self, token: str) -> dict: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=token) return await self._send(message, "authenticating") async def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) await self._send(message, "invalidating") self.token = None async def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) response = await self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] async def signin(self, vars: Dict[str, Any]) -> str: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -109,14 +140,13 @@ async def signin(self, vars: Dict[str, Any]) -> str: return response["result"] async def info(self) -> Optional[dict]: - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) outcome = await self._send(message, "getting database information") self.check_response_for_result(outcome, "getting database information") return outcome["result"] async def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.id, RequestMethod.USE, namespace=namespace, database=database, @@ -127,7 +157,6 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, @@ -140,7 +169,6 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, @@ -149,23 +177,23 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: return response async def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] async def let(self, key: str, value: Any) -> None: - message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value) + message = RequestMessage(RequestMethod.LET, key=key, value=value) await self._send(message, "letting") async def unset(self, key: str) -> None: - message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) + message = RequestMessage(RequestMethod.UNSET, params=[key]) await self._send(message, "unsetting") async def select( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) response = await self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -179,9 +207,7 @@ async def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] @@ -189,9 +215,7 @@ async def create( async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] @@ -199,9 +223,7 @@ async def update( async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -209,9 +231,7 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] @@ -219,7 +239,7 @@ async def patch( async def delete( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) response = await self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -227,9 +247,7 @@ async def delete( async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -238,7 +256,7 @@ async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) response = await self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") @@ -246,65 +264,58 @@ async def insert_relation( async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: message = RequestMessage( - self.id, RequestMethod.LIVE, table=table, ) response = await self._send(message, "live") self.check_response_for_result(response, "live") - return response["result"] + uuid = response["result"] + assert uuid not in self.live_queues + self.live_queues[str(uuid)] = [] + return uuid - async def subscribe_live( - self, query_uuid: Union[str, UUID] - ) -> AsyncGenerator[dict, None]: + def subscribe_live(self, query_uuid: Union[str, UUID]) -> Queue: result_queue = Queue() + suid = str(query_uuid) + self.live_queues[suid].append(result_queue) - async def listen_live(): - """ - Listen for live updates from the WebSocket and put them into the queue. - """ - try: - while True: - response = decode(await self.socket.recv()) - if response.get("result", {}).get("id") == query_uuid: - await result_queue.put(response["result"]["result"]) - except Exception as e: - print("Error in live subscription:", e) - await result_queue.put({"error": str(e)}) - - asyncio.create_task(listen_live()) - - while True: - result = await result_queue.get() - if "error" in result: - raise Exception(f"Error in live subscription: {result['error']}") - yield result + async def _iter(): + while True: + ret = await result_queue.get() + yield ret["result"] + + return _iter() async def kill(self, query_uuid: Union[str, UUID]) -> None: - message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid) + message = RequestMessage(RequestMethod.KILL, uuid=query_uuid) await self._send(message, "kill") + self.live_queues.pop(str(query_uuid), None) async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] async def close(self): - await self.socket.close() + if self.recv_task: + self.recv_task.cancel() + try: + await self.recv_task + except asyncio.CancelledError: + pass + + if self.socket is not None: + await self.socket.close() async def __aenter__(self) -> "AsyncWsSurrealConnection": """ Asynchronous context manager entry. Initializes a websocket connection and returns the connection instance. """ - self.socket = await websockets.connect( - self.raw_url, max_size=None, subprotocols=[websockets.Subprotocol("cbor")] - ) + await self.connect() return self async def __aexit__(self, exc_type, exc_value, traceback) -> None: @@ -312,5 +323,4 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: Asynchronous context manager exit. Closes the websocket connection upon exiting the context. """ - if self.socket is not None: - await self.socket.close() + await self.close() diff --git a/src/surrealdb/connections/blocking_http.py b/src/surrealdb/connections/blocking_http.py index 01380b48..f7cbc356 100644 --- a/src/surrealdb/connections/blocking_http.py +++ b/src/surrealdb/connections/blocking_http.py @@ -57,16 +57,19 @@ def set_token(self, token: str) -> None: self.token = token def authenticate(self, token: str) -> dict: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=token) + self.id = message.id return self._send(message, "authenticating") def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) + self.id = message.id self._send(message, "invalidating") self.token = None def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) + self.id = message.id response = self._send(message, "signup") self.check_response_for_result(response, "signup") self.token = response["result"] @@ -74,7 +77,6 @@ def signup(self, vars: Dict) -> str: def signin(self, vars: dict) -> str: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -83,24 +85,26 @@ def signin(self, vars: dict) -> str: namespace=vars.get("namespace"), variables=vars.get("variables"), ) + self.id = message.id response = self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] return str(response["result"]) def info(self): - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) + self.id = message.id response = self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.token, RequestMethod.USE, namespace=namespace, database=database, ) + self.id = message.id _ = self._send(message, "use") self.namespace = namespace self.database = database @@ -111,11 +115,11 @@ def query(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -126,11 +130,11 @@ def query_raw(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query", bypass=True) return response @@ -143,15 +147,15 @@ def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) + self.id = message.id response = self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) + self.id = message.id response = self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -159,9 +163,8 @@ def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) + self.id = message.id response = self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -170,8 +173,9 @@ def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) + self.id = message.id response = self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] @@ -185,9 +189,8 @@ def unset(self, key: str) -> None: def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) + self.id = message.id response = self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -195,15 +198,15 @@ def merge( def patch( self, thing: Union[str, RecordID, Table], data: Optional[Dict[Any, Any]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) + self.id = message.id response = self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) + self.id = message.id response = self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -211,15 +214,15 @@ def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) + self.id = message.id response = self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) + self.id = message.id response = self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] @@ -227,9 +230,8 @@ def version(self) -> str: def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) + self.id = message.id response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] diff --git a/src/surrealdb/connections/blocking_ws.py b/src/surrealdb/connections/blocking_ws.py index 97d1c99f..0ba4548d 100644 --- a/src/surrealdb/connections/blocking_ws.py +++ b/src/surrealdb/connections/blocking_ws.py @@ -62,23 +62,25 @@ def _send( return response def authenticate(self, token: str) -> dict: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=token) + self.id = message.id return self._send(message, "authenticating") def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) + self.id = message.id self._send(message, "invalidating") self.token = None def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) + self.id = message.id response = self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] def signin(self, vars: Dict[str, Any]) -> str: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -87,35 +89,37 @@ def signin(self, vars: Dict[str, Any]) -> str: namespace=vars.get("namespace"), variables=vars.get("variables"), ) + self.id = message.id response = self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] return response["result"] def info(self) -> dict: - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) + self.id = message.id response = self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.id, RequestMethod.USE, namespace=namespace, database=database, ) + self.id = message.id self._send(message, "use") def query(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -124,30 +128,34 @@ def query_raw(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query", bypass=True) return response def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) + self.id = message.id response = self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] def let(self, key: str, value: Any) -> None: - message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value) + message = RequestMessage(RequestMethod.LET, key=key, value=value) + self.id = message.id self._send(message, "letting") def unset(self, key: str) -> None: - message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) + message = RequestMessage(RequestMethod.UNSET, params=[key]) + self.id = message.id self._send(message, "unsetting") def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) + self.id = message.id response = self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -161,29 +169,30 @@ def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) + self.id = message.id response = self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] def live(self, table: Union[str, Table], diff: bool = False) -> UUID: message = RequestMessage( - self.id, RequestMethod.LIVE, table=table, ) + self.id = message.id response = self._send(message, "live") self.check_response_for_result(response, "live") return response["result"] def kill(self, query_uuid: Union[str, UUID]) -> None: - message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid) + message = RequestMessage(RequestMethod.KILL, uuid=query_uuid) + self.id = message.id self._send(message, "kill") def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) + self.id = message.id response = self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -191,9 +200,8 @@ def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) + self.id = message.id response = self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -202,8 +210,9 @@ def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) + self.id = message.id response = self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] @@ -211,9 +220,8 @@ def insert_relation( def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) + self.id = message.id response = self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -221,9 +229,8 @@ def merge( def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) + self.id = message.id response = self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] @@ -260,9 +267,8 @@ def subscribe_live( def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) + self.id = message.id response = self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] @@ -270,9 +276,8 @@ def update( def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) + self.id = message.id response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] diff --git a/src/surrealdb/request_message/message.py b/src/surrealdb/request_message/message.py index 509dbb93..59ce5aac 100644 --- a/src/surrealdb/request_message/message.py +++ b/src/surrealdb/request_message/message.py @@ -1,3 +1,5 @@ +import uuid + from surrealdb.request_message.descriptors.cbor_ws import WsCborDescriptor from surrealdb.request_message.methods import RequestMethod @@ -6,7 +8,7 @@ class RequestMessage: WS_CBOR_DESCRIPTOR = WsCborDescriptor() - def __init__(self, id_for_request, method: RequestMethod, **kwargs) -> None: - self.id = id_for_request + def __init__(self, method: RequestMethod, **kwargs) -> None: + self.id = str(uuid.uuid4()) self.method = method self.kwargs = kwargs diff --git a/tests/unit_tests/connections/authenticate/test_async_ws.py b/tests/unit_tests/connections/authenticate/test_async_ws.py index 53707a45..2304191e 100644 --- a/tests/unit_tests/connections/authenticate/test_async_ws.py +++ b/tests/unit_tests/connections/authenticate/test_async_ws.py @@ -21,7 +21,6 @@ async def asyncSetUp(self): async def test_authenticate(self): outcome = await self.connection.authenticate(token=self.connection.token) - await self.connection.socket.close() diff --git a/tests/unit_tests/connections/batch_async/__init__.py b/tests/unit_tests/connections/batch_async/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/connections/batch_async/test_async_ws.py b/tests/unit_tests/connections/batch_async/test_async_ws.py new file mode 100644 index 00000000..6ed0e00a --- /dev/null +++ b/tests/unit_tests/connections/batch_async/test_async_ws.py @@ -0,0 +1,46 @@ +import asyncio +import os +import sys +from unittest import main, IsolatedAsyncioTestCase + +from surrealdb.connections.async_ws import AsyncWsSurrealConnection + + +class TestAsyncWsSurrealConnection(IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.url = "ws://localhost:8000" + self.password = "root" + self.username = "root" + self.vars_params = { + "username": self.username, + "password": self.password, + } + self.database_name = "test_db" + self.namespace = "test_ns" + self.data = { + "username": self.username, + "password": self.password, + } + self.connection = AsyncWsSurrealConnection(self.url) + _ = await self.connection.signin(self.vars_params) + _ = await self.connection.use(namespace=self.namespace, database=self.database_name) + + async def test_batch(self): + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + # async batching doesn't work for surrealDB v2.1.0" or lower + if os.environ.get("SURREALDB_VERSION") == "v2.1.0": + pass + elif python_version == "3.9" or python_version == "3.10": + print("async batching is being bypassed due to python versions 3.9 and 3.10 not supporting async task group") + else: + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(self.connection.query("RETURN sleep(duration::from::millis($d)) or $p**2", dict(d=10 if num%2 else 0, p=num))) for num in range(5)] + + outcome = [t.result() for t in tasks] + self.assertEqual([0, 1, 4, 9, 16], outcome) + await self.connection.close() + + +if __name__ == "__main__": + main() diff --git a/tests/unit_tests/connections/create/test_async_ws.py b/tests/unit_tests/connections/create/test_async_ws.py index 6db2b7fe..ff545baa 100644 --- a/tests/unit_tests/connections/create/test_async_ws.py +++ b/tests/unit_tests/connections/create/test_async_ws.py @@ -35,7 +35,6 @@ async def test_create_string(self): 1 ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_string_with_data(self): outcome = await self.connection.create("user", self.data) @@ -53,7 +52,6 @@ async def test_create_string_with_data(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_string_with_data_and_id(self): first_outcome = await self.connection.create("user:tobie", self.data) @@ -73,7 +71,6 @@ async def test_create_string_with_data_and_id(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_record_id(self): record_id = RecordID("user",1) @@ -87,7 +84,6 @@ async def test_create_record_id(self): ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_record_id_with_data(self): record_id = RecordID("user", 1) @@ -107,7 +103,6 @@ async def test_create_record_id_with_data(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_table(self): table = Table("user") @@ -120,7 +115,6 @@ async def test_create_table(self): ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_table_with_data(self): table = Table("user") @@ -139,7 +133,6 @@ async def test_create_table_with_data(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() diff --git a/tests/unit_tests/connections/delete/test_async_ws.py b/tests/unit_tests/connections/delete/test_async_ws.py index 4a6f73f7..cc6ad40c 100644 --- a/tests/unit_tests/connections/delete/test_async_ws.py +++ b/tests/unit_tests/connections/delete/test_async_ws.py @@ -43,14 +43,12 @@ async def test_delete_string(self): self.check_no_change(outcome) outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(outcome, []) - await self.connection.socket.close() async def test_delete_record_id(self): first_outcome = await self.connection.delete(self.record_id) self.check_no_change(first_outcome) outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(outcome, []) - await self.connection.socket.close() async def test_delete_table(self): await self.connection.query("CREATE user:jaime SET name = 'Jaime';") @@ -59,7 +57,6 @@ async def test_delete_table(self): self.assertEqual(2, len(first_outcome)) outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(outcome, []) - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/info/test_async_ws.py b/tests/unit_tests/connections/info/test_async_ws.py index 7c77fe1e..85816de2 100644 --- a/tests/unit_tests/connections/info/test_async_ws.py +++ b/tests/unit_tests/connections/info/test_async_ws.py @@ -21,7 +21,6 @@ async def asyncSetUp(self): async def test_info(self): outcome = await self.connection.info() - await self.connection.socket.close() # TODO => confirm that the info is what we expect diff --git a/tests/unit_tests/connections/insert/test_async_ws.py b/tests/unit_tests/connections/insert/test_async_ws.py index 88f51dc8..9f6ee9d2 100644 --- a/tests/unit_tests/connections/insert/test_async_ws.py +++ b/tests/unit_tests/connections/insert/test_async_ws.py @@ -46,7 +46,6 @@ async def test_insert_string_with_data(self): 2 ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_insert_record_id_result_error(self): record_id = RecordID("user","tobie") @@ -59,7 +58,6 @@ async def test_insert_record_id_result_error(self): True ) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/insert_relation/test_async_ws.py b/tests/unit_tests/connections/insert_relation/test_async_ws.py index 9ae31560..3f484575 100644 --- a/tests/unit_tests/connections/insert_relation/test_async_ws.py +++ b/tests/unit_tests/connections/insert_relation/test_async_ws.py @@ -76,7 +76,6 @@ async def test_insert_relation_record_ids(self): ) await self.connection.query("DELETE user;") await self.connection.query("DELETE likes;") - await self.connection.socket.close() async def test_insert_relation_record_id(self): data = { @@ -94,7 +93,6 @@ async def test_insert_relation_record_id(self): ) await self.connection.query("DELETE user;") await self.connection.query("DELETE likes;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/let/test_async_ws.py b/tests/unit_tests/connections/let/test_async_ws.py index a4faeda5..fb6a98f7 100644 --- a/tests/unit_tests/connections/let/test_async_ws.py +++ b/tests/unit_tests/connections/let/test_async_ws.py @@ -30,7 +30,6 @@ async def test_let(self): outcome = await self.connection.query('SELECT * FROM person WHERE name.first = $name.first') self.assertEqual({'first': 'Tobie', 'last': 'Morgan Hitchcock'}, outcome[0]["name"]) await self.connection.query("DELETE person;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/live/test_async_http.py b/tests/unit_tests/connections/live/test_async_http.py index 80b054e7..1253dea4 100644 --- a/tests/unit_tests/connections/live/test_async_http.py +++ b/tests/unit_tests/connections/live/test_async_http.py @@ -26,7 +26,6 @@ async def test_query(self): outcome = await self.connection.live("user") self.assertEqual(UUID, type(outcome)) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/live/test_async_ws.py b/tests/unit_tests/connections/live/test_async_ws.py index 80b054e7..1253dea4 100644 --- a/tests/unit_tests/connections/live/test_async_ws.py +++ b/tests/unit_tests/connections/live/test_async_ws.py @@ -26,7 +26,6 @@ async def test_query(self): outcome = await self.connection.live("user") self.assertEqual(UUID, type(outcome)) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/merge/test_async_ws.py b/tests/unit_tests/connections/merge/test_async_ws.py index 1f151842..92a5f09b 100644 --- a/tests/unit_tests/connections/merge/test_async_ws.py +++ b/tests/unit_tests/connections/merge/test_async_ws.py @@ -51,7 +51,6 @@ async def test_merge_string(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_string_with_data(self): first_outcome = await self.connection.merge("user:tobie", self.data) @@ -59,7 +58,6 @@ async def test_merge_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_record_id(self): first_outcome = await self.connection.merge(self.record_id) @@ -67,7 +65,6 @@ async def test_merge_record_id(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_record_id_with_data(self): outcome = await self.connection.merge(self.record_id, self.data) @@ -77,7 +74,6 @@ async def test_merge_record_id_with_data(self): outcome[0] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_table(self): table = Table("user") @@ -87,7 +83,6 @@ async def test_merge_table(self): self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_table_with_data(self): table = Table("user") @@ -96,7 +91,6 @@ async def test_merge_table_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/patch/test_async_ws.py b/tests/unit_tests/connections/patch/test_async_ws.py index f4c54327..d6c408bd 100644 --- a/tests/unit_tests/connections/patch/test_async_ws.py +++ b/tests/unit_tests/connections/patch/test_async_ws.py @@ -43,7 +43,6 @@ async def test_patch_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_patch_record_id_with_data(self): outcome = await self.connection.patch(self.record_id, self.data) @@ -51,7 +50,6 @@ async def test_patch_record_id_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_patch_table_with_data(self): table = Table("user") @@ -60,7 +58,6 @@ async def test_patch_table_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/query/test_async_http.py b/tests/unit_tests/connections/query/test_async_http.py index 444f6fb1..4044d34b 100644 --- a/tests/unit_tests/connections/query/test_async_http.py +++ b/tests/unit_tests/connections/query/test_async_http.py @@ -55,7 +55,6 @@ async def test_query(self): ] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/query/test_async_ws.py b/tests/unit_tests/connections/query/test_async_ws.py index 7ffcf483..541bb116 100644 --- a/tests/unit_tests/connections/query/test_async_ws.py +++ b/tests/unit_tests/connections/query/test_async_ws.py @@ -55,7 +55,6 @@ async def test_query(self): ] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/select/test_async_ws.py b/tests/unit_tests/connections/select/test_async_ws.py index 71328260..11f830db 100644 --- a/tests/unit_tests/connections/select/test_async_ws.py +++ b/tests/unit_tests/connections/select/test_async_ws.py @@ -42,7 +42,6 @@ async def test_select(self): await self.connection.query("DELETE user;") await self.connection.query("DELETE users;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/signin/test_async_ws.py b/tests/unit_tests/connections/signin/test_async_ws.py index 23245bdc..c49618c1 100644 --- a/tests/unit_tests/connections/signin/test_async_ws.py +++ b/tests/unit_tests/connections/signin/test_async_ws.py @@ -47,8 +47,6 @@ async def test_signin_root(self): self.assertIsNotNone(response) _ = await self.connection.query("DELETE user;") _ = await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() async def test_signin_namespace(self): connection = AsyncWsSurrealConnection(self.url) @@ -61,8 +59,6 @@ async def test_signin_namespace(self): self.assertIsNotNone(response) _ = await self.connection.query("DELETE user;") _ = await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() async def test_signin_database(self): connection = AsyncWsSurrealConnection(self.url) @@ -76,8 +72,6 @@ async def test_signin_database(self): self.assertIsNotNone(response) _ = await self.connection.query("DELETE user;") _ = await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() async def test_signin_record(self): vars = { @@ -99,8 +93,6 @@ async def test_signin_record(self): await self.connection.query("DELETE user;") await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/subscribe_live/test_async_ws.py b/tests/unit_tests/connections/subscribe_live/test_async_ws.py index 3afdb882..2254382c 100644 --- a/tests/unit_tests/connections/subscribe_live/test_async_ws.py +++ b/tests/unit_tests/connections/subscribe_live/test_async_ws.py @@ -50,8 +50,6 @@ async def test_live_subscription(self): # Cleanup the subscription await self.pub_connection.query("DELETE user;") - await self.pub_connection.socket.close() - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/unset/test_async_ws.py b/tests/unit_tests/connections/unset/test_async_ws.py index bb4633d1..113c7ec6 100644 --- a/tests/unit_tests/connections/unset/test_async_ws.py +++ b/tests/unit_tests/connections/unset/test_async_ws.py @@ -38,7 +38,6 @@ async def test_unset(self): self.assertEqual([], outcome) await self.connection.query("DELETE person;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/update/test_async_ws.py b/tests/unit_tests/connections/update/test_async_ws.py index 0452d59b..3e8c2e93 100644 --- a/tests/unit_tests/connections/update/test_async_ws.py +++ b/tests/unit_tests/connections/update/test_async_ws.py @@ -51,7 +51,6 @@ async def test_update_string(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_string_with_data(self): first_outcome = await self.connection.update("user:tobie", self.data) @@ -59,7 +58,6 @@ async def test_update_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_record_id(self): first_outcome = await self.connection.update(self.record_id) @@ -67,7 +65,6 @@ async def test_update_record_id(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_record_id_with_data(self): outcome = await self.connection.update(self.record_id, self.data) @@ -77,7 +74,6 @@ async def test_update_record_id_with_data(self): outcome[0] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_table(self): table = Table("user") @@ -87,7 +83,6 @@ async def test_update_table(self): self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_table_with_data(self): table = Table("user") @@ -96,7 +91,6 @@ async def test_update_table_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/upsert/test_async_ws.py b/tests/unit_tests/connections/upsert/test_async_ws.py index 5bcd949b..e237d29a 100644 --- a/tests/unit_tests/connections/upsert/test_async_ws.py +++ b/tests/unit_tests/connections/upsert/test_async_ws.py @@ -52,7 +52,6 @@ async def test_upsert_string(self): outcome = await self.connection.query("SELECT * FROM user;") # self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_string_with_data(self): first_outcome = await self.connection.upsert("user:tobie", self.data) @@ -60,7 +59,6 @@ async def test_upsert_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") # self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_record_id(self): first_outcome = await self.connection.upsert(self.record_id) @@ -68,7 +66,6 @@ async def test_upsert_record_id(self): outcome = await self.connection.query("SELECT * FROM user;") # self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_record_id_with_data(self): outcome = await self.connection.upsert(self.record_id, self.data) @@ -78,7 +75,6 @@ async def test_upsert_record_id_with_data(self): # outcome[0] # ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_table(self): table = Table("user") @@ -89,7 +85,6 @@ async def test_upsert_table(self): # self.check_no_change(outcome[1], random_id=True) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_table_with_data(self): table = Table("user") @@ -99,7 +94,6 @@ async def test_upsert_table_with_data(self): self.assertEqual(2, len(outcome)) # self.check_change(outcome[0], random_id=True) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/version/test_async_ws.py b/tests/unit_tests/connections/version/test_async_ws.py index 5e602d60..fa29f2b8 100644 --- a/tests/unit_tests/connections/version/test_async_ws.py +++ b/tests/unit_tests/connections/version/test_async_ws.py @@ -21,7 +21,6 @@ async def asyncSetUp(self): async def test_version(self): self.assertEqual(str, type(await self.connection.version())) - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/data_types/test_datetimes.py b/tests/unit_tests/data_types/test_datetimes.py index 311a6f6a..a221710b 100644 --- a/tests/unit_tests/data_types/test_datetimes.py +++ b/tests/unit_tests/data_types/test_datetimes.py @@ -76,7 +76,6 @@ async def test_datetime_iso_format(self): # Cleanup await self.connection.query("DELETE datetime_tests;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/request_message/descriptors/test_cbor_ws.py b/tests/unit_tests/request_message/descriptors/test_cbor_ws.py index 2db08555..ed3dece6 100644 --- a/tests/unit_tests/request_message/descriptors/test_cbor_ws.py +++ b/tests/unit_tests/request_message/descriptors/test_cbor_ws.py @@ -7,19 +7,19 @@ class TestWsCborAdapter(TestCase): def test_use_pass(self): - message = RequestMessage(1, RequestMethod.USE, namespace="ns", database="db") + message = RequestMessage(RequestMethod.USE, namespace="ns", database="db") outcome = message.WS_CBOR_DESCRIPTOR self.assertIsInstance(outcome, bytes) def test_use_fail(self): - message = RequestMessage(1, RequestMethod.USE, namespace="ns", database=1) + message = RequestMessage(RequestMethod.USE, namespace="ns", database=1) with self.assertRaises(ValueError) as context: message.WS_CBOR_DESCRIPTOR self.assertEqual( "Invalid schema for Cbor WS encoding for use: {'params': [{1: ['must be of string type']}]}", str(context.exception) ) - message = RequestMessage(1, RequestMethod.USE, namespace="ns") + message = RequestMessage(RequestMethod.USE, namespace="ns") with self.assertRaises(ValueError) as context: message.WS_CBOR_DESCRIPTOR self.assertEqual( @@ -28,18 +28,17 @@ def test_use_fail(self): ) def test_info_pass(self): - message = RequestMessage(1, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) outcome = message.WS_CBOR_DESCRIPTOR self.assertIsInstance(outcome, bytes) def test_version_pass(self): - message = RequestMessage(1, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) outcome = message.WS_CBOR_DESCRIPTOR self.assertIsInstance(outcome, bytes) def test_signin_pass_root(self): message = RequestMessage( - 1, RequestMethod.SIGN_IN, username="user", password="pass" @@ -49,7 +48,6 @@ def test_signin_pass_root(self): def test_signin_pass_root_with_none(self): message = RequestMessage( - 1, RequestMethod.SIGN_IN, username="username", password="pass", @@ -62,7 +60,6 @@ def test_signin_pass_root_with_none(self): def test_signin_pass_account(self): message = RequestMessage( - 1, RequestMethod.SIGN_IN, username="username", password="pass", @@ -75,7 +72,6 @@ def test_signin_pass_account(self): def test_authenticate_pass(self): message = RequestMessage( - 1, RequestMethod.AUTHENTICATE, token="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJTdXJyZWFsREIiLCJpYXQiOjE1MTYyMzkwMjIsIm5iZiI6MTUxNjIzOTAyMiwiZXhwIjoxODM2NDM5MDIyLCJOUyI6InRlc3QiLCJEQiI6InRlc3QiLCJTQyI6InVzZXIiLCJJRCI6InVzZXI6dG9iaWUifQ.N22Gp9ze0rdR06McGj1G-h2vu6a6n9IVqUbMFJlOxxA" ) @@ -84,7 +80,6 @@ def test_authenticate_pass(self): def test_invalidate_pass(self): message = RequestMessage( - 1, RequestMethod.INVALIDATE ) outcome = message.WS_CBOR_DESCRIPTOR @@ -92,7 +87,6 @@ def test_invalidate_pass(self): def test_let_pass(self): message = RequestMessage( - 1, RequestMethod.LET, key="key", value="value" @@ -102,7 +96,6 @@ def test_let_pass(self): def test_unset_pass(self): message = RequestMessage( - 1, RequestMethod.UNSET, params=["one", "two", "three"] ) @@ -111,7 +104,6 @@ def test_unset_pass(self): def test_live_pass(self): message = RequestMessage( - 1, RequestMethod.LIVE, table="person" ) @@ -120,7 +112,6 @@ def test_live_pass(self): def test_kill_pass(self): message = RequestMessage( - 1, RequestMethod.KILL, uuid="0189d6e3-8eac-703a-9a48-d9faa78b44b9" ) @@ -129,7 +120,6 @@ def test_kill_pass(self): def test_query_pass(self): message = RequestMessage( - 1, RequestMethod.QUERY, query="query" ) @@ -138,7 +128,6 @@ def test_query_pass(self): def test_create_pass_params(self): message = RequestMessage( - 1, RequestMethod.CREATE, collection="person", data={"table": "table"} @@ -148,7 +137,6 @@ def test_create_pass_params(self): def test_insert_pass_dict(self): message = RequestMessage( - 1, RequestMethod.INSERT, collection="table", params={"key": "value"} @@ -158,7 +146,6 @@ def test_insert_pass_dict(self): def test_insert_pass_list(self): message = RequestMessage( - 1, RequestMethod.INSERT, collection="table", params=[{"key": "value"}, {"key": "value"}] @@ -168,7 +155,6 @@ def test_insert_pass_list(self): def test_patch_pass(self): message = RequestMessage( - 1, RequestMethod.PATCH, collection="table", params=[{"key": "value"}, {"key": "value"}] @@ -178,7 +164,6 @@ def test_patch_pass(self): def test_select_pass(self): message = RequestMessage( - 1, RequestMethod.SELECT, params=["table", "user"], ) @@ -187,7 +172,6 @@ def test_select_pass(self): def test_update_pass(self): message = RequestMessage( - 1, RequestMethod.UPDATE, record_id="test", data={"table": "table"} @@ -197,7 +181,6 @@ def test_update_pass(self): def test_upsert_pass(self): message = RequestMessage( - 1, RequestMethod.UPSERT, record_id="test", data={"table": "table"} @@ -207,7 +190,6 @@ def test_upsert_pass(self): def test_merge_pass(self): message = RequestMessage( - 1, RequestMethod.MERGE, record_id="test", data={"table": "table"} @@ -217,7 +199,6 @@ def test_merge_pass(self): def test_delete_pass(self): message = RequestMessage( - 1, RequestMethod.DELETE, record_id="test", ) diff --git a/tests/unit_tests/request_message/test_request_message.py b/tests/unit_tests/request_message/test_request_message.py index 90c87541..7b1afd0b 100644 --- a/tests/unit_tests/request_message/test_request_message.py +++ b/tests/unit_tests/request_message/test_request_message.py @@ -9,7 +9,7 @@ def setUp(self): self.method = RequestMethod.USE def test_init(self): - request_message = RequestMessage(1, self.method, one="two", three="four") + request_message = RequestMessage(self.method, one="two", three="four") self.assertEqual(request_message.method, self.method) self.assertEqual(request_message.kwargs, {"one": "two", "three": "four"})