diff --git a/CHANGELOG.md b/CHANGELOG.md index cef12ef841..a7b59d9a26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2714](https://github.com/Pycord-Development/pycord/pull/2714)) - Added the ability to pass a `datetime.time` object to `format_dt`. ([#2747](https://github.com/Pycord-Development/pycord/pull/2747)) +- Added support for type hinting slash command options with `typing.Annotated`. + ([#2782](https://github.com/Pycord-Development/pycord/pull/2782)) - Added `discord.Interaction.created_at`. ([#2801](https://github.com/Pycord-Development/pycord/pull/2801)) diff --git a/discord/commands/core.py b/discord/commands/core.py index 06996bcaa1..095e20ce1e 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -73,9 +73,9 @@ from .options import Option, OptionChoice if sys.version_info >= (3, 11): - from typing import Annotated, get_args, get_origin + from typing import Annotated, Literal, get_args, get_origin else: - from typing_extensions import Annotated, get_args, get_origin + from typing_extensions import Annotated, Literal, get_args, get_origin __all__ = ( "_BaseCommand", @@ -806,6 +806,24 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: if option == inspect.Parameter.empty: option = str + if self._is_typing_literal(option): + literal_values = get_args(option) + if not all(isinstance(v, (str, int, float)) for v in literal_values): + raise TypeError( + "Literal values must be str, int, or float for Discord choices." + ) + + value_type = type(literal_values[0]) + if not all(isinstance(v, value_type) for v in literal_values): + raise TypeError("All Literal values must be of the same type.") + + option = Option( + value_type, + choices=[ + OptionChoice(name=str(v), value=v) for v in literal_values + ], + ) + if self._is_typing_annotated(option): type_hint = get_args(option)[0] metadata = option.__metadata__ @@ -908,6 +926,9 @@ def _is_typing_union(self, annotation): def _is_typing_optional(self, annotation): return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore + def _is_typing_literal(self, annotation): + return get_origin(annotation) is Literal + def _is_typing_annotated(self, annotation): return get_origin(annotation) is Annotated