From b78818014918b1a9189ec117ede1a5c016c0f347 Mon Sep 17 00:00:00 2001 From: Faster Speeding Date: Fri, 19 Aug 2022 19:36:34 +0100 Subject: [PATCH] Add app commands declared client callback --- tanjun/abc.py | 28 ++++++++++++++++++++++++++++ tanjun/clients.py | 44 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/tanjun/abc.py b/tanjun/abc.py index 0ab38bfdb..95a79c5a1 100644 --- a/tanjun/abc.py +++ b/tanjun/abc.py @@ -3862,12 +3862,40 @@ async def open(self) -> None: """ +class DeclaredCommands(abc.ABC): + __slots__ = () + + @property + @abc.abstractmethod + def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]: + """The declared command builders.""" + + @property + @abc.abstractmethod + def commands(self) -> collections.Sequence[hikari.PartialCommand]: + """The declared command objects.""" + + @property + @abc.abstractmethod + def guild_id(self) -> typing.Optional[hikari.Snowflake]: + """Id of the guild these commands were declared for. + + This will be [None][] if they were declared globally. + """ + + class ClientCallbackNames(str, enum.Enum): """Enum of the standard client callback names. These should be dispatched by all [tanjun.abc.Client][] implementations. """ + APP_COMMANDS_DECLARED = "app_commands_delcared" + """Called when the application commands are declared through the client. + + One positional argument of type [DeclaredCommands][]. + """ + CLOSED = "closed" """Called when the client has finished closing. diff --git a/tanjun/clients.py b/tanjun/clients.py index 14e0dcb54..ea239bc4b 100644 --- a/tanjun/clients.py +++ b/tanjun/clients.py @@ -497,6 +497,33 @@ async def __call__(self) -> None: ) +class _DeclaredCommands(tanjun.DeclaredCommands): + __slots__ = ("_builders", "_commands", "_guild_id") + + def __init__( + self, + builders: collections.Sequence[hikari.api.CommandBuilder], + commands: collections.Sequence[hikari.PartialCommand], + guild_id: typing.Optional[hikari.Snowflake], + /, + ) -> None: + self._builders = builders + self._commands = commands + self._guild_id = guild_id + + @property + def builders(self) -> collections.Sequence[hikari.api.CommandBuilder]: + return self._builders + + @property + def commands(self) -> collections.Sequence[hikari.PartialCommand]: + return self._commands + + @property + def guild_id(self) -> typing.Optional[hikari.Snowflake]: + return self._guild_id + + def _log_clients( cache: typing.Optional[hikari.api.Cache], events: typing.Optional[hikari.api.EventManager], @@ -1358,14 +1385,14 @@ async def declare_application_commands( user_ids = user_ids or {} names_to_commands: dict[tuple[hikari.CommandType, str], tanjun.AppCommand[typing.Any]] = {} conflicts: set[tuple[hikari.CommandType, str]] = set() - builders: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {} + builders_dict: dict[tuple[hikari.CommandType, str], hikari.api.CommandBuilder] = {} message_count = 0 slash_count = 0 user_count = 0 for command in commands: key = (command.type, command.name) - if key in builders: + if key in builders_dict: conflicts.add(key) if isinstance(command, tanjun.AppCommand): @@ -1399,7 +1426,7 @@ async def declare_application_commands( if localiser: localisation.localise_command(builder, localiser) - builders[key] = builder + builders_dict[key] = builder if conflicts: raise ValueError( @@ -1421,14 +1448,15 @@ async def declare_application_commands( if not force: registered_commands = await self._rest.fetch_application_commands(application, guild=guild) - if _internal.cmp_all_commands(registered_commands, builders): + if _internal.cmp_all_commands(registered_commands, builders_dict): _LOGGER.info( "Skipping bulk declare for %s application commands since they're already declared", target_type ) return registered_commands - _LOGGER.info("Bulk declaring %s %s application commands", len(builders), target_type) - responses = await self._rest.set_application_commands(application, list(builders.values()), guild=guild) + _LOGGER.info("Bulk declaring %s %s application commands", len(builders_dict), target_type) + builders = list(builders_dict.values()) + responses = await self._rest.set_application_commands(application, builders, guild=guild) for response in responses: # different command_ name used here for MyPy compat if not guild and (command_ := names_to_commands.get((response.type, response.name))): @@ -1442,6 +1470,10 @@ async def declare_application_commands( ", ".join(f"{response.type}-{response.name}: {response.id}" for response in responses), ) + await self.dispatch_client_callback( + tanjun.ClientCallbackNames.APP_COMMANDS_DECLARED, + _DeclaredCommands(builders, responses, None if guild is hikari.UNDEFINED else hikari.Snowflake(guild)), + ) return responses def set_auto_defer_after(self, time: typing.Optional[float], /) -> Self: