diff --git a/lavalink/__init__.py b/lavalink/__init__.py index d6c8569e..ddeffb4c 100644 --- a/lavalink/__init__.py +++ b/lavalink/__init__.py @@ -26,10 +26,7 @@ "Node", "NodeStats", "Stats", - "user_id", - "channel_finder_func", "Player", - "PlayerManager", "initialize", "connect", "get_player", diff --git a/lavalink/enums.py b/lavalink/enums.py index cfdd8332..1ed971e4 100644 --- a/lavalink/enums.py +++ b/lavalink/enums.py @@ -136,3 +136,4 @@ class ExceptionSeverity(enum.Enum): COMMON = "COMMON" SUSPICIOUS = "SUSPICIOUS" FATAL = "FATAL" + FAULT = "FAULT" diff --git a/lavalink/lavalink.py b/lavalink/lavalink.py index e4cb4023..3e07b981 100644 --- a/lavalink/lavalink.py +++ b/lavalink/lavalink.py @@ -69,20 +69,17 @@ async def initialize( global _loop _loop = bot.loop - player_manager.user_id = bot.user.id - player_manager.channel_finder_func = bot.get_channel register_event_listener(_handle_event) register_update_listener(_handle_update) lavalink_node = node.Node( - _loop, - dispatch, - bot._connection._get_websocket, - host, - password, + _loop=_loop, + event_handler=dispatch, + host=host, + password=password, port=ws_port, - user_id=player_manager.user_id, - num_shards=bot.shard_count if bot.shard_count is not None else 1, + user_id=bot.user.id, + num_shards=bot.shard_count or 1, resume_key=resume_key, resume_timeout=resume_timeout, bot=bot, @@ -91,7 +88,6 @@ async def initialize( await lavalink_node.connect(timeout=timeout, secured=secured) lavalink_node._retries = 0 - bot.add_listener(node.on_socket_response) bot.add_listener(_on_guild_remove, name="on_guild_remove") return lavalink_node @@ -100,32 +96,28 @@ async def initialize( async def connect(channel: discord.VoiceChannel, deafen: bool = False): """ Connects to a discord voice channel. - This is the publicly exposed way to connect to a discord voice channel. The :py:func:`initialize` function must be called first! - Parameters ---------- channel - Returns ------- Player The created Player object. - Raises ------ IndexError If there are no available lavalink nodes ready to connect to discord. """ node_ = node.get_node(channel.guild.id) - p = await node_.player_manager.create_player(channel, deafen=deafen) + p = await node_.create_player(channel, deafen=deafen) return p def get_player(guild_id: int) -> player_manager.Player: node_ = node.get_node(guild_id) - return node_.player_manager.get_player(guild_id) + return node_.get_player(guild_id) async def _on_guild_remove(guild): @@ -185,7 +177,7 @@ def _get_event_args(data: enums.LavalinkEvents, raw_data: dict): try: node_ = node.get_node(guild_id, ignore_ready_status=True) - player = node_.player_manager.get_player(guild_id) + player = node_.get_player(guild_id) except (IndexError, KeyError): if data != enums.LavalinkEvents.TRACK_END: log.debug( @@ -353,7 +345,6 @@ async def close(bot): """ unregister_event_listener(_handle_event) unregister_update_listener(_handle_update) - bot.remove_listener(node.on_socket_response) bot.remove_listener(_on_guild_remove, name="on_guild_remove") await node.disconnect() @@ -363,13 +354,13 @@ async def close(bot): def all_players() -> Tuple[player_manager.Player]: nodes = node._nodes - ret = tuple(p for n in nodes for p in n.player_manager.players) + ret = tuple(p for n in nodes for p in n.players) return ret def all_connected_players() -> Tuple[player_manager.Player]: nodes = node._nodes - ret = tuple(p for n in nodes for p in n.player_manager.players if p.connected) + ret = tuple(p for n in nodes for p in n.players if p.connected) return ret diff --git a/lavalink/node.py b/lavalink/node.py index aaf8d679..37f0926f 100644 --- a/lavalink/node.py +++ b/lavalink/node.py @@ -1,26 +1,29 @@ from __future__ import annotations + import asyncio import contextlib import secrets import string +import typing from collections import namedtuple -from typing import Awaitable, List, Optional, cast +from typing import Awaitable, KeysView, List, Optional, ValuesView, cast import aiohttp -import typing +from typing import Awaitable, KeysView, List, Optional, ValuesView, cast from discord.backoff import ExponentialBackoff from discord.ext.commands import Bot -from . import __version__, ws_discord_log, ws_ll_log -from .enums import * -from .player_manager import PlayerManager +from . import log, ws_ll_log, ws_rll_log, __version__ +from .enums import LavalinkEvents, LavalinkIncomingOp, LavalinkOutgoingOp, NodeState, PlayerState +from .player_manager import Player from .rest_api import Track +from .utils import VoiceChannel -__all__ = ["Stats", "Node", "NodeStats", "get_node", "get_nodes_stats", "join_voice"] +__all__ = ["Stats", "Node", "NodeStats", "get_node", "get_nodes_stats"] _nodes: List[Node] = [] -PlayerState = namedtuple("PlayerState", "position time connected") +PositionTime = namedtuple("PositionTime", "position time connected") MemoryInfo = namedtuple("MemoryInfo", "reservable used free allocated") CPUInfo = namedtuple("CPUInfo", "cores systemLoad lavalinkLoad") @@ -103,7 +106,6 @@ def __init__( self, _loop: asyncio.BaseEventLoop, event_handler: typing.Callable, - voice_ws_func: typing.Callable, host: str, password: str, port: int, @@ -122,8 +124,6 @@ def __init__( The event loop of the bot. event_handler Function to dispatch events to. - voice_ws_func : typing.Callable - Function that takes one argument, guild ID, and returns a websocket. host : str Lavalink player host. password : str @@ -146,7 +146,6 @@ def __init__( self.loop = _loop self.bot = bot self.event_handler = event_handler - self.get_voice_ws = voice_ws_func self.host = host self.port = port self.password = password @@ -165,13 +164,12 @@ def __init__( self.session = aiohttp.ClientSession() self._queue: List = [] + self._players_dict = {} self.state = NodeState.CONNECTING self._state_handlers: List = [] self._retries = 0 - self.player_manager = PlayerManager(self) - self.stats = None if self not in _nodes: @@ -183,6 +181,8 @@ def __init__( aiohttp.WSMsgType.CLOSED, ) + self.register_state_handler(self.node_state_handler) + def __repr__(self): return ( " dict: return self._get_connect_headers() + @property + def players(self) -> ValuesView[Player]: + return self._players_dict.values() + + @property + def guild_ids(self) -> KeysView[int]: + return self._players_dict.keys() + def _gen_key(self): if self._resume_key is None: return _Key() @@ -374,7 +382,7 @@ async def _handle_op(self, op: LavalinkIncomingOp, data): self.event_handler(op, event, data) elif op == LavalinkIncomingOp.PLAYER_UPDATE: state = data.get("state", {}) - state = PlayerState( + state = PositionTime( position=state.get("position", 0), time=state.get("time", 0), connected=state.get("connected", False), @@ -424,7 +432,7 @@ async def _reconnect(self): self._retries = 0 def dispatch_reconnect(self): - for guild_id in self.player_manager.guild_ids: + for guild_id in self.guild_ids: self.event_handler( LavalinkIncomingOp.EVENT, LavalinkEvents.WEBSOCKET_CLOSED, @@ -460,12 +468,85 @@ def register_state_handler(self, func): def unregister_state_handler(self, func): self._state_handlers.remove(func) - async def join_voice_channel(self, guild_id, channel_id, deafen: bool = False): + async def create_player(self, channel: VoiceChannel, deafen: bool = False) -> Player: """ - Alternative way to join a voice channel if node is known. + Connects to a discord voice channel. + This function is safe to repeatedly call as it will return an existing + player if there is one. + Parameters + ---------- + channel + Returns + ------- + Player + The created Player object. + """ + if self._already_in_guild(channel): + p = self.get_player(channel.guild.id) + await p.move_to(channel, deafen=deafen) + else: + p = await channel.connect(cls=Player) + if deafen: + await p.guild.change_voice_state(channel=p.channel, self_deaf=True) + self._players_dict[channel.guild.id] = p + await self.refresh_player_state(p) + return p + + def _already_in_guild(self, channel: VoiceChannel) -> bool: + return channel.guild.id in self._players_dict + + def get_player(self, guild_id: int) -> Player: + """ + Gets a Player object from a guild ID. + Parameters + ---------- + guild_id : int + Discord guild ID. + Returns + ------- + Player + Raises + ------ + KeyError + If that guild does not have a Player, e.g. is not connected to any + voice channel. """ - voice_ws = self.get_voice_ws(guild_id) - await voice_ws.voice_state(guild_id, channel_id, self_deaf=deafen) + if guild_id in self._players_dict: + return self._players_dict[guild_id] + raise KeyError("No such player for that guild.") + + async def node_state_handler(self, next_state: NodeState, old_state: NodeState): + ws_rll_log.debug("Received node state update: %s -> %s", old_state.name, next_state.name) + if next_state == NodeState.READY: + await self.update_player_states(PlayerState.READY) + elif next_state == NodeState.DISCONNECTING: + await self.update_player_states(PlayerState.DISCONNECTING) + elif next_state in (NodeState.CONNECTING, NodeState.RECONNECTING): + await self.update_player_states(PlayerState.NODE_BUSY) + + async def update_player_states(self, state: PlayerState): + for p in self.players: + await p.update_state(state) + + async def refresh_player_state(self, player: Player): + if self.ready: + await player.update_state(PlayerState.READY) + elif self.state == NodeState.DISCONNECTING: + await player.update_state(PlayerState.DISCONNECTING) + else: + await player.update_state(PlayerState.NODE_BUSY) + + def remove_player(self, player: Player): + if player.state != PlayerState.DISCONNECTING: + log.error( + "Attempting to remove a player (%r) from player list with state: %s", + player, + player.state.name, + ) + return + guild_id = player.channel.guild.id + if guild_id in self._players_dict: + del self._players_dict[guild_id] async def disconnect(self): """ @@ -479,7 +560,10 @@ async def disconnect(self): if self._resuming_configured: await self.send(dict(op="configureResuming", key=None)) self._resuming_configured = False - await self.player_manager.disconnect() + + for p in tuple(self.players): + await p.disconnect(force=True) + log.debug("Disconnected all players.") if self._ws is not None and not self._ws.closed: await self._ws.close() @@ -572,7 +656,7 @@ async def seek(self, guild_id: int, position: int): ) -def get_node(guild_id: int, ignore_ready_status: bool = False) -> Node: +def get_node(guild_id: int = None, ignore_ready_status: bool = False) -> Node: """ Gets a node based on a guild ID, useful for noding separation. If the guild ID does not already have a node association, the least used @@ -591,7 +675,7 @@ def get_node(guild_id: int, ignore_ready_status: bool = False) -> Node: least_used = None for node in _nodes: - guild_ids = node.player_manager.guild_ids + guild_ids = node.guild_ids if ignore_ready_status is False and not node.ready: continue @@ -612,41 +696,6 @@ def get_nodes_stats(): return [node.stats for node in _nodes] -async def join_voice(guild_id: int, channel_id: int, deafen: bool = False): - """ - Joins a voice channel by ID's. - - Parameters - ---------- - guild_id : int - channel_id : int - """ - node = get_node(guild_id) - await node.join_voice_channel(guild_id, channel_id, deafen) - - async def disconnect(): for node in _nodes.copy(): await node.disconnect() - - -async def on_socket_response(data): - raw_event = data.get("t") - try: - event = DiscordVoiceSocketResponses(raw_event) - except ValueError: - return - - guild_id = data["d"]["guild_id"] - - try: - node = get_node(guild_id, ignore_ready_status=True) - except IndexError: - ws_discord_log.info( - f"Received unhandled Discord WS voice response for guild: %d, %s", int(guild_id), data - ) - else: - ws_ll_log.debug( - f"Received Discord WS voice response for guild: %d, %s", int(guild_id), data - ) - await node.player_manager.on_socket_response(data) diff --git a/lavalink/player_manager.py b/lavalink/player_manager.py index 7222f218..f680c4e5 100644 --- a/lavalink/player_manager.py +++ b/lavalink/player_manager.py @@ -1,25 +1,30 @@ import asyncio import datetime from random import shuffle -from typing import KeysView, Optional, TYPE_CHECKING, ValuesView +from typing import TYPE_CHECKING, Optional import discord from discord.backoff import ExponentialBackoff +from discord.voice_client import VoiceProtocol from . import log, ws_rll_log -from .enums import * +from .enums import ( + LavalinkEvents, + LavalinkIncomingOp, + LavalinkOutgoingOp, + PlayerState, + TrackEndReason, +) from .rest_api import RESTClient, Track +from .utils import VoiceChannel if TYPE_CHECKING: from . import node -__all__ = ["user_id", "channel_finder_func", "Player", "PlayerManager"] +__all__ = ["Player"] -user_id = None -channel_finder_func = lambda channel_id: None - -class Player(RESTClient): +class Player(RESTClient, VoiceProtocol): """ The Player class represents the current state of playback. It also is used to control playback and queue tracks. @@ -39,9 +44,16 @@ class Player(RESTClient): shuffle : bool """ - def __init__(self, manager: "PlayerManager", channel: discord.VoiceChannel): - super().__init__(manager.node) - self.bot = manager.bot + def __call__(self, client: discord.Client, channel: VoiceChannel): + self.client: discord.Client = client + self.channel: VoiceChannel = channel + + return self + + def __init__( + self, client: discord.Client = None, channel: VoiceChannel = None, node: "node.Node" = None + ): + self.client = client self.channel = channel self.guild = channel.guild self._last_channel_id = channel.id @@ -56,15 +68,23 @@ def __init__(self, manager: "PlayerManager", channel: discord.VoiceChannel): self._auto_play_sent = False self._volume = 100 self.state = PlayerState.CREATED + self._voice_state = {} self.connected_at = None self._connected = False self._is_playing = False self._metadata = {} - self.manager = manager + if node is None: + from .node import get_node + + node = get_node() + self.node = node + self._con_delay = None self._last_resume = None + super().__init__(self) + def __repr__(self): return ( " bool: """ return self._connected + async def on_voice_server_update(self, data: dict) -> None: + self._voice_state.update({"event": data}) + await self._send_lavalink_voice_update(self._voice_state) + + async def on_voice_state_update(self, data: dict) -> None: + self._voice_state.update({"sessionId": data["session_id"]}) + if (channel_id := data["channel_id"]) is None: + ws_rll_log.info("Received voice disconnect from discord, removing player.") + self._voice_state.clear() + await self.disconnect(force=True) + else: + channel = self.guild.get_channel(int(channel_id)) + if channel != self.channel: + if self.channel: + self._last_channel_id = self.channel.id + self.channel = channel + await self._send_lavalink_voice_update({**self._voice_state, "event": data}) + + async def _send_lavalink_voice_update(self, voice_state: dict): + if voice_state.keys() != {"sessionId", "event"}: + return + + if voice_state["event"].keys() == {"token", "guild_id", "endpoint"}: + await self.node.send( + { + "op": LavalinkOutgoingOp.VOICE_UPDATE.value, + "guildId": str(self.guild.id), + "sessionId": voice_state["sessionId"], + "event": voice_state["event"], + } + ) + async def wait_until_ready( self, timeout: Optional[float] = None, no_raise: bool = False ) -> bool: @@ -139,17 +191,15 @@ async def wait_until_ready( else: raise - async def connect(self, deafen: bool = False, channel: Optional[discord.VoiceChannel] = None): + async def connect(self, timeout: float = 2.0, reconnect: bool = False, deafen: bool = False): """ Connects to the voice channel associated with this Player. """ - self._last_resume = datetime.datetime.now(tz=datetime.timezone.utc) + self._last_resume = datetime.datetime.now(datetime.timezone.utc) self.connected_at = datetime.datetime.now(datetime.timezone.utc) self._connected = True - if channel: - if self.channel: - self._last_channel_id = self.channel.id - self.channel = channel + self.node._players_dict[self.guild.id] = self + await self.node.refresh_player_state(self) await self.guild.change_voice_state( channel=self.channel, self_mute=False, self_deaf=deafen ) @@ -173,7 +223,7 @@ async def move_to(self, channel: discord.VoiceChannel, deafen: bool = False): track=self.current, replace=True, start=self.position, pause=self._paused ) - async def disconnect(self, requested=True): + async def disconnect(self, force=False): """ Disconnects this player from it's voice channel. """ @@ -185,7 +235,7 @@ async def disconnect(self, requested=True): await self.update_state(PlayerState.DISCONNECTING) guild_id = self.guild.id - if not requested: + if force: log.debug("Forcing player disconnect for %r due to player manager request.", self) self.node.event_handler( LavalinkIncomingOp.EVENT, @@ -199,15 +249,11 @@ async def disconnect(self, requested=True): }, ) - voice_ws = self.node.get_voice_ws(guild_id) - - if not voice_ws.socket.closed: + if not self.client.shards[self.guild.shard_id].is_closed(): await self.guild.change_voice_state(channel=None) - await self.node.destroy_guild(guild_id) - await self.close() - - self.manager.remove_player(self) + self.node.remove_player(self) + self.cleanup() def store(self, key, value): """ @@ -234,15 +280,11 @@ async def update_state(self, state: PlayerState): ws_rll_log.debug("Player %r changing state: %s -> %s", self, self.state.name, state.name) - old_state = self.state self.state = state if self._con_delay: self._con_delay = None - if state == PlayerState.READY: - self.reset_session() - async def handle_event(self, event: "node.LavalinkEvents", extra): """ Handles various Lavalink Events. @@ -419,181 +461,3 @@ async def seek(self, position: int): if self.current.seekable: position = max(min(position, self.current.length), 0) await self.node.seek(self.guild.id, position) - - -class PlayerManager: - def __init__(self, node_: "node.Node"): - self._player_dict = {} - self.voice_states = {} - self.bot = node_.bot - self.node = node_ - self.node.register_state_handler(self.node_state_handler) - - @property - def players(self) -> ValuesView[Player]: - return self._player_dict.values() - - @property - def guild_ids(self) -> KeysView[int]: - return self._player_dict.keys() - - async def create_player(self, channel: discord.VoiceChannel, deafen: bool = False) -> Player: - """ - Connects to a discord voice channel. - - This function is safe to repeatedly call as it will return an existing - player if there is one. - - Parameters - ---------- - channel - - Returns - ------- - Player - The created Player object. - """ - if self._already_in_guild(channel): - p = self.get_player(channel.guild.id) - await p.move_to(channel, deafen=deafen) - else: - p = Player(self, channel) - await p.connect(deafen=deafen) - self._player_dict[channel.guild.id] = p - await self.refresh_player_state(p) - return p - - def _already_in_guild(self, channel: discord.VoiceChannel) -> bool: - return channel.guild.id in self._player_dict - - def get_player(self, guild_id: int) -> Player: - """ - Gets a Player object from a guild ID. - - Parameters - ---------- - guild_id : int - Discord guild ID. - - Returns - ------- - Player - - Raises - ------ - KeyError - If that guild does not have a Player, e.g. is not connected to any - voice channel. - """ - if guild_id in self._player_dict: - return self._player_dict[guild_id] - raise KeyError("No such player for that guild.") - - def _ensure_player(self, channel_id: int): - channel = channel_finder_func(channel_id) - if channel is not None: - try: - p = self.get_player(channel.guild.id) - except KeyError: - log.debug("Received voice channel connection without a player.") - p = Player(self, channel) - self._player_dict[channel.guild.id] = p - return p, channel - - async def _remove_player(self, guild_id: int): - try: - p = self.get_player(guild_id) - except KeyError: - pass - else: - del self._player_dict[guild_id] - await p.disconnect(requested=False) - - async def node_state_handler(self, next_state: NodeState, old_state: NodeState): - ws_rll_log.debug("Received node state update: %s -> %s", old_state.name, next_state.name) - if next_state == NodeState.READY: - await self.update_player_states(PlayerState.READY) - elif next_state == NodeState.DISCONNECTING: - await self.update_player_states(PlayerState.DISCONNECTING) - elif next_state in (NodeState.CONNECTING, NodeState.RECONNECTING): - await self.update_player_states(PlayerState.NODE_BUSY) - - async def update_player_states(self, state: PlayerState): - for p in self.players: - await p.update_state(state) - - async def refresh_player_state(self, player: Player): - if self.node.ready: - await player.update_state(PlayerState.READY) - elif self.node.state == NodeState.DISCONNECTING: - await player.update_state(PlayerState.DISCONNECTING) - else: - await player.update_state(PlayerState.NODE_BUSY) - - async def on_socket_response(self, data): - raw_event = data.get("t") - try: - event = DiscordVoiceSocketResponses(raw_event) - except ValueError: - return - - guild_id = data["d"]["guild_id"] - if guild_id not in self.voice_states: - self.voice_states[guild_id] = {} - - if event == DiscordVoiceSocketResponses.VOICE_SERVER_UPDATE: - # Connected for the first time - socket_event_data = data["d"] - self.voice_states[guild_id].update({"guild_id": guild_id, "event": socket_event_data}) - elif event == DiscordVoiceSocketResponses.VOICE_STATE_UPDATE: - channel_id = data["d"]["channel_id"] - event_user_id = int(data["d"].get("user_id")) - - if event_user_id != user_id: - return - - if channel_id is None: - # We disconnected - p = self._player_dict.get(guild_id) - msg = "Received voice disconnect from discord, removing player." - if p: - msg += f" {p}" - ws_rll_log.info(msg) - self.voice_states[guild_id] = {} - await self._remove_player(int(guild_id)) - - else: - # After initial connection, get session ID - p, channel = self._ensure_player(int(channel_id)) - if channel != p.channel: - if p.channel: - p._last_channel_id = p.channel.id - p.channel = channel - - session_id = data["d"]["session_id"] - self.voice_states[guild_id]["session_id"] = session_id - else: - return - data = self.voice_states[guild_id] - if all(k in data for k in ["session_id", "guild_id", "event"]): - await self.node.send_lavalink_voice_update(**self.voice_states[guild_id]) - - async def disconnect(self): - """ - Disconnects all players. - """ - for p in tuple(self.players): - await p.disconnect(requested=False) - log.debug("Disconnected all players.") - - def remove_player(self, player: Player): - if player.state != PlayerState.DISCONNECTING: - log.error( - "Attempting to remove a player (%r) from player list with state: %s", - player, - player.state.name, - ) - return - guild_id = player.channel.guild.id - if guild_id in self._player_dict: - del self._player_dict[guild_id] diff --git a/lavalink/rest_api.py b/lavalink/rest_api.py index c1efb5e0..be2a22c6 100644 --- a/lavalink/rest_api.py +++ b/lavalink/rest_api.py @@ -1,13 +1,15 @@ import re from collections import namedtuple -from typing import Tuple, Union +from typing import TYPE_CHECKING, Tuple, Union from urllib.parse import quote, urlparse -from aiohttp import ClientSession from aiohttp.client_exceptions import ServerDisconnectedError from . import log -from .enums import * +from .enums import ExceptionSeverity, LoadType, PlayerState + +if TYPE_CHECKING: + from . import player_manager __all__ = ["Track", "RESTClient", "PlaylistInfo"] @@ -275,20 +277,16 @@ class RESTClient: Client class used to access the REST endpoints on a Lavalink node. """ - def __init__(self, node): - self.node = node - self._session = None - self._uri = "http://{}:{}/loadtracks?identifier=".format(node.host, node.port) - self._headers = {"Authorization": node.password} - - self.state = PlayerState.CONNECTING + def __init__(self, player: "player_manager.Player"): + self.player = player + self.node = player.node + self._session = self.node.session + self._uri = "http://{}:{}/loadtracks?identifier=".format(self.node.host, self.node.port) + self._headers = {"Authorization": self.node.password} + self.state = player.state self._warned = False - def reset_session(self): - if self._session is None or self._session.closed: - self._session = ClientSession(loop=self.node.loop) - def __check_node_ready(self): if self.state != PlayerState.READY: raise RuntimeError("Cannot execute REST request when node not ready.") @@ -387,8 +385,3 @@ async def search_sc(self, query) -> LoadResult: list of Track """ return await self.load_tracks("scsearch:{}".format(query)) - - async def close(self): - if self._session is not None: - await self._session.close() - log.debug("Closed REST session.") diff --git a/lavalink/utils.py b/lavalink/utils.py index 3c50e6b2..c6547388 100644 --- a/lavalink/utils.py +++ b/lavalink/utils.py @@ -1,6 +1,14 @@ +from typing import Union + +import discord + + def format_time(time): """Formats the given time into HH:MM:SS""" h, r = divmod(time / 1000, 3600) m, s = divmod(r, 60) return "%02d:%02d:%02d" % (h, m, s) + + +VoiceChannel = Union[discord.VoiceChannel, discord.StageChannel] diff --git a/tests/conftest.py b/tests/conftest.py index 40095bd9..e9399491 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import namedtuple from types import SimpleNamespace @@ -59,11 +61,14 @@ def user(): @pytest.fixture def guild(): - Guild = namedtuple("Guild", "id name") - return Guild(987654321, "Testing") + Guild = MagicMock() + Guild.id = 987654321 + Guild.name = "Testing" + Guild.get_channel = lambda channel_id: voice_channel + yield Guild -@pytest.fixture() +@pytest.fixture def voice_channel(guild): VoiceChannel = namedtuple("VoiceChannel", "id guild name") return VoiceChannel(9999999999, guild, "Testing VC") @@ -117,7 +122,6 @@ async def node(bot): node_ = lavalink.node.Node( _loop=bot.loop, event_handler=MagicMock(), - voice_ws_func=bot._connection._get_websocket, host="localhost", password="password", port=2333, @@ -125,6 +129,7 @@ async def node(bot): num_shards=bot.shard_count, resume_key="Test", resume_timeout=60, + bot=bot, ) # node_.send = MagicMock(wraps=send) diff --git a/tests/test_lavalink.py b/tests/test_lavalink.py index 85f1cede..d0264cd3 100644 --- a/tests/test_lavalink.py +++ b/tests/test_lavalink.py @@ -9,9 +9,6 @@ async def test_initialize(bot): await lavalink.initialize(bot, "localhost", "password", 2333, 2333) - assert lavalink.player_manager.user_id == bot.user.id - assert lavalink.player_manager.channel_finder_func == bot.get_channel - assert len(lavalink.node._nodes) == bot.shard_count bot.add_listener.assert_called() diff --git a/tests/test_player_manager.py b/tests/test_player_manager.py index a5282d89..af067e80 100644 --- a/tests/test_player_manager.py +++ b/tests/test_player_manager.py @@ -25,51 +25,32 @@ def func(guild_id=guild.id): def voice_state_update(bot, voice_channel): def func(user_id=bot.user.id, channel_id=voice_channel.id, guild_id=voice_channel.guild.id): return { - "t": "VOICE_STATE_UPDATE", - "s": 84, - "op": 0, - "d": { - "user_id": str(user_id), - "suppress": False, - "session_id": "744d1ac65d00e31fb7ab29fc2436be3e", - "self_video": False, - "self_mute": False, - "self_deaf": False, - "mute": False, - "guild_id": str(guild_id), - "deaf": False, - "channel_id": str(channel_id), - }, + "user_id": str(user_id), + "suppress": False, + "session_id": "744d1ac65d00e31fb7ab29fc2436be3e", + "self_video": False, + "self_mute": False, + "self_deaf": False, + "mute": False, + "guild_id": str(guild_id), + "deaf": False, + "channel_id": str(channel_id), } return func @pytest.mark.asyncio -async def test_autoconnect( - initialize_lavalink, voice_channel, voice_server_update, voice_state_update -): +async def test_autoconnect(bot, voice_channel, voice_server_update, voice_state_update): node = lavalink.node.get_node(voice_channel.guild.id) + node._players_dict[voice_channel.guild.id] = lavalink.player_manager.Player(bot, voice_channel) + player = node.get_player(voice_channel.guild.id) + assert voice_channel.guild.id in set(node.guild_ids) + server = voice_server_update() state = voice_state_update() - await node.player_manager.on_socket_response(server) - - assert voice_channel.guild.id not in set(node.player_manager.guild_ids) - - await node.player_manager.on_socket_response(state) - - send_call = { - "op": "voiceUpdate", - "guildId": str(voice_channel.guild.id), - "sessionId": "744d1ac65d00e31fb7ab29fc2436be3e", - "event": { - "token": "e5bbc4a783a1af5b", - "guild_id": str(voice_channel.guild.id), - "endpoint": "us-west43.discord.gg:80", - }, - } - - node._MOCK_send.assert_called_with(send_call) + await player.on_voice_server_update(server) + await player.on_voice_state_update(state) assert len(lavalink.all_players()) == 1 - assert lavalink.get_player(voice_channel.guild.id).channel == voice_channel + assert node.get_player(voice_channel.guild.id) == player