Skip to content
This repository has been archived by the owner on Dec 26, 2022. It is now read-only.

✨ Added UserMessage.from_id #315

Merged
merged 10 commits into from
Dec 21, 2021
156 changes: 91 additions & 65 deletions pincer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,38 @@
from importlib import import_module
from inspect import isasyncgenfunction
from typing import (
Any, Dict, Iterable, List, Optional, Tuple, Union, overload,
AsyncIterator, TYPE_CHECKING
Any,
Dict,
List,
Optional,
Tuple,
Union,
overload,
AsyncIterator,
TYPE_CHECKING
Lunarmagpie marked this conversation as resolved.
Show resolved Hide resolved
)
from . import __package__
from .commands import ChatCommandHandler
from .core import HTTPClient
from .core.gateway import Dispatcher
from .exceptions import (
InvalidEventName, TooManySetupArguments, NoValidSetupMethod,
NoCogManagerReturnFound, CogAlreadyExists, CogNotFound
InvalidEventName,
TooManySetupArguments,
NoValidSetupMethod,
NoCogManagerReturnFound,
CogAlreadyExists,
CogNotFound,
Lunarmagpie marked this conversation as resolved.
Show resolved Hide resolved
)
from .middleware import middleware
from .objects import (
Role, Channel, DefaultThrottleHandler, User, Guild, Intents,
GuildTemplate, StickerPack
Role,
Channel,
DefaultThrottleHandler,
User,
Guild,
Intents,
GuildTemplate,
StickerPack, UserMessage,
Lunarmagpie marked this conversation as resolved.
Show resolved Hide resolved
)
Enderchief marked this conversation as resolved.
Show resolved Hide resolved
from .utils.conversion import construct_client_dict
from .utils.event_mgr import EventMgr
Expand Down Expand Up @@ -104,7 +121,8 @@ def decorator(func: Coro):
if override:
_log.warning(
"Middleware overriding has been enabled for `%s`."
" This might cause unexpected behavior.", call
" This might cause unexpected behavior.",
call,
)

if not override and callable(_events.get(call)):
Expand All @@ -117,9 +135,7 @@ async def wrapper(cls, payload: GatewayDispatch):
_log.debug("`%s` middleware has been invoked", call)

return await (
func(cls, payload)
if should_pass_cls(func)
else func(payload)
func(cls, payload) if should_pass_cls(func) else func(payload)
)

_events[call] = wrapper
Expand Down Expand Up @@ -166,12 +182,13 @@ class Client(Dispatcher):
""" # noqa: E501

def __init__(
self,
token: str, *,
received: str = None,
intents: Union[Iterable, Intents] = None,
throttler: ThrottleInterface = DefaultThrottleHandler,
reconnect: bool = True,
self,
token: str,
*,
received: str = None,
intents: Intents = None,
throttler: ThrottleInterface = DefaultThrottleHandler,
reconnect: bool = True,
):

if isinstance(intents, Iterable):
Expand All @@ -183,7 +200,7 @@ def __init__(
# Gets triggered on all events
-1: self.payload_event_handler,
# Use this event handler for opcode 0.
0: self.event_handler
0: self.event_handler,
},
intents=intents or Intents.all(),
reconnect=reconnect,
Expand All @@ -209,10 +226,9 @@ def chat_commands(self) -> List[str]:
Get a list of chat command calls which have been registered in
the :class:`~pincer.commands.ChatCommandHandler`\\.
"""
return list(map(
lambda cmd: cmd.app.name,
ChatCommandHandler.register.values()
))
return list(
map(lambda cmd: cmd.app.name, ChatCommandHandler.register.values())
)
Enderchief marked this conversation as resolved.
Show resolved Hide resolved

@property
def guild_ids(self) -> List[Snowflake]:
Expand Down Expand Up @@ -274,8 +290,9 @@ async def on_ready(self):
InvalidEventName
If the function name is not a valid event (on_x)
"""
if not iscoroutinefunction(coroutine) \
and not isasyncgenfunction(coroutine):
if not iscoroutinefunction(coroutine) and not isasyncgenfunction(
coroutine
):
Enderchief marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
"Any event which is registered must be a coroutine function"
)
Expand Down Expand Up @@ -307,10 +324,17 @@ def get_event_coro(name: str) -> List[Optional[Coro]]:
"""
calls = _events.get(name.strip().lower())

return [] if not calls else list(filter(
lambda call: iscoroutinefunction(call) or isasyncgenfunction(call),
calls
))
return (
[]
if not calls
else list(
filter(
lambda call: iscoroutinefunction(call)
or isasyncgenfunction(call),
calls,
)
)
)
Sigmanificient marked this conversation as resolved.
Show resolved Hide resolved

