Skip to content

Added task-creators. #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 3 additions & 104 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os
import sys
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import wraps
from logging import getLogger
from typing import (
TYPE_CHECKING,
Expand All @@ -18,7 +15,6 @@
Optional,
TypeVar,
Union,
overload,
)
from uuid import uuid4

Expand All @@ -35,7 +31,8 @@
from taskiq.result_backends.dummy import DummyResultBackend
from taskiq.serializers.json_serializer import JSONSerializer
from taskiq.state import TaskiqState
from taskiq.utils import maybe_awaitable, remove_suffix
from taskiq.task_creator import BaseTaskCreator
from taskiq.utils import maybe_awaitable
from taskiq.warnings import TaskiqDeprecationWarning

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -97,6 +94,7 @@ def __init__(
TaskiqDeprecationWarning,
stacklevel=2,
)
self.task: BaseTaskCreator = BaseTaskCreator(self)
self.middlewares: "List[TaskiqMiddleware]" = []
self.result_backend = result_backend
self.decorator_class = AsyncTaskiqDecoratedTask
Expand Down Expand Up @@ -255,105 +253,6 @@ def listen(self) -> AsyncGenerator[Union[bytes, AckableMessage], None]:
:return: nothing.
"""

@overload
def task(
self,
task_name: Callable[_FuncParams, _ReturnType],
**labels: Any,
) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: # pragma: no cover
...

@overload
def task(
self,
task_name: Optional[str] = None,
**labels: Any,
) -> Callable[
[Callable[_FuncParams, _ReturnType]],
AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType],
]: # pragma: no cover
...

def task( # type: ignore[misc]
self,
task_name: Optional[str] = None,
**labels: Any,
) -> Any:
"""
Decorator that turns function into a task.

This decorator converts function to
a `TaskiqDecoratedTask` object.

This object can be called as a usual function,
because it uses decorated function in it's __call__
method.

!! You have to use it with parentheses in order to
get autocompletion. Like this:

>>> @task()
>>> def my_func():
>>> ...

:param task_name: custom name of a task, defaults to decorated function's name.
:param labels: some addition labels for task.

