diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 438e9f86..98bba4e9 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -1,9 +1,6 @@ -import os -import sys import warnings from abc import ABC, abstractmethod from collections import defaultdict -from functools import wraps from logging import getLogger from typing import ( TYPE_CHECKING, @@ -18,7 +15,6 @@ Optional, TypeVar, Union, - overload, ) from uuid import uuid4 @@ -35,7 +31,8 @@ from taskiq.result_backends.dummy import DummyResultBackend from taskiq.serializers.json_serializer import JSONSerializer from taskiq.state import TaskiqState -from taskiq.utils import maybe_awaitable, remove_suffix +from taskiq.task_creator import BaseTaskCreator +from taskiq.utils import maybe_awaitable from taskiq.warnings import TaskiqDeprecationWarning if TYPE_CHECKING: # pragma: no cover @@ -97,6 +94,7 @@ def __init__( TaskiqDeprecationWarning, stacklevel=2, ) + self.task: BaseTaskCreator = BaseTaskCreator(self) self.middlewares: "List[TaskiqMiddleware]" = [] self.result_backend = result_backend self.decorator_class = AsyncTaskiqDecoratedTask @@ -255,105 +253,6 @@ def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]: :return: nothing. """ - @overload - def task( - self, - task_name: Callable[_FuncParams, _ReturnType], - **labels: Any, - ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: # pragma: no cover - ... - - @overload - def task( - self, - task_name: Optional[str] = None, - **labels: Any, - ) -> Callable[ - [Callable[_FuncParams, _ReturnType]], - AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], - ]: # pragma: no cover - ... - - def task( # type: ignore[misc] - self, - task_name: Optional[str] = None, - **labels: Any, - ) -> Any: - """ - Decorator that turns function into a task. - - This decorator converts function to - a `TaskiqDecoratedTask` object. - - This object can be called as a usual function, - because it uses decorated function in it's __call__ - method. - - !! You have to use it with parentheses in order to - get autocompletion. Like this: - - >>> @task() - >>> def my_func(): - >>> ... - - :param task_name: custom name of a task, defaults to decorated function's name. - :param labels: some addition labels for task. - - :returns: decorator function or AsyncTaskiqDecoratedTask. - """ - - def make_decorated_task( - inner_labels: Dict[str, Union[str, int]], - inner_task_name: Optional[str] = None, - ) -> Callable[ - [Callable[_FuncParams, _ReturnType]], - AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], - ]: - def inner( - func: Callable[_FuncParams, _ReturnType], - ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: - nonlocal inner_task_name - if inner_task_name is None: - fmodule = func.__module__ - if fmodule == "__main__": # pragma: no cover - fmodule = ".".join( - remove_suffix(sys.argv[0], ".py").split( - os.path.sep, - ), - ) - fname = func.__name__ - if fname == "": - fname = f"lambda_{uuid4().hex}" - inner_task_name = f"{fmodule}:{fname}" - wrapper = wraps(func) - - decorated_task = wrapper( - self.decorator_class( - broker=self, - original_func=func, - labels=inner_labels, - task_name=inner_task_name, - ), - ) - - self._register_task(decorated_task.task_name, decorated_task) # type: ignore - - return decorated_task # type: ignore - - return inner - - if callable(task_name): - # This is an edge case, - # when decorator called without parameters. - return make_decorated_task( - inner_labels=labels or {}, - )(task_name) - - return make_decorated_task( - inner_task_name=task_name, - inner_labels=labels or {}, - ) - def register_task( self, func: Callable[_FuncParams, _ReturnType], diff --git a/taskiq/decor.py b/taskiq/decor.py index 4c28c184..eae73c92 100644 --- a/taskiq/decor.py +++ b/taskiq/decor.py @@ -1,3 +1,4 @@ +import copy import sys from collections.abc import Coroutine from datetime import datetime @@ -8,6 +9,7 @@ Callable, Dict, Generic, + Optional, TypeVar, Union, overload, @@ -15,6 +17,7 @@ from typing_extensions import ParamSpec +from taskiq.abc.middleware import TaskiqMiddleware from taskiq.kicker import AsyncKicker from taskiq.scheduler.created_schedule import CreatedSchedule from taskiq.task import AsyncTaskiqTask @@ -51,11 +54,15 @@ def __init__( task_name: str, original_func: Callable[_FuncParams, _ReturnType], labels: Dict[str, Any], + extra_middlewares: Optional[list[TaskiqMiddleware]] = None, ) -> None: self.broker = broker self.task_name = task_name self.original_func = original_func self.labels = labels + self.middlewares = copy.copy(broker.middlewares) + if extra_middlewares: + self.middlewares.extend(extra_middlewares) # This is a hack to make ProcessPoolExecutor work # with decorated functions. diff --git a/taskiq/kicker.py b/taskiq/kicker.py index 45a55d73..0d44c1ab 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -153,7 +153,7 @@ async def kiq( logger.debug( f"Kicking {self.task_name} with args={args} and kwargs={kwargs}.", ) - message = self._prepare_message(*args, **kwargs) + message = self.get_message(*args, **kwargs) for middleware in self.broker.middlewares: if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: message = await maybe_awaitable(middleware.pre_send(message)) @@ -191,7 +191,7 @@ async def schedule_by_cron( schedule_id = self.custom_schedule_id if schedule_id is None: schedule_id = self.broker.id_generator() - message = self._prepare_message(*args, **kwargs) + message = self.get_message(*args, **kwargs) cron_offset = None if isinstance(cron, CronSpec): cron_str = cron.to_cron() @@ -228,7 +228,7 @@ async def schedule_by_time( schedule_id = self.custom_schedule_id if schedule_id is None: schedule_id = self.broker.id_generator() - message = self._prepare_message(*args, **kwargs) + message = self.get_message(*args, **kwargs) scheduled = ScheduledTask( schedule_id=schedule_id, task_name=message.task_name, @@ -261,10 +261,10 @@ def _prepare_arg(cls, arg: Any) -> Any: arg = asdict(arg) return arg - def _prepare_message( + def get_message( self, - *args: Any, - **kwargs: Any, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, ) -> TaskiqMessage: """ Create a message from args and kwargs. diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index c15fb935..aecd0117 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -125,7 +125,7 @@ async def callback( # noqa: C901, PLR0912 "Function for task %s is resolved. Executing...", taskiq_msg.task_name, ) - for middleware in self.broker.middlewares: + for middleware in task.middlewares: if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute: taskiq_msg = await maybe_awaitable( middleware.pre_execute( @@ -150,13 +150,24 @@ async def callback( # noqa: C901, PLR0912 message=taskiq_msg, ) + if result.is_err is not None: + for middleware in task.middlewares: + if middleware.__class__.on_error != TaskiqMiddleware.on_error: + await maybe_awaitable( + middleware.on_error( + taskiq_msg, + result, + result.error, # type: ignore + ), + ) + if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance( message, AckableMessage, ): await maybe_awaitable(message.ack()) - for middleware in self.broker.middlewares: + for middleware in task.middlewares: if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) @@ -164,7 +175,7 @@ async def callback( # noqa: C901, PLR0912 if not isinstance(result.error, NoResultError): await self.broker.result_backend.set_result(taskiq_msg.task_id, result) - for middleware in self.broker.middlewares: + for middleware in task.middlewares: if middleware.__class__.post_save != TaskiqMiddleware.post_save: await maybe_awaitable(middleware.post_save(taskiq_msg, result)) @@ -183,7 +194,7 @@ async def callback( # noqa: C901, PLR0912 ): await maybe_awaitable(message.ack()) - async def run_task( # noqa: C901, PLR0912, PLR0915 + async def run_task( # noqa: C901 self, target: Callable[..., Any], message: TaskiqMessage, @@ -304,17 +315,6 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 error=found_exception, labels=message.labels, ) - # If exception is found we execute middlewares. - if found_exception is not None: - for middleware in self.broker.middlewares: - if middleware.__class__.on_error != TaskiqMiddleware.on_error: - await maybe_awaitable( - middleware.on_error( - message, - result, - found_exception, - ), - ) return result diff --git a/taskiq/task_creator.py b/taskiq/task_creator.py new file mode 100644 index 00000000..edc407d3 --- /dev/null +++ b/taskiq/task_creator.py @@ -0,0 +1,182 @@ +import os +import sys +import warnings +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + TypeVar, + overload, +) +from uuid import uuid4 + +from typing_extensions import ParamSpec, Self + +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.utils import remove_suffix + +if TYPE_CHECKING: + from taskiq.abc.broker import AsyncBroker + + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + + +class BaseTaskCreator: + """ + Base class for task creator. + + Instances of this class are used to make tasks out of the given functions. + """ + + def __init__(self, broker: "AsyncBroker") -> None: + self._broker = broker + self._task_name: Optional[str] = None + self._labels: Dict[str, Any] = {} + self._middlewares: list[TaskiqMiddleware] = [] + + def name(self, name: str) -> Self: + """Assign custom name to the task.""" + self._task_name = name + return self + + def labels(self, **labels: Any) -> Self: + """Assign custom labels to the task.""" + self._labels = labels + return self + + def middlewares(self, *middlewares: TaskiqMiddleware) -> Self: + """Assign custom middlewares to the task.""" + self._middlewares = list(middlewares) + return self + + def make_task( + self, + task_name: str, + func: Callable[_FuncParams, _ReturnType], + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Make a task from the given function.""" + return AsyncTaskiqDecoratedTask( + broker=self._broker, + original_func=func, + labels=self._labels, + task_name=task_name, + extra_middlewares=self._middlewares, + ) + + def __resolve_name(self, func: Callable[..., Any]) -> str: + """Resolve name of the function.""" + fmodule = func.__module__ + if fmodule == "__main__": # pragma: no cover + fmodule = ".".join( + remove_suffix(sys.argv[0], ".py").split(os.path.sep), + ) + fname = func.__name__ + if fname == "": + fname = f"lambda_{uuid4().hex}" + return f"{fmodule}:{fname}" + + @overload + def __call__( + self, + task_name: Callable[_FuncParams, _ReturnType], + **labels: Any, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: # pragma: no cover + ... + + @overload + def __call__( + self, + task_name: Optional[str] = None, + **labels: Any, + ) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], + ]: # pragma: no cover + ... + + def __call__( # type: ignore[misc] + self, + task_name: Optional[str] = None, + **labels: Any, + ) -> Any: + """ + Decorator that turns function into a task. + + This decorator converts function to + a `TaskiqDecoratedTask` object. + + This object can be called as a usual function, + because it uses decorated function in it's __call__ + method. + + !! You have to use it with parentheses in order to + get autocompletion. Like this: + + >>> @task() + >>> def my_func(): + >>> ... + + :param task_name: custom name of a task, defaults to decorated function's name. + :param labels: some addition labels for task. + + :returns: decorator function or AsyncTaskiqDecoratedTask. + """ + if task_name is not None and isinstance(task_name, str): + warnings.warn( + "Using task_name is deprecated, @broker.task.name('name') instead", + DeprecationWarning, + stacklevel=2, + ) + self._task_name = task_name + if labels: + warnings.warn( + "Using labels is deprecated, @broker.task.labels(**labels) instead", + DeprecationWarning, + stacklevel=2, + ) + self._labels.update(labels) + + def make_decorated_task() -> Callable[ + [Callable[_FuncParams, _ReturnType]], + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], + ]: + def inner( + func: Callable[_FuncParams, _ReturnType], + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + inner_task_name = self._task_name + if inner_task_name is None: + inner_task_name = self.__resolve_name(func) + + wrapper = wraps(func) + decorated_task = wrapper( + self.make_task( + task_name=inner_task_name, + func=func, + ), + ) + + # We need these ignored lines because + # `wrap` function copies __annotations__, + # therefore mypy thinks that decorated_task + # is still an instance of the original function. + self._broker._register_task( # noqa: SLF001 + decorated_task.task_name, # type: ignore + decorated_task, # type: ignore + ) + + return decorated_task # type: ignore + + return inner + + if callable(task_name): + # This is an edge case, + # when decorator called without parameters. + return make_decorated_task()(task_name) + + return make_decorated_task() diff --git a/tests/abc/test_broker.py b/tests/abc/test_broker.py index 2896cd66..57cfad52 100644 --- a/tests/abc/test_broker.py +++ b/tests/abc/test_broker.py @@ -40,7 +40,7 @@ def test_decorator_with_name_success() -> None: """Test that task_name is successfully set.""" tbrok = _TestBroker() - @tbrok.task(task_name="my_task") + @tbrok.task.name("my_task") async def test_func() -> None: """Some test function.""" @@ -52,7 +52,7 @@ def test_decorator_with_labels_success() -> None: """Tests that labels are assigned for task as is.""" tbrok = _TestBroker() - @tbrok.task(label1=1, label2=2) + @tbrok.task.labels(label1=1, label2=2) async def test_func() -> None: """Some test function.""" diff --git a/tests/api/test_scheduler.py b/tests/api/test_scheduler.py index ce09bab6..4380f75a 100644 --- a/tests/api/test_scheduler.py +++ b/tests/api/test_scheduler.py @@ -1,6 +1,6 @@ import asyncio import contextlib -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -16,7 +16,9 @@ async def test_successful() -> None: scheduler = TaskiqScheduler(broker, sources=[LabelScheduleSource(broker)]) scheduler_task = asyncio.create_task(run_scheduler_task(scheduler)) - @broker.task(schedule=[{"time": datetime.utcnow() - timedelta(seconds=1)}]) + @broker.task.labels( + schedule=[{"time": datetime.now(timezone.utc) - timedelta(seconds=1)}], + ) def _() -> None: ... @@ -31,7 +33,7 @@ async def test_cancelation() -> None: broker = AsyncQueueBroker() scheduler = TaskiqScheduler(broker, sources=[LabelScheduleSource(broker)]) - @broker.task(schedule=[{"time": datetime.utcnow()}]) + @broker.task.labels(schedule=[{"time": datetime.now(timezone.utc)}]) def _() -> None: ... diff --git a/tests/middlewares/test_task_retry.py b/tests/middlewares/test_task_retry.py index 7544ab66..61ffb407 100644 --- a/tests/middlewares/test_task_retry.py +++ b/tests/middlewares/test_task_retry.py @@ -14,7 +14,7 @@ async def test_wait_result() -> None: ) runs = 0 - @broker.task(retry_on_error=True) + @broker.task.labels(retry_on_error=True) async def run_task() -> str: nonlocal runs @@ -40,7 +40,7 @@ async def test_wait_result_error() -> None: runs = 0 lock = asyncio.Lock() - @broker.task(retry_on_error=True) + @broker.task.labels(retry_on_error=True) async def run_task() -> str: nonlocal runs, lock @@ -74,7 +74,7 @@ async def test_wait_result_no_result() -> None: runs = 0 lock = asyncio.Lock() - @broker.task(retry_on_error=True) + @broker.task.labels(retry_on_error=True) async def run_task() -> str: nonlocal runs, done, lock @@ -111,7 +111,7 @@ async def test_max_retries() -> None: ) runs = 0 - @broker.task(max_retries=10) + @broker.task.labels(max_retries=10) def run_task() -> str: nonlocal runs @@ -137,7 +137,7 @@ async def test_no_retry() -> None: ) runs = 0 - @broker.task(retry_on_error=False, max_retries=10) + @broker.task.labels(retry_on_error=False, max_retries=10) def run_task() -> str: nonlocal runs diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 57637e95..95d536c7 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -2,18 +2,16 @@ import random import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, ClassVar, List, Optional +from typing import Optional import pytest from taskiq_dependencies import Depends from taskiq.abc.broker import AckableMessage, AsyncBroker -from taskiq.abc.middleware import TaskiqMiddleware from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError from taskiq.message import TaskiqMessage from taskiq.receiver import Receiver -from taskiq.result import TaskiqResult from tests.utils import AsyncQueueBroker @@ -152,43 +150,6 @@ def test_func() -> None: assert result.is_err -@pytest.mark.anyio -async def test_run_task_exception_middlewares() -> None: - """Tests that run_task can run sync tasks.""" - - class _TestMiddleware(TaskiqMiddleware): - found_exceptions: ClassVar[List[BaseException]] = [] - - def on_error( - self, - message: "TaskiqMessage", - result: "TaskiqResult[Any]", - exception: BaseException, - ) -> None: - self.found_exceptions.append(exception) - - def test_func() -> None: - raise ValueError - - broker = InMemoryBroker().with_middlewares(_TestMiddleware()) - receiver = get_receiver(broker) - - result = await receiver.run_task( - test_func, - TaskiqMessage( - task_id="", - task_name="", - labels={}, - args=[], - kwargs={}, - ), - ) - assert result.return_value is None - assert result.is_err - assert len(_TestMiddleware.found_exceptions) == 1 - assert _TestMiddleware.found_exceptions[0].__class__ is ValueError - - @pytest.mark.anyio async def test_callback_success() -> None: """Test that callback function works well.""" diff --git a/tests/receiver/test_receiver_middlewares.py b/tests/receiver/test_receiver_middlewares.py new file mode 100644 index 00000000..f75dc4f8 --- /dev/null +++ b/tests/receiver/test_receiver_middlewares.py @@ -0,0 +1,83 @@ +from typing import Any, ClassVar, List + +import pytest + +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.message import TaskiqMessage +from taskiq.result import TaskiqResult + + +@pytest.mark.anyio +async def test_run_middleware_on_error() -> None: + """Tests that run_task can run sync tasks.""" + + class _TestMiddleware(TaskiqMiddleware): + found_exceptions: ClassVar[List[BaseException]] = [] + + def on_error( + self, + message: "TaskiqMessage", + result: "TaskiqResult[Any]", + exception: BaseException, + ) -> None: + self.found_exceptions.append(exception) + + broker = InMemoryBroker().with_middlewares(_TestMiddleware()) + + @broker.task + def test_func() -> None: + raise ValueError + + task = await test_func.kiq() + + result = await task.wait_result() + assert result.return_value is None + assert result.is_err + assert len(_TestMiddleware.found_exceptions) == 1 + assert _TestMiddleware.found_exceptions[0].__class__ is ValueError + + +@pytest.mark.anyio +async def test_run_middleware_on_success() -> None: + """Tests that run_task can run sync tasks.""" + ran_pre_send = False + ran_post_send = False + ran_pre_execute = False + ran_post_execute = False + + class _TestMiddleware(TaskiqMiddleware): + async def pre_send(self, message: "TaskiqMessage") -> TaskiqMessage: + nonlocal ran_pre_send + ran_pre_send = True + return message + + async def post_send(self, message: "TaskiqMessage") -> None: + nonlocal ran_post_send + ran_post_send = True + + async def pre_execute(self, message: "TaskiqMessage") -> TaskiqMessage: + nonlocal ran_pre_execute + ran_pre_execute = True + return message + + async def post_execute( + self, + message: "TaskiqMessage", + result: "TaskiqResult[Any]", + ) -> None: + nonlocal ran_post_execute + ran_post_execute = True + + broker = InMemoryBroker().with_middlewares(_TestMiddleware()) + + @broker.task + def test_func() -> None: + return None + + task = await test_func.kiq() + + result = await task.wait_result() + assert result.return_value is None + assert not result.is_err + assert ran_pre_send and ran_post_send and ran_pre_execute and ran_post_execute diff --git a/tests/schedule_sources/test_label_based.py b/tests/schedule_sources/test_label_based.py index 9e683917..82c3da12 100644 --- a/tests/schedule_sources/test_label_based.py +++ b/tests/schedule_sources/test_label_based.py @@ -19,10 +19,7 @@ async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: broker = InMemoryBroker() - @broker.task( - task_name="test_task", - schedule=schedule_label, - ) + @broker.task.name("test_task").labels(schedule=schedule_label) def task() -> None: pass @@ -45,8 +42,7 @@ def task() -> None: async def test_label_discovery_no_cron() -> None: broker = InMemoryBroker() - @broker.task( - task_name="test_task", + @broker.task.name("test_task").labels( schedule=[{"args": ["* * * * *"]}], ) def task() -> None: diff --git a/tests/scheduler/test_label_based_sched.py b/tests/scheduler/test_label_based_sched.py index 156e8498..0201f006 100644 --- a/tests/scheduler/test_label_based_sched.py +++ b/tests/scheduler/test_label_based_sched.py @@ -1,5 +1,5 @@ import asyncio -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Dict, List import pytest @@ -18,16 +18,13 @@ "schedule_label", [ pytest.param([{"cron": "* * * * *"}], id="cron"), - pytest.param([{"time": datetime.utcnow()}], id="time"), + pytest.param([{"time": datetime.now(timezone.utc)}], id="time"), ], ) async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: broker = InMemoryBroker() - @broker.task( - task_name="test_task", - schedule=schedule_label, - ) + @broker.task.name("test_task").labels(schedule=schedule_label) def task() -> None: pass @@ -49,8 +46,7 @@ def task() -> None: async def test_label_discovery_no_cron() -> None: broker = InMemoryBroker() - @broker.task( - task_name="test_task", + @broker.task.name("test_task").labels( schedule=[{"args": ["* * * * *"]}], ) def task() -> None: @@ -74,11 +70,10 @@ async def test_task_scheduled_at_time_runs_only_once(mock_sleep: None) -> None: # freeze time to 00:00, so task won't be scheduled by `cron`, only by `time` with freeze_time("00:00:00", tick=True): - @broker.task( - task_name="test_task", + @broker.task.name("test_task").labels( schedule=[ - {"time": datetime.utcnow(), "args": [1]}, - {"time": datetime.utcnow() + timedelta(days=1), "args": [2]}, + {"time": datetime.now(timezone.utc), "args": [1]}, + {"time": datetime.now(timezone.utc) + timedelta(days=1), "args": [2]}, {"cron": "1 * * * *", "args": [3]}, ], )