def load_cog(self, path: str, package: Optional[str] = None):
"""Load a cog from a string path, setup method in COG may
Expand Down Expand Up @@ -451,7 +475,7 @@ def execute_event(calls: List[Coro], *args, **kwargs):
if should_pass_cls(call):
call_args = (
ChatCommandHandler.managers[call.__module__],
*(arg for arg in args if arg is not None)
*(arg for arg in args if arg is not None),
Enderchief marked this conversation as resolved.
Show resolved Hide resolved
)

ensure_future(call(*call_args, **kwargs))
Expand All @@ -462,15 +486,11 @@ def run(self):

def __del__(self):
"""Ensure close of the http client."""
if hasattr(self, 'http'):
if hasattr(self, "http"):
run(self.http.close())

async def handle_middleware(
self,
payload: GatewayDispatch,
key: str,
*args,
**kwargs
self, payload: GatewayDispatch, key: str, *args, **kwargs
) -> Tuple[Optional[Coro], List[Any], Dict[str, Any]]:
"""|coro|

Expand Down Expand Up @@ -522,11 +542,7 @@ async def handle_middleware(
)

async def execute_error(
self,
error: Exception,
name: str = "on_error",
*args,
**kwargs
self, error: Exception, name: str = "on_error", *args, **kwargs
):
"""|coro|

Expand Down Expand Up @@ -623,7 +639,7 @@ async def create_guild(
afk_channel_id: Optional[Snowflake] = None,
afk_timeout: Optional[int] = None,
system_channel_id: Optional[Snowflake] = None,
system_channel_flags: Optional[int] = None
system_channel_flags: Optional[int] = None,
) -> Guild:
"""Creates a guild.

Expand Down Expand Up @@ -664,7 +680,7 @@ async def create_guild(

async def create_guild(self, name: str, **kwargs) -> Guild:
g = await self.http.post("guilds", data={"name": name, **kwargs})
return await self.get_guild(g['id'])
return await self.get_guild(g["id"])

async def get_guild_template(self, code: str) -> GuildTemplate:
"""|coro|
Expand All @@ -682,16 +698,12 @@ async def get_guild_template(self, code: str) -> GuildTemplate:
"""
return GuildTemplate.from_dict(
construct_client_dict(
self,
await self.http.get(f"guilds/templates/{code}")
self, await self.http.get(f"guilds/templates/{code}")
)
)

async def create_guild_from_template(
self,
template: GuildTemplate,
name: str,
icon: Optional[str] = None
self, template: GuildTemplate, name: str, icon: Optional[str] = None
) -> Guild:
"""|coro|
Creates a guild from a template.
Expand All @@ -715,16 +727,16 @@ async def create_guild_from_template(
self,
await self.http.post(
f"guilds/templates/{template.code}",
data={"name": name, "icon": icon}
)
data={"name": name, "icon": icon},
),
)
)

async def wait_for(
self,
event_name: str,
check: CheckFunction = None,
timeout: Optional[float] = None
self,
event_name: str,
check: CheckFunction = None,
timeout: Optional[float] = None,
):
"""
Parameters
Expand All @@ -745,11 +757,11 @@ async def wait_for(
return await self.event_mgr.wait_for(event_name, check, timeout)

def loop_for(
self,
event_name: str,
check: CheckFunction = None,
iteration_timeout: Optional[float] = None,
loop_timeout: Optional[float] = None
self,
event_name: str,
check: CheckFunction = None,
iteration_timeout: Optional[float] = None,
loop_timeout: Optional[float] = None,
):
"""
Parameters
Expand All @@ -771,10 +783,7 @@ def loop_for(
What the Discord API returns for this event.
"""
return self.event_mgr.loop_for(
event_name,
check,
iteration_timeout,
loop_timeout
event_name, check, iteration_timeout, loop_timeout
)

async def get_guild(self, guild_id: int) -> Guild:
Expand Down Expand Up @@ -852,10 +861,27 @@ async def get_channel(self, _id: int) -> Channel:
"""
return await Channel.from_id(self, _id)

async def get_message(self, _id: Snowflake, channel_id: Snowflake) -> UserMessage:
Enderchief marked this conversation as resolved.
Show resolved Hide resolved
"""|coro|
Creates a UserMessage object

Parameters
----------
_id: :class:`~pincer.utils.snowflake.Snowflake`
ID of the message that is wanted.
channel_id : int
ID of the channel the message is in.

Returns
-------
:class:`~pincer.objects.message.user_message.UserMessage`
The message object.
"""

return await UserMessage.from_id(self, _id, channel_id)

async def get_webhook(
self,
id: Snowflake,
token: Optional[str] = None
self, id: Snowflake, token: Optional[str] = None
) -> Webhook:
"""|coro|
Fetch a Webhook from its identifier.
Expand Down
Loading