Skip to content

Fix busy loop in server_receive #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 57 additions & 48 deletions asgi_testclient/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import inspect
import json as _json
from asyncio import Queue, ensure_future, sleep
from functools import partial
from http.cookies import SimpleCookie
from contextlib import asynccontextmanager
from http import HTTPStatus
from urllib.parse import urlsplit, urlencode
from wsgiref.headers import Headers as _Headers
Expand Down Expand Up @@ -38,7 +41,7 @@ def is_asgi2(app: Union[ASGI2App, ASGI3App]) -> bool:
if inspect.isclass(app):
return True

if hasattr(app, "__call__") and inspect.iscoroutinefunction(app.__call__): #type: ignore
if hasattr(app, "__call__") and inspect.iscoroutinefunction(app.__call__): # type: ignore
return False

return not inspect.iscoroutinefunction(app)
Expand Down Expand Up @@ -86,7 +89,7 @@ def ok(self) -> bool:

This attribute checks if the status code of the response is between
400 and 600 to see if there was a client error or a server error. If
the status code is between 200 and 400, this will return True.
the status code is between 200 and 400, this will return True.

This is **not** a check to see if the response code is ``200 OK``. """
try:
Expand Down Expand Up @@ -128,13 +131,12 @@ def json(self, **kwargs):


class WsSession:
def __init__(self, app: ASGI3App, scope: Scope) -> None:
self._client: Queue = Queue() # For ASGI app to send messages
self._server: Queue = Queue() # For client session to send message to ASGI app
def __init__(self) -> None:
self._client = asyncio.Queue() # For ASGI app to send messages
self._server = asyncio.Queue() # For client session to send message to ASGI app

self._server_task = ensure_future(
app(scope, self._server_receive, self._server_send)
)
async def serve(self, app: ASGI3App, scope):
await app(scope, self._server_receive, self._server_send)

async def _start(self) -> None:
""" Start conmunication between client and ASGI app. """
Expand Down Expand Up @@ -185,20 +187,6 @@ async def receive_json(self):
async def close(self):
""" Finish session with server, wait until handler is done. """
await self.send({"type": "websocket.disconnect", "code": 1000})
while not self._server_task.done():
await sleep(0.1)


class WsContextManager:
def __init__(self, ws_session):
self.ws_session = ws_session

async def __aenter__(self):
self.ws_session = await self.ws_session
return self.ws_session

async def __aexit__(self, *args):
await self.ws_session.close()


class TestClient:
Expand All @@ -223,6 +211,7 @@ def __init__(
app: Union[ASGI2App, ASGI3App],
raise_server_exceptions: bool = True,
base_url: str = "http://testserver",
cookies: dict[str, str] = None,
) -> None:

if is_asgi2(app):
Expand All @@ -233,6 +222,7 @@ def __init__(
self.app = cast(ASGI3App, app)
self.base_url = base_url
self.raise_server_exceptions = raise_server_exceptions
self.cookies = cookies

async def send(
self,
Expand All @@ -243,11 +233,9 @@ async def send(
headers: Headers = {},
json: dict = {},
subprotocols: Optional[List[str]] = None,
ws: bool = False,
) -> Union[Response, WsSession]:
) -> Response:
""" Handle request/response cycle seting up request, creating scope dict,
calling the app and awaiting in the handler to return the response. """
self.url = url
scheme, host, port, path, query = self.prepare_url(url, params=params)
req_headers: ReqHeaders = self.prepare_headers(host, headers)

Expand All @@ -263,23 +251,19 @@ async def send(
"server": [host, port],
}

if ws:
scope["type"] = "websocket"
scope["scheme"] = "ws"
scope["subprotocols"] = subprotocols or []
session = WsSession(self.app, scope)
await session._start()
return session

scope["type"] = "http"
self.prepare_body(req_headers, data=data, json=json)
try:
self.__response_started = False
self.__response_complete = False
await self.app(scope, self._receive, self._send)
await self.app(scope, self._receive, partial(self._send, url=url))
except Exception as ex:
if self.raise_server_exceptions:
raise ex from None
if cookie_header := self._response.headers.get('set-cookie'):
cookie = SimpleCookie()
cookie.load(cookie_header)
self.cookies.update({k: v.value for k, v in cookie.items()})
return self._response

def prepare_url(self, url: str, params: Params) -> Url:
Expand Down Expand Up @@ -329,6 +313,9 @@ def prepare_headers(self, host: str, headers: Headers = []) -> ReqHeaders:
_headers: list = [(b"host", host.encode())]
_headers += self.default_headers

if self.cookies:
_headers += [(b"cookie", ";".join(f"{key}={value}" for key, value in self.cookies.items()).encode())]

if headers:
if isinstance(headers, dict):
_headers += [
Expand Down Expand Up @@ -357,14 +344,14 @@ def prepare_body(
)
headers.append((b"content-length", str(len(self._body)).encode()))

async def _send(self, message: Message) -> None:
async def _send(self, message: Message, url: str) -> None:
""" Mimic ASGI send awaitable, create and set response object. """
if message["type"] == "http.response.start":
assert (
not self.__response_started
), 'Received multiple "http.response.start" messages.'
self._response = Response(
self.url,
url,
status_code=message["status"],
headers=[
(k.decode(), v.decode()) for k, v in message["headers"]
Expand All @@ -383,9 +370,6 @@ async def _send(self, message: Message) -> None:
self.__response_complete = True

async def _receive(self) -> Message:
""" Mimic ASGI receive awaitable.
TODO: Mimic Stream requests
"""
return {"type": "http.request", "body": self._body, "more_body": False}

async def get(self, url, **kwargs):
Expand All @@ -409,12 +393,37 @@ async def delete(self, url, **kwargs):
async def patch(self, url, **kwargs):
return await self.send("PATCH", url, **kwargs)

async def ws_connect(self, url, subprotocols=None, **kwargs):
return await self.send(
"GET", url, subprotocols=subprotocols, ws=True, **kwargs
)
@asynccontextmanager
async def ws_session(self, url, subprotocols=None, params=None, headers=None):
scheme, host, port, path, query = self.prepare_url(url, params=params)
req_headers: ReqHeaders = self.prepare_headers(host, headers)

def ws_session(self, url, subprotocols=None, **kwargs):
return WsContextManager(
self.send("GET", url, subprotocols=subprotocols, ws=True, **kwargs)
)
scope = {
"http_version": "1.1",
"method": "GET",
"path": path,
"root_path": "",
"scheme": scheme,
"query_string": query,
"headers": req_headers,
"client": ("testclient", 5000),
"server": [host, port],
"type": "websocket",
"subprotocols": subprotocols or [],
}
session = WsSession()
async with run_coro(session.serve(self.app, scope)):
await session._start()
try:
yield session
finally:
await session.close()


@asynccontextmanager
async def run_coro(coro):
task = asyncio.create_task(coro)
try:
yield
finally:
await task
Loading