diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 13edefa081..00435f626b 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -26,7 +26,7 @@ class Reducer: NAME = None - def __init__(self, *args: List[str]) -> None: + def __init__(self, *args: str) -> None: self._args = args self._field = None self._alias = None @@ -116,7 +116,7 @@ def __init__(self, query: str = "*") -> None: self._add_scores = False self._scorer = "TFIDF" - def load(self, *fields: List[str]) -> "AggregateRequest": + def load(self, *fields: str) -> "AggregateRequest": """ Indicate the fields to be returned in the response. These fields are returned in addition to any others implicitly specified. @@ -223,7 +223,7 @@ def limit(self, offset: int, num: int) -> "AggregateRequest": self._aggregateplan.extend(_limit.build_args()) return self - def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": + def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest": """ Indicate how the results should be sorted. This can also be used for *top-N* style queries diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index bc48fa9aa8..80d9b35728 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -542,7 +542,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa def aggregate( self, - query: Union[str, Query], + query: Union[AggregateRequest, Cursor], query_params: Dict[str, Union[str, int, float]] = None, ): """ @@ -573,7 +573,7 @@ def aggregate( ) def _get_aggregate_result( - self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool + self, raw: List, query: Union[AggregateRequest, Cursor], has_cursor: bool ): if has_cursor: if isinstance(query, Cursor): @@ -967,7 +967,7 @@ async def search( async def aggregate( self, - query: Union[str, Query], + query: Union[AggregateResult, Cursor], query_params: Dict[str, Union[str, int, float]] = None, ): """