diff --git a/src/django_twc_toolbox/decorators.py b/src/django_twc_toolbox/decorators.py new file mode 100644 index 0000000..1117973 --- /dev/null +++ b/src/django_twc_toolbox/decorators.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable +from functools import wraps +from typing import Any +from typing import TypeVar +from urllib.parse import urlparse + +import httpx +from django.conf import settings + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class req: + """ + A decorator that sends a HTTP request to the specified endpoint after the + decorated function has been called. This is useful for health checks on + periodic tasks or dispatching webhooks. + + Args: + url: The URL of the endpoint to send the request to. + + Returns: + The decorated function. + """ + + def __init__(self, url: str, *, client: httpx.Client | None = None) -> None: + self.url = url + self.client = client or httpx.Client() + + def __call__(self, func: Callable[..., T]) -> Callable[..., T]: + parsed_url = urlparse(self.url) + if not parsed_url.scheme or not parsed_url.netloc: + msg = f"Invalid URL: {self.url}" + logger.error(msg) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + # Execute the decorated function + result = func(*args, **kwargs) + + # Send the HTTP request if not in DEBUG mode + if not settings.DEBUG: + self.client.get(self.url) + + return result + + return wrapper diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..ac4364e --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from unittest import mock + +import httpx +import pytest + +from django_twc_toolbox.decorators import req + + +def test_req(monkeypatch): + monkeypatch.setattr("django.conf.settings.DEBUG", False) + url = "https://example.com/test-hc" + + @req(url) + def foo(): + pass + + with mock.patch.object(httpx, "get") as mock_httpx_get: + foo() + + mock_httpx_get.assert_called_once_with(url, timeout=5) + + +@pytest.mark.parametrize("url", ["", "invalid"]) +def test_req_invalid_url(url, caplog): + with caplog.at_level("ERROR"): + + @req(url) + def foo(): + pass + + assert "Invalid URL" in caplog.text