Skip to content

Commit

Permalink
Merge pull request #1298 from plun1331/forum-channels
Browse files Browse the repository at this point in the history
Add more forum channel/thread features
  • Loading branch information
BobDotCom authored Apr 27, 2022
2 parents 7b96490 + 17ee762 commit 0433dd4
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 17 deletions.
2 changes: 2 additions & 0 deletions discord/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from .types.channel import PermissionOverwrite as PermissionOverwritePayload
from .ui.view import View
from .user import ClientUser
from .flags import ChannelFlags

PartialMessageableChannel = Union[TextChannel, VoiceChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
Expand Down Expand Up @@ -328,6 +329,7 @@ class GuildChannel:
type: ChannelType
position: int
category_id: Optional[int]
flags: ChannelFlags
_state: ConnectionState
_overwrites: List[_Overwrites]

Expand Down
70 changes: 57 additions & 13 deletions discord/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
try_enum,
)
from .errors import ClientException, InvalidArgument
from .flags import ChannelFlags
from .invite import Invite
from .iterators import ArchivedThreadIterator
from .mixins import Hashable
Expand Down Expand Up @@ -154,6 +155,10 @@ class _TextChannel(discord.abc.GuildChannel, Hashable):
default_auto_archive_duration: :class:`int`
The default auto archive duration in minutes for threads created in this channel.
.. versionadded:: 2.0
flags: :class:`ChannelFlags`
Extra features of the channel.
.. versionadded:: 2.0
"""

Expand All @@ -171,6 +176,7 @@ class _TextChannel(discord.abc.GuildChannel, Hashable):
"_type",
"last_message_id",
"default_auto_archive_duration",
"flags",
)

def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[TextChannelPayload, ForumChannelPayload]):
Expand Down Expand Up @@ -200,6 +206,7 @@ def _update(self, guild: Guild, data: Union[TextChannelPayload, ForumChannelPayl
self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440)
self._type: int = data.get("type", self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id")
self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0))
self._fill_overwrites(data)

@property
Expand Down Expand Up @@ -801,12 +808,11 @@ def guidelines(self) -> Optional[str]:
"""
return self.topic

async def create_post(
async def create_thread(
self,
name: str, # Could be renamed to title?
name: str,
content=None,
*,
tts=None,
embed=None,
embeds=None,
file=None,
Expand All @@ -817,6 +823,7 @@ async def create_post(
allowed_mentions=None,
view=None,
auto_archive_duration: ThreadArchiveDuration = MISSING,
slowmode_delay: int = MISSING,
reason: Optional[str] = None,
) -> Thread:
"""|coro|
Expand All @@ -832,17 +839,39 @@ async def create_post(
-----------
name: :class:`str`
The name of the thread.
message: Optional[:class:`abc.Snowflake`]
A snowflake representing the message to create the thread with.
If ``None`` is passed then a private thread is created.
Defaults to ``None``.
content: :class:`str`
The content of the message to send.
embed: :class:`~discord.Embed`
The rich embed for the content.
file: :class:`~discord.File`
The file to upload.
files: List[:class:`~discord.File`]
A list of files to upload. Must be a maximum of 10.
nonce: :class:`int`
The nonce to use for sending this message. If the message was successfully sent,
then the message will have a nonce with this value.
allowed_mentions: :class:`~discord.AllowedMentions`
Controls the mentions being processed in this message. If this is
passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`.
The merging behaviour only overrides attributes that have been explicitly passed
to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`.
If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions`
are used instead.
view: :class:`discord.ui.View`
A Discord UI View to add to the message.
embeds: List[:class:`~discord.Embed`]
A list of embeds to upload. Must be a maximum of 10.
stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]]
A list of stickers to upload. Must be a maximum of 3.
auto_archive_duration: :class:`int`
The duration in minutes before a thread is automatically archived for inactivity.
If not provided, the channel's default auto archive duration is used.
type: Optional[:class:`ChannelType`]
The type of thread to create. If a ``message`` is passed then this parameter
is ignored, as a thread created with a message is always a public thread.
By default this creates a private thread if this is ``None``.
slowmode_delay: :class:`int`
The number of seconds a member must wait between sending messages
in the new thread. A value of `0` denotes that it is disabled.
Bots and users with :attr:`~Permissions.manage_channels` or
:attr:`~Permissions.manage_messages` bypass slowmode.
If not provided, the forum channel's default slowmode is used.
reason: :class:`str`
The reason for creating a new thread. Shows up on the audit log.
Expand Down Expand Up @@ -903,7 +932,6 @@ async def create_post(
files=[file],
allowed_mentions=allowed_mentions,
content=message_content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
Expand All @@ -924,7 +952,6 @@ async def create_post(
self.id,
files=files,
content=message_content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
Expand All @@ -948,6 +975,7 @@ async def create_post(
stickers=stickers,
components=components,
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
rate_limit_per_user=slowmode_delay or self.slowmode_delay,
reason=reason,
)
ret = Thread(guild=self.guild, state=self._state, data=data)
Expand All @@ -974,6 +1002,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
"rtc_region",
"video_quality_mode",
"last_message_id",
"flags",
)

def __init__(
Expand Down Expand Up @@ -1004,6 +1033,7 @@ def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPay
self.position: int = data.get("position")
self.bitrate: int = data.get("bitrate")
self.user_limit: int = data.get("user_limit")
self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0))
self._fill_overwrites(data)

@property
Expand Down Expand Up @@ -1105,6 +1135,10 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
.. versionadded:: 2.0
last_message_id: Optional[:class:`int`]
The ID of the last message sent to this channel. It may not always point to an existing or valid message.
.. versionadded:: 2.0
flags: :class:`ChannelFlags`
Extra features of the channel.
.. versionadded:: 2.0
"""

Expand Down Expand Up @@ -1572,6 +1606,10 @@ class StageChannel(VocalGuildChannel):
video_quality_mode: :class:`VideoQualityMode`
The camera video quality for the stage channel's participants.
.. versionadded:: 2.0
flags: :class:`ChannelFlags`
Extra features of the channel.
.. versionadded:: 2.0
"""

Expand Down Expand Up @@ -1845,6 +1883,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
.. note::
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
flags: :class:`ChannelFlags`
Extra features of the channel.
.. versionadded:: 2.0
"""

__slots__ = (
Expand All @@ -1856,6 +1898,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
"position",
"_overwrites",
"category_id",
"flags",
)

def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload):
Expand All @@ -1872,6 +1915,7 @@ def _update(self, guild: Guild, data: CategoryChannelPayload) -> None:
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.nsfw: bool = data.get("nsfw", False)
self.position: int = data.get("position")
self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0))
self._fill_overwrites(data)

