From 17ee76269a0cc306a949ff8608be78b2efd35ceb Mon Sep 17 00:00:00 2001 From: plun1331 Date: Tue, 26 Apr 2022 16:36:31 -0700 Subject: [PATCH] Add more forum channel/thread features --- discord/abc.py | 2 ++ discord/channel.py | 70 +++++++++++++++++++++++++++++++++++++--------- discord/flags.py | 36 ++++++++++++++++++++++++ discord/http.py | 8 +++--- discord/threads.py | 16 +++++++++++ 5 files changed, 115 insertions(+), 17 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index 50c9ca3b5f..5548c48653 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -97,6 +97,7 @@ from .types.channel import PermissionOverwrite as PermissionOverwritePayload from .ui.view import View from .user import ClientUser + from .flags import ChannelFlags PartialMessageableChannel = Union[TextChannel, VoiceChannel, Thread, DMChannel, PartialMessageable] MessageableChannel = Union[PartialMessageableChannel, GroupChannel] @@ -328,6 +329,7 @@ class GuildChannel: type: ChannelType position: int category_id: Optional[int] + flags: ChannelFlags _state: ConnectionState _overwrites: List[_Overwrites] diff --git a/discord/channel.py b/discord/channel.py index 0473443f10..c3519e8ac3 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -60,6 +60,7 @@ try_enum, ) from .errors import ClientException, InvalidArgument +from .flags import ChannelFlags from .invite import Invite from .iterators import ArchivedThreadIterator from .mixins import Hashable @@ -154,6 +155,10 @@ class _TextChannel(discord.abc.GuildChannel, Hashable): default_auto_archive_duration: :class:`int` The default auto archive duration in minutes for threads created in this channel. + .. versionadded:: 2.0 + flags: :class:`ChannelFlags` + Extra features of the channel. + .. versionadded:: 2.0 """ @@ -171,6 +176,7 @@ class _TextChannel(discord.abc.GuildChannel, Hashable): "_type", "last_message_id", "default_auto_archive_duration", + "flags", ) def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[TextChannelPayload, ForumChannelPayload]): @@ -200,6 +206,7 @@ def _update(self, guild: Guild, data: Union[TextChannelPayload, ForumChannelPayl self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440) self._type: int = data.get("type", self._type) self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id") + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self._fill_overwrites(data) @property @@ -801,12 +808,11 @@ def guidelines(self) -> Optional[str]: """ return self.topic - async def create_post( + async def create_thread( self, - name: str, # Could be renamed to title? + name: str, content=None, *, - tts=None, embed=None, embeds=None, file=None, @@ -817,6 +823,7 @@ async def create_post( allowed_mentions=None, view=None, auto_archive_duration: ThreadArchiveDuration = MISSING, + slowmode_delay: int = MISSING, reason: Optional[str] = None, ) -> Thread: """|coro| @@ -832,17 +839,39 @@ async def create_post( ----------- name: :class:`str` The name of the thread. - message: Optional[:class:`abc.Snowflake`] - A snowflake representing the message to create the thread with. - If ``None`` is passed then a private thread is created. - Defaults to ``None``. + content: :class:`str` + The content of the message to send. + embed: :class:`~discord.Embed` + The rich embed for the content. + file: :class:`~discord.File` + The file to upload. + files: List[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + nonce: :class:`int` + The nonce to use for sending this message. If the message was successfully sent, + then the message will have a nonce with this value. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. + view: :class:`discord.ui.View` + A Discord UI View to add to the message. + embeds: List[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]] + A list of stickers to upload. Must be a maximum of 3. auto_archive_duration: :class:`int` The duration in minutes before a thread is automatically archived for inactivity. If not provided, the channel's default auto archive duration is used. - type: Optional[:class:`ChannelType`] - The type of thread to create. If a ``message`` is passed then this parameter - is ignored, as a thread created with a message is always a public thread. - By default this creates a private thread if this is ``None``. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in the new thread. A value of `0` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + If not provided, the forum channel's default slowmode is used. reason: :class:`str` The reason for creating a new thread. Shows up on the audit log. @@ -903,7 +932,6 @@ async def create_post( files=[file], allowed_mentions=allowed_mentions, content=message_content, - tts=tts, embed=embed, embeds=embeds, nonce=nonce, @@ -924,7 +952,6 @@ async def create_post( self.id, files=files, content=message_content, - tts=tts, embed=embed, embeds=embeds, nonce=nonce, @@ -948,6 +975,7 @@ async def create_post( stickers=stickers, components=components, auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay or self.slowmode_delay, reason=reason, ) ret = Thread(guild=self.guild, state=self._state, data=data) @@ -974,6 +1002,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha "rtc_region", "video_quality_mode", "last_message_id", + "flags", ) def __init__( @@ -1004,6 +1033,7 @@ def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPay self.position: int = data.get("position") self.bitrate: int = data.get("bitrate") self.user_limit: int = data.get("user_limit") + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self._fill_overwrites(data) @property @@ -1105,6 +1135,10 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel): .. versionadded:: 2.0 last_message_id: Optional[:class:`int`] The ID of the last message sent to this channel. It may not always point to an existing or valid message. + .. versionadded:: 2.0 + flags: :class:`ChannelFlags` + Extra features of the channel. + .. versionadded:: 2.0 """ @@ -1572,6 +1606,10 @@ class StageChannel(VocalGuildChannel): video_quality_mode: :class:`VideoQualityMode` The camera video quality for the stage channel's participants. + .. versionadded:: 2.0 + flags: :class:`ChannelFlags` + Extra features of the channel. + .. versionadded:: 2.0 """ @@ -1845,6 +1883,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): .. note:: To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. + flags: :class:`ChannelFlags` + Extra features of the channel. + + .. versionadded:: 2.0 """ __slots__ = ( @@ -1856,6 +1898,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): "position", "_overwrites", "category_id", + "flags", ) def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): @@ -1872,6 +1915,7 @@ def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") self.nsfw: bool = data.get("nsfw", False) self.position: int = data.get("position") + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self._fill_overwrites(data) @property diff --git a/discord/flags.py b/discord/flags.py index 0a1d0724e2..c17790e647 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -1198,3 +1198,39 @@ def gateway_message_content_limited(self): and has hit the guild limit. """ return 1 << 19 + + +@fill_with_flags() +class ChannelFlags(BaseFlags): + r"""Wraps up the Discord Channel flags. + + .. container:: operations + + .. describe:: x == y + + Checks if two ChannelFlags are equal. + .. describe:: x != y + + Checks if two ChannelFlags are not equal. + .. describe:: hash(x) + + Return the flag's hash. + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + Note that aliases are not shown. + + .. versionadded:: 2.0 + + Attributes + ----------- + value: :class:`int` + The raw value. You should query flags via the properties + rather than using this raw value. + """ + + @flag_value + def pinned(self): + """:class:`bool`: Returns ``True`` if the thread is pinned to the top of its parent forum channel.""" + return 1 << 1 diff --git a/discord/http.py b/discord/http.py index 3afb78eca1..2ec52ffe81 100644 --- a/discord/http.py +++ b/discord/http.py @@ -1091,9 +1091,9 @@ def start_forum_thread( *, name: str, auto_archive_duration: threads.ThreadArchiveDuration, + rate_limit_per_user: int, invitable: bool = True, reason: Optional[str] = None, - tts: bool = False, embed: Optional[embed.Embed] = None, embeds: Optional[List[embed.Embed]] = None, nonce: Optional[str] = None, @@ -1109,9 +1109,6 @@ def start_forum_thread( if content: payload["content"] = content - if tts: - payload["tts"] = True - if embed: payload["embeds"] = [embed] @@ -1129,6 +1126,9 @@ def start_forum_thread( if stickers: payload["sticker_ids"] = stickers + + if rate_limit_per_user: + payload["rate_limit_per_user"] = rate_limit_per_user # TODO: Once supported by API, remove has_message=true query parameter route = Route("POST", "/channels/{channel_id}/threads?has_message=true", channel_id=channel_id) return self.request(route, json=payload, reason=reason) diff --git a/discord/threads.py b/discord/threads.py index 322f50af4d..cafbc21ea5 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -32,6 +32,7 @@ from .abc import Messageable, _purge_messages_helper from .enums import ChannelType, try_enum from .errors import ClientException +from .flags import ChannelFlags from .mixins import Hashable from .utils import MISSING, _get_as_snowflake, parse_time @@ -121,6 +122,10 @@ class Thread(Messageable, Hashable): created_at: Optional[:class:`datetime.datetime`] An aware timestamp of when the thread was created. Only available for threads created after 2022-01-09. + flags: :class:`ChannelFlags` + Extra features of the thread. + + .. versionadded:: 2.0 """ __slots__ = ( @@ -143,6 +148,7 @@ class Thread(Messageable, Hashable): "auto_archive_duration", "archive_timestamp", "created_at", + "flags", ) def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): @@ -173,6 +179,7 @@ def _from_data(self, data: ThreadPayload): self.slowmode_delay = data.get("rate_limit_per_user", 0) self.message_count = data["message_count"] self.member_count = data["member_count"] + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self._unroll_metadata(data["thread_metadata"]) try: @@ -196,6 +203,7 @@ def _update(self, data): except KeyError: pass + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self.slowmode_delay = data.get("rate_limit_per_user", 0) try: @@ -503,6 +511,7 @@ async def edit( invitable: bool = MISSING, slowmode_delay: int = MISSING, auto_archive_duration: ThreadArchiveDuration = MISSING, + pinned: bool = MISSING, reason: Optional[str] = None, ) -> Thread: """|coro| @@ -535,6 +544,8 @@ async def edit( A value of ``0`` disables slowmode. The maximum value possible is ``21600``. reason: Optional[:class:`str`] The reason for editing this thread. Shows up on the audit log. + pinned: :class:`bool` + Whether to pin the thread or not. This only works if the thread is part of a forum. Raises ------- @@ -561,6 +572,11 @@ async def edit( payload["invitable"] = invitable if slowmode_delay is not MISSING: payload["rate_limit_per_user"] = slowmode_delay + if pinned is not MISSING: + # copy the ChannelFlags object to avoid mutating the original + flags = ChannelFlags._from_value(self.flags.value) + flags.pinned = pinned + payload['flags'] = flags.value data = await self._state.http.edit_channel(self.id, **payload, reason=reason) # The data payload will always be a Thread payload