:returns: decorator function or AsyncTaskiqDecoratedTask.
"""

def make_decorated_task(
inner_labels: Dict[str, Union[str, int]],
inner_task_name: Optional[str] = None,
) -> Callable[
[Callable[_FuncParams, _ReturnType]],
AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType],
]:
def inner(
func: Callable[_FuncParams, _ReturnType],
) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]:
nonlocal inner_task_name
if inner_task_name is None:
fmodule = func.__module__
if fmodule == "__main__": # pragma: no cover
fmodule = ".".join(
remove_suffix(sys.argv[0], ".py").split(
os.path.sep,
),
)
fname = func.__name__
if fname == "<lambda>":
fname = f"lambda_{uuid4().hex}"
inner_task_name = f"{fmodule}:{fname}"
wrapper = wraps(func)

decorated_task = wrapper(
self.decorator_class(
broker=self,
original_func=func,
labels=inner_labels,
task_name=inner_task_name,
),
)

self._register_task(decorated_task.task_name, decorated_task) # type: ignore

return decorated_task # type: ignore

return inner

if callable(task_name):
# This is an edge case,
# when decorator called without parameters.
return make_decorated_task(
inner_labels=labels or {},
)(task_name)

return make_decorated_task(
inner_task_name=task_name,
inner_labels=labels or {},
)

def register_task(
self,
func: Callable[_FuncParams, _ReturnType],
Expand Down
7 changes: 7 additions & 0 deletions taskiq/decor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import sys
from collections.abc import Coroutine
from datetime import datetime
Expand All @@ -8,13 +9,15 @@
Callable,
Dict,
Generic,
Optional,
TypeVar,
Union,
overload,
)

from typing_extensions import ParamSpec

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.kicker import AsyncKicker
from taskiq.scheduler.created_schedule import CreatedSchedule
from taskiq.task import AsyncTaskiqTask
Expand Down Expand Up @@ -51,11 +54,15 @@ def __init__(
task_name: str,
original_func: Callable[_FuncParams, _ReturnType],
labels: Dict[str, Any],
extra_middlewares: Optional[list[TaskiqMiddleware]] = None,
) -> None:
self.broker = broker
self.task_name = task_name
self.original_func = original_func
self.labels = labels
self.middlewares = copy.copy(broker.middlewares)
if extra_middlewares:
self.middlewares.extend(extra_middlewares)

# This is a hack to make ProcessPoolExecutor work
# with decorated functions.
Expand Down
12 changes: 6 additions & 6 deletions taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def kiq(
logger.debug(
f"Kicking {self.task_name} with args={args} and kwargs={kwargs}.",
)
message = self._prepare_message(*args, **kwargs)
message = self.get_message(*args, **kwargs)
for middleware in self.broker.middlewares:
if middleware.__class__.pre_send != TaskiqMiddleware.pre_send:
message = await maybe_awaitable(middleware.pre_send(message))
Expand Down Expand Up @@ -191,7 +191,7 @@ async def schedule_by_cron(
schedule_id = self.custom_schedule_id
if schedule_id is None:
schedule_id = self.broker.id_generator()
message = self._prepare_message(*args, **kwargs)
message = self.get_message(*args, **kwargs)
cron_offset = None
if isinstance(cron, CronSpec):
cron_str = cron.to_cron()
Expand Down Expand Up @@ -228,7 +228,7 @@ async def schedule_by_time(
schedule_id = self.custom_schedule_id
if schedule_id is None:
schedule_id = self.broker.id_generator()
message = self._prepare_message(*args, **kwargs)
message = self.get_message(*args, **kwargs)
scheduled = ScheduledTask(
schedule_id=schedule_id,
task_name=message.task_name,
Expand Down Expand Up @@ -261,10 +261,10 @@ def _prepare_arg(cls, arg: Any) -> Any:
arg = asdict(arg)
return arg

def _prepare_message(
def get_message(
self,
*args: Any,
**kwargs: Any,
*args: _FuncParams.args,
**kwargs: _FuncParams.kwargs,
) -> TaskiqMessage:
"""
Create a message from args and kwargs.
Expand Down
30 changes: 15 additions & 15 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def callback( # noqa: C901, PLR0912
"Function for task %s is resolved. Executing...",
taskiq_msg.task_name,
)
for middleware in self.broker.middlewares:
for middleware in task.middlewares:
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
taskiq_msg = await maybe_awaitable(
middleware.pre_execute(
Expand All @@ -150,21 +150,32 @@ async def callback( # noqa: C901, PLR0912
message=taskiq_msg,
)

if result.is_err is not None:
for middleware in task.middlewares:
if middleware.__class__.on_error != TaskiqMiddleware.on_error:
await maybe_awaitable(
middleware.on_error(
taskiq_msg,
result,
result.error, # type: ignore
),
)

if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance(
message,
AckableMessage,
):
await maybe_awaitable(message.ack())

for middleware in self.broker.middlewares:
for middleware in task.middlewares:
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))

try:
if not isinstance(result.error, NoResultError):
await self.broker.result_backend.set_result(taskiq_msg.task_id, result)

for middleware in self.broker.middlewares:
for middleware in task.middlewares:
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
await maybe_awaitable(middleware.post_save(taskiq_msg, result))

Expand All @@ -183,7 +194,7 @@ async def callback( # noqa: C901, PLR0912
):
await maybe_awaitable(message.ack())

async def run_task( # noqa: C901, PLR0912, PLR0915
async def run_task( # noqa: C901
self,
target: Callable[..., Any],
message: TaskiqMessage,
Expand Down Expand Up @@ -304,17 +315,6 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
error=found_exception,
labels=message.labels,
)
# If exception is found we execute middlewares.
if found_exception is not None:
for middleware in self.broker.middlewares:
if middleware.__class__.on_error != TaskiqMiddleware.on_error:
await maybe_awaitable(
middleware.on_error(
message,
result,
found_exception,
),
)

return result

Expand Down
Loading