@property
Expand Down
36 changes: 36 additions & 0 deletions discord/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,3 +1198,39 @@ def gateway_message_content_limited(self):
and has hit the guild limit.
"""
return 1 << 19


@fill_with_flags()
class ChannelFlags(BaseFlags):
r"""Wraps up the Discord Channel flags.
.. container:: operations
.. describe:: x == y
Checks if two ChannelFlags are equal.
.. describe:: x != y
Checks if two ChannelFlags are not equal.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. versionadded:: 2.0
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""

@flag_value
def pinned(self):
""":class:`bool`: Returns ``True`` if the thread is pinned to the top of its parent forum channel."""
return 1 << 1
8 changes: 4 additions & 4 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,9 +1091,9 @@ def start_forum_thread(
*,
name: str,
auto_archive_duration: threads.ThreadArchiveDuration,
rate_limit_per_user: int,
invitable: bool = True,
reason: Optional[str] = None,
tts: bool = False,
embed: Optional[embed.Embed] = None,
embeds: Optional[List[embed.Embed]] = None,
nonce: Optional[str] = None,
Expand All @@ -1109,9 +1109,6 @@ def start_forum_thread(
if content:
payload["content"] = content

if tts:
payload["tts"] = True

if embed:
payload["embeds"] = [embed]

Expand All @@ -1129,6 +1126,9 @@ def start_forum_thread(

if stickers:
payload["sticker_ids"] = stickers

if rate_limit_per_user:
payload["rate_limit_per_user"] = rate_limit_per_user
# TODO: Once supported by API, remove has_message=true query parameter
route = Route("POST", "/channels/{channel_id}/threads?has_message=true", channel_id=channel_id)
return self.request(route, json=payload, reason=reason)
Expand Down
16 changes: 16 additions & 0 deletions discord/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .abc import Messageable, _purge_messages_helper
from .enums import ChannelType, try_enum
from .errors import ClientException
from .flags import ChannelFlags
from .mixins import Hashable
from .utils import MISSING, _get_as_snowflake, parse_time

Expand Down Expand Up @@ -121,6 +122,10 @@ class Thread(Messageable, Hashable):
created_at: Optional[:class:`datetime.datetime`]
An aware timestamp of when the thread was created.
Only available for threads created after 2022-01-09.
flags: :class:`ChannelFlags`
Extra features of the thread.
.. versionadded:: 2.0
"""

__slots__ = (
Expand All @@ -143,6 +148,7 @@ class Thread(Messageable, Hashable):
"auto_archive_duration",
"archive_timestamp",
"created_at",
"flags",
)

def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
Expand Down Expand Up @@ -173,6 +179,7 @@ def _from_data(self, data: ThreadPayload):
self.slowmode_delay = data.get("rate_limit_per_user", 0)
self.message_count = data["message_count"]
self.member_count = data["member_count"]
self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0))
self._unroll_metadata(data["thread_metadata"])

try:
Expand All @@ -196,6 +203,7 @@ def _update(self, data):
except KeyError:
pass

self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0))
self.slowmode_delay = data.get("rate_limit_per_user", 0)

try:
Expand Down Expand Up @@ -503,6 +511,7 @@ async def edit(
invitable: bool = MISSING,
slowmode_delay: int = MISSING,
auto_archive_duration: ThreadArchiveDuration = MISSING,
pinned: bool = MISSING,
reason: Optional[str] = None,
) -> Thread:
"""|coro|
Expand Down Expand Up @@ -535,6 +544,8 @@ async def edit(
A value of ``0`` disables slowmode. The maximum value possible is ``21600``.
reason: Optional[:class:`str`]
The reason for editing this thread. Shows up on the audit log.
pinned: :class:`bool`
Whether to pin the thread or not. This only works if the thread is part of a forum.
Raises
-------
Expand All @@ -561,6 +572,11 @@ async def edit(
payload["invitable"] = invitable
if slowmode_delay is not MISSING:
payload["rate_limit_per_user"] = slowmode_delay
if pinned is not MISSING:
# copy the ChannelFlags object to avoid mutating the original
flags = ChannelFlags._from_value(self.flags.value)
flags.pinned = pinned
payload['flags'] = flags.value

data = await self._state.http.edit_channel(self.id, **payload, reason=reason)
# The data payload will always be a Thread payload
Expand Down

0 comments on commit 0433dd4

Please sign in to comment.