Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite options #1251

Merged
merged 12 commits into from
May 5, 2022
69 changes: 34 additions & 35 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,45 +655,49 @@ def __init__(self, func: Callable, *args, **kwargs) -> None:
if self.permissions and self.default_permission:
self.default_permission = False

def _parse_options(self, params) -> List[Option]:
if list(params.items())[0][0] == "self":
temp = list(params.items())
temp.pop(0)
params = dict(temp)
def _check_required_params(self, params):
params = iter(params.items())
required_params = ["self", "context"] if self.attached_to_group or self.cog else ["context"]
for p in required_params:
try:
next(params)
except StopIteration:
raise ClientException(f'Callback for {self.name} command is missing "{p}" parameter.')

# next we have the 'ctx' as the next parameter
try:
next(params)
except StopIteration:
raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
return params

def _parse_options(self, params, *, check_params: bool = True) -> List[Option]:
if check_params:
params = self._check_required_params(params)

final_options = []
for p_name, p_obj in params:

option = p_obj.annotation
if option == inspect.Parameter.empty:
option = str

if self._is_typing_union(option):
if self._is_typing_optional(option):
option = Option(option.__args__[0], "No description provided", required=False)
option = Option(option.__args__[0], "No description provided", required=False) # type: ignore # union type
else:
option = Option(option.__args__, "No description provided")
option = Option(option.__args__, "No description provided") # type: ignore # union type
krittick marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(option, Option):
option = Option(option, "No description provided")
if isinstance(p_obj.default, Option): # arg: type = Option(...)
p_obj.default.input_type = SlashCommandOptionType.from_datatype(option)
option = p_obj.default
else: # arg: Option(...) = default
option = Option(option, "No description provided")

if option.default is None:
if p_obj.default == inspect.Parameter.empty:
option.default = None
else:
if not p_obj.default == inspect.Parameter.empty and not isinstance(p_obj.default, Option):
option.default = p_obj.default
option.required = False

if option.name is None:
option.name = p_name
option._parameter_name = p_name
if option.name != p_name or option._parameter_name is None:
option._parameter_name = p_name

_validate_names(option)
_validate_descriptions(option)
Expand All @@ -703,25 +707,15 @@ def _parse_options(self, params) -> List[Option]:
return final_options

def _match_option_param_names(self, params, options):
if list(params.items())[0][0] == "self":
temp = list(params.items())
temp.pop(0)
params = dict(temp)
params = iter(params.items())
params = self._check_required_params(params)

# next we have the 'ctx' as the next parameter
try:
next(params)
except StopIteration:
raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')

check_annotations = [
check_annotations: List[Callable[[Option, Type], bool]] = [
lambda o, a: o.input_type == SlashCommandOptionType.string
and o.converter is not None, # pass on converters
lambda o, a: isinstance(o.input_type, SlashCommandOptionType), # pass on slash cmd option type enums
lambda o, a: isinstance(o._raw_type, tuple) and a == Union[o._raw_type], # type: ignore # union types
lambda o, a: self._is_typing_optional(a) and not o.required and o._raw_type in a.__args__, # optional
lambda o, a: inspect.isclass(a) and issubclass(a, o._raw_type), # 'normal' types
lambda o, a: isinstance(a, type) and issubclass(a, o._raw_type), # 'normal' types
]
for o in options:
_validate_names(o)
Expand All @@ -732,15 +726,14 @@ def _match_option_param_names(self, params, options):
raise ClientException(f"Too many arguments passed to the options kwarg.")
p_obj = p_obj.annotation

if not any(c(o, p_obj) for c in check_annotations):
if not any(check(o, p_obj) for check in check_annotations):
raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.")
o._parameter_name = p_name

left_out_params = OrderedDict()
left_out_params[""] = "" # bypass first iter (ctx)
for k, v in params:
left_out_params[k] = v
options.extend(self._parse_options(left_out_params))
options.extend(self._parse_options(left_out_params, check_params=False))

return options

Expand All @@ -752,6 +745,12 @@ def _is_typing_union(self, annotation):
def _is_typing_optional(self, annotation):
return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore

def _set_cog(self, cog):
prev = self.cog
super()._set_cog(cog)
if (prev is None and cog is not None) or (prev is not None and cog is None):
self.options = self._parse_options(self._get_signature_parameters()) # parse again to leave out self

@property
def is_subcommand(self) -> bool:
return self.parent is not None
Expand Down Expand Up @@ -1162,7 +1161,7 @@ def _update_copy(self, kwargs: Dict[str, Any]):
return self.copy()

def _set_cog(self, cog):
self.cog = cog
super()._set_cog(cog)
for subcommand in self.subcommands:
subcommand._set_cog(cog)

Expand Down
14 changes: 9 additions & 5 deletions discord/commands/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ async def hello(
----------
input_type: :class:`Any`
The type of input that is expected for this option.
description: :class:`str`
The description of this option.
Must be 100 characters or fewer.
name: :class:`str`
The name of this option visible in the UI.
Inherits from the variable name if not provided as a parameter.
description: Optional[:class:`str`]
The description of this option.
Must be 100 characters or fewer.
choices: Optional[List[Union[:class:`Any`, :class:`OptionChoice`]]]
The list of available choices for this option.
Can be a list of values or :class:`OptionChoice` objects (which represent a name:value pair).
Expand Down Expand Up @@ -115,10 +115,11 @@ async def hello(
See `here <https://discord.com/developers/docs/reference#locales>`_ for a list of valid locales.
"""

def __init__(self, input_type: Any, /, description: str = None, **kwargs) -> None:
def __init__(self, input_type: Any = str, /, description: Optional[str] = None, **kwargs) -> None:
self.name: Optional[str] = kwargs.pop("name", None)
if self.name is not None:
self.name = str(self.name)
self._parameter_name = self.name # default
self.description = description or "No description provided"
self.converter = None
self._raw_type = input_type
Expand All @@ -140,7 +141,10 @@ def __init__(self, input_type: Any, /, description: str = None, **kwargs) -> Non
else:
if _type == SlashCommandOptionType.channel:
if not isinstance(input_type, tuple):
input_type = (input_type,)
if hasattr(input_type, "__args__"): # Union
input_type = input_type.__args__
else:
input_type = (input_type,)
for i in input_type:
if i.__name__ == "GuildChannel":
continue
Expand Down
5 changes: 3 additions & 2 deletions discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,9 @@ def from_datatype(cls, datatype):
if issubclass(datatype, float):
return cls.number

# TODO: Improve the error message
raise TypeError(f"Invalid class {datatype} used as an input type for an Option")
from .commands.context import ApplicationContext
if not issubclass(datatype, ApplicationContext): # TODO: prevent ctx being passed here in cog commands
raise TypeError(f"Invalid class {datatype} used as an input type for an Option") # TODO: Improve the error message


class EmbeddedActivity(Enum):
Expand Down