diff --git a/cogs/error.py b/cogs/error.py index 5c5eae8..2e86a93 100644 --- a/cogs/error.py +++ b/cogs/error.py @@ -3,8 +3,6 @@ from discord.ext import commands -# for finding errors with the code. - class Error(commands.Cog): """Basic class for catching errors and sending a message""" @@ -19,21 +17,19 @@ def __init__(self, bot): @commands.Cog.listener() async def on_command_error(self, ctx: commands.Context, error: commands.CommandError): - if isinstance(error, commands.ExpectedClosingQuoteError): - return await ctx.send(f"Pozor! Chybí tady: {error.close_quote} uvozovka!") - - elif isinstance(error, commands.MissingPermissions): - return await ctx.send("Chybí ti požadovaná práva!") + match error: + case commands.ExpectedClosingQuoteError(): + return await ctx.send(f"Pozor! Chybí tady: {error} uvozovka!") - elif isinstance(error, commands.CommandNotFound): - pass + case commands.MissingPermissions(): + return await ctx.send("Chybí ti požadovaná práva!") - elif isinstance(error, commands.MissingRequiredArgument): - return await ctx.send("Chybí ti povinný argument, zkontroluj si ho znova!") + case commands.CommandNotFound(): + pass - else: - self.logger.critical(f"{ctx.message.id}, {ctx.message.content} | {error}") - print(error) + case _: + self.logger.critical(f"{ctx.message.id}, {ctx.message.content} | {error}") + print(error) @commands.Cog.listener() async def on_command(self, ctx: commands.Context): diff --git a/cogs/morserovka.py b/cogs/morserovka.py index 8cfda13..73da9b3 100644 --- a/cogs/morserovka.py +++ b/cogs/morserovka.py @@ -1,4 +1,5 @@ -from discord import Message +import discord +from discord import Message, app_commands from discord.ext import commands @@ -34,24 +35,31 @@ class Morse(commands.Cog): def __init__(self, bot): self.bot = bot - @commands.command(aliases=["encrypt"]) - async def zasifruj(self, ctx: commands.Context, message: str) -> Message: - await ctx.message.delete() + @app_commands.command( + name="zasifruj", + description="Zašifruj text do morserovky!") + @app_commands.describe( + message="Věta nebo slovo pro zašifrování") + async def zasifruj(self, interaction: discord.Interaction, message: str) -> Message: try: cipher = "/".join(self.MORSE_CODE_DICT.get(letter.upper()) for letter in message) - return await ctx.send(cipher) + return await interaction.response.send_message(cipher) except TypeError: - return await ctx.send("Asi jsi nezadal správný text. Text musí být bez speciálních znaků!") + return await interaction.response.send_message( + "Asi jsi nezadal správný text. Text musí být bez speciálních znaků!") - @commands.command(aliases=["decrypt"]) - async def desifruj(self, ctx: commands.Context, message: str) -> Message: - await ctx.message.delete() + @app_commands.command( + name="desifruj", + description="Dešifruj text z morserovky!") + @app_commands.describe( + message="Věta nebo slovo pro dešifrování") + async def desifruj(self, interaction: discord.Interaction, message: str) -> Message: try: decipher = ''.join(self.REVERSED_MORSE_CODE_DICT.get(letter) for letter in message.split("/")) - return await ctx.send(decipher) + return await interaction.response.send_message(decipher) except TypeError: decipher = ''.join(self.REVERSED_MORSE_CODE_DICT.get(x) for x in message.split("|")) - return await ctx.send(decipher) + return await interaction.response.send_message(decipher) async def setup(bot): diff --git a/cogs/poll_command.py b/cogs/poll_command.py index 4010e8e..cbb73a7 100644 --- a/cogs/poll_command.py +++ b/cogs/poll_command.py @@ -1,4 +1,7 @@ +import discord +from discord import app_commands from discord.ext import commands +from loguru import logger from src.db_folder.databases import PollDatabase, VoteButtonDatabase from src.jachym import Jachym @@ -11,7 +14,7 @@ from src.ui.poll_view import PollView -def error_handling(answer: tuple[str]) -> str: +def error_handling(answer: list[str]) -> str: if len(answer) > Poll.MAX_OPTIONS: return f"Zadal jsi příliš mnoho odpovědí, můžeš maximálně {Poll.MAX_OPTIONS}!" elif len(answer) < Poll.MIN_OPTIONS: @@ -19,26 +22,30 @@ def error_handling(answer: tuple[str]) -> str: class PollCreate(commands.Cog): - COOLDOWN = 10 - def __init__(self, bot: Jachym): self.bot = bot - @commands.command(aliases=["anketa"]) - @commands.cooldown(1, COOLDOWN, commands.BucketType.user) - async def pool(self, ctx: commands.Context, question: str, *answer: str): - await ctx.message.delete() + @app_commands.command( + name="anketa", + description="Anketa pro hlasování. Jsou vidět všichni hlasovatelé.") + @app_commands.describe( + question="Otázka, na kterou potřebuješ vědět odpověď", + answer='Odpovědi, rozděluješ odpovědi uvozovkou ("), maximálně pouze 10 možností') + async def pool(self, interaction: discord.Interaction, question: str, answer: str) -> discord.Message: + + await interaction.response.send_message(embed=PollEmbedBase("Dělám na tom, vydrž!")) + message = await interaction.original_response() - message = await ctx.send(embed=PollEmbedBase("Dělám na tom, vydrž!")) - if error_handling(answer): - return await message.edit(embed=PollEmbedBase(error_handling(answer))) + answers = answer.split(sep='"') + if error_handling(answers): + return await message.edit(embed=PollEmbedBase(error_handling(answers))) poll = Poll( message_id=message.id, channel_id=message.channel.id, question=question, - options=answer, - user_id=ctx.message.author.id + options=answers, + user_id=interaction.user.id ) embed = PollEmbed(poll) @@ -48,14 +55,23 @@ async def pool(self, ctx: commands.Context, question: str, *answer: str): self.bot.active_discord_polls.add(poll) await self.bot.set_presence() - - await message.edit(embed=embed, view=view) + logger.info(f"Successfully added Pool - {message.id}") + return await message.edit(embed=embed, view=view) @pool.error async def pool_error(self, ctx: commands.Context, error): if isinstance(error, commands.CommandOnCooldown): await ctx.send(embed=CooldownErrorEmbed(error.retry_after)) + @commands.command() + async def anketa(self, ctx): + return await ctx.send( + "Ahoj! Tahle funkce teď už bohužel nebude fungovat :(\n" + "Ale neboj se! Do Jáchyma jsou už teď implementovány slash commands, takže místo vykříčníku teď dej /, " + "kde najdeš všechny funkce co teď Jáchym má! :)\n" + "Ještě jedna maličká věc - já jsem vyvíjený už pomalu třetí rok a můj autor by po mně chtěl, abych Ti poslal odkaz na formulář, kde by rád zpětnou vazbu na mě, jestli odvádím dobrou práci: https://forms.gle/1dFq84Ng39vdkxVQ7\n" + "Moc ti děkuji! A díky, že mě používáš! :)") + async def setup(bot): await bot.add_cog(PollCreate(bot)) diff --git a/cogs/sync_command.py b/cogs/sync_command.py new file mode 100644 index 0000000..2510c0d --- /dev/null +++ b/cogs/sync_command.py @@ -0,0 +1,76 @@ +from typing import Literal, Optional, TYPE_CHECKING + +import discord +from discord.ext import commands +from discord.ext.commands import Greedy, Context + +if TYPE_CHECKING: + from src.jachym import Jachym + + +class SyncSlashCommands(commands.Cog): + def __init__(self, bot: "Jachym"): + self.bot = bot + + @commands.command() + @commands.guild_only() + @commands.is_owner() + async def sync( + self, + ctx: Context, + guilds: Greedy[discord.Guild], + spec: Optional[Literal["-", "*", "^"]] = None) -> None: + """ + A command to sync all slash commands to servers user requires. Works like this: + !sync + global sync - syncs all slash commands with all guilds + !sync - + sync current guild + !sync * + copies all global app commands to current guild and syncs + !sync ^ + clears all commands from the current guild target and syncs (removes guild commands) + !sync id_1 id_2 + syncs guilds with id 1 and 2 + + Args: + ctx: commands.Context + guilds: Greedy[discord.Object] + spec: Optional[Literal] + + Returns: Synced slash command + + """ + + if not guilds: + if spec == "-": + synced = await self.bot.tree.sync(guild=ctx.guild) + elif spec == "*": + self.bot.tree.copy_global_to(guild=ctx.guild) + synced = await self.bot.tree.sync(guild=ctx.guild) + elif spec == "^": + self.bot.tree.clear_commands(guild=ctx.guild) + await self.bot.tree.sync(guild=ctx.guild) + synced = [] + else: + synced = await self.bot.tree.sync() + + await ctx.send( + f"Synced {len(synced)} commands {'globally' if spec is None else 'to the current guild.'}" + ) + return + + ret = 0 + for guild in guilds: + try: + await self.bot.tree.sync(guild=guild) + except discord.HTTPException: + pass + else: + ret += 1 + + await ctx.send(f"Synced the tree to {ret}/{len(guilds)}.") + + +async def setup(bot): + await bot.add_cog(SyncSlashCommands(bot)) diff --git a/cogs/utility.py b/cogs/utility.py index b8954b4..59664e3 100644 --- a/cogs/utility.py +++ b/cogs/utility.py @@ -1,6 +1,7 @@ import datetime -from discord import Message +import discord +from discord import Message, app_commands from discord.ext import commands from src.ui.embeds import EmbedFromJSON @@ -12,18 +13,24 @@ class Utility(commands.Cog): def __init__(self, bot): self.bot = bot - @commands.command(pass_context=True, aliases=['help']) - async def pomoc(self, ctx: commands.Context) -> Message: + @app_commands.command( + name="pomoc", + description="Pomocníček, který ti pomůže s různými věcmi.") + async def pomoc(self, interaction: discord.Interaction) -> Message: embed = EmbedFromJSON().add_fields_from_json("help") - return await ctx.send(embed=embed) + return await interaction.response.send_message(embed=embed, ephemeral=True) - @commands.command(pass_context=True) - async def rozcestnik(self, ctx: commands.Context) -> Message: + @app_commands.command( + name="rozcestnik", + description="Všechny věci, co skaut potřebuje. Odkazy na webové stránky.") + async def rozcestnik(self, interaction: discord.Interaction) -> Message: embed = EmbedFromJSON().add_fields_from_json("rozcestnik") - return await ctx.send(embed=embed) + return await interaction.response.send_message(embed=embed, ephemeral=True) - @commands.command(pass_context=True) - async def ping(self, ctx: commands.Context) -> Message: + @app_commands.command( + name="ping", + description="Něco trvá dlouho? Koukni se, jestli není vysoká latence") + async def ping(self, interaction: discord.Interaction) -> Message: ping = round(self.bot.latency * 1000) if ping < 200: message = f'🟢 {ping} milisekund.' @@ -32,7 +39,7 @@ async def ping(self, ctx: commands.Context) -> Message: else: message = f'🔴 {ping} milisekund.' - return await ctx.send(message) + return await interaction.response.send_message(message, ephemeral=True) @commands.command(pass_context=True, aliases=["smazat"]) @commands.has_permissions(administrator=True) diff --git a/requirements.txt b/requirements.txt index 41472f3..79fb711 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ python-dotenv==0.17.1 -aiomysql>=0.0.22 \ No newline at end of file +aiomysql>=0.0.22 +pytest>=7.3.1 +loguru>=0.7.0 \ No newline at end of file diff --git a/src/db_folder/databases.py b/src/db_folder/databases.py index b36ae10..35913aa 100644 --- a/src/db_folder/databases.py +++ b/src/db_folder/databases.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Optional, Union, AsyncIterator +from typing import Optional, AsyncIterator, TYPE_CHECKING import aiomysql import discord.errors @@ -7,24 +7,27 @@ from src.ui.poll import Poll +if TYPE_CHECKING: + from ..jachym import Jachym + class Crud(ABC): def __init__(self, poll: aiomysql.pool.Pool): self.poll = poll - async def commit_value(self, sql: str, value: tuple): + async def commit_value(self, sql: str, value: tuple) -> None: async with self.poll.acquire() as conn: cursor = await conn.cursor() await cursor.execute(sql, value) await conn.commit() - async def commit_many_values(self, sql: str, values: list[tuple]): + async def commit_many_values(self, sql: str, values: list[tuple]) -> None: async with self.poll.acquire() as conn: cursor = await conn.cursor() await cursor.executemany(sql, values) await conn.commit() - async def fetch_all_values(self, sql: str, value: Optional[tuple] = None): + async def fetch_all_values(self, sql: str, value: Optional[tuple] = None) -> list[tuple]: async with self.poll.acquire() as conn: cursor = await conn.cursor() await cursor.execute(sql, value) @@ -37,7 +40,7 @@ class PollDatabase(Crud): def __init__(self, database_poll: aiomysql.pool.Pool): super().__init__(database_poll) - async def add(self, discord_poll: Poll): + async def add(self, discord_poll: Poll) -> None: sql = "INSERT INTO `Poll`(message_id, channel_id, question, date_created_at, creator_user) " \ "VALUES (%s, %s, %s, %s, %s)" values = ( @@ -50,7 +53,7 @@ async def add(self, discord_poll: Poll): await self.commit_value(sql, values) - async def remove(self, message_id: int): + async def remove(self, message_id: int) -> None: sql = "DELETE FROM `Poll` WHERE message_id = %s" value = (message_id,) @@ -70,7 +73,7 @@ async def fetch_all_answers(self, message_id) -> tuple[str, ...]: return answers - async def fetch_all_polls(self, bot) -> AsyncIterator[Union[Poll, Message]]: + async def fetch_all_polls(self, bot: "Jachym") -> AsyncIterator[Poll | Message]: sql = "SELECT * FROM `Poll`" polls = await self.fetch_all_values(sql) diff --git a/src/helpers.py b/src/helpers.py index 36d7665..e0564ec 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -1,16 +1,17 @@ import time from functools import wraps +from loguru import logger -def timeit(func): + +def timeit(func: callable): @wraps(func) async def async_wrapper(*args, **kwargs): - print(f"{func.__name__} starting...") + logger.info(f"{func.__name__} starting...") start = time.time() result = await func(*args, **kwargs) duration = time.time() - start - print(f'{func.__name__} took {duration:.2f} seconds') - + logger.info(f'{func.__name__} took {duration:.2f} seconds') return result return async_wrapper diff --git a/src/jachym.py b/src/jachym.py index c4c77f5..8e85d5f 100644 --- a/src/jachym.py +++ b/src/jachym.py @@ -6,6 +6,7 @@ from aiomysql import create_pool from discord.ext import commands from dotenv import load_dotenv +from loguru import logger from src.db_folder.databases import PollDatabase from src.helpers import timeit @@ -17,6 +18,7 @@ class Jachym(commands.Bot): MY_BIRTHDAY = "27.12.2020" + OWNER_ID = 337971071485607936 def __init__(self): # https://discordpy.readthedocs.io/en/stable/intents.html @@ -26,7 +28,8 @@ def __init__(self): super().__init__( command_prefix=commands.when_mentioned_or("!"), intents=discord.Intents.all(), - help_command=None + help_command=None, + owner_id=self.OWNER_ID ) @timeit @@ -37,14 +40,23 @@ async def _fetch_pools_from_database(self) -> None: self.add_view(PollView(poll=poll, embed=message.embeds[0], db_poll=self.pool)) self.active_discord_polls.add(poll) - print(f"There are now {len(self.active_discord_polls)} active pools!") + logger.success(f"There are now {len(self.active_discord_polls)} active pools!") async def set_presence(self): activity_name = f"Jsem na {len(self.guilds)} serverech a mám spuštěno {len(self.active_discord_polls)} anket!" await self.change_presence(activity=discord.Game(name=activity_name)) - @commands.Cog.listener() - async def on_ready(self): + async def load_extensions(self): + for filename in listdir("cogs/"): + if filename.endswith(".py"): + try: + await self.load_extension(f"cogs.{filename[:-3]}") + logger.success(f"{filename[:-3]} has loaded successfully") + except Exception as error: + logger.error(error) + + async def setup_hook(self): + logger.info("Getting setup ready...") self.pool = await create_pool( user=getenv("USER_DATABASE"), password=getenv("PASSWORD"), @@ -54,16 +66,10 @@ async def on_ready(self): maxsize=20) await self._fetch_pools_from_database() - await self.set_presence() - print("Ready!") + logger.success("Setup ready!") - async def load_extensions(self): - for filename in listdir("cogs/"): - if filename.endswith(".py"): - try: - await self.load_extension(f"cogs.{filename[:-3]}") - - print(f"{filename[:-3]} has loaded successfully") - except Exception as error: - raise error + @commands.Cog.listener() + async def on_ready(self): + await self.set_presence() + logger.success("Bot online!") diff --git a/src/ui/button.py b/src/ui/button.py index 6f172c2..afa46bc 100644 --- a/src/ui/button.py +++ b/src/ui/button.py @@ -7,6 +7,10 @@ class ButtonBackend(discord.ui.Button): + """ + Button class to edit a poll embed with + """ + def __init__(self, custom_id: str, poll: Poll, diff --git a/src/ui/poll.py b/src/ui/poll.py index 50333b5..13b90a8 100644 --- a/src/ui/poll.py +++ b/src/ui/poll.py @@ -3,6 +3,9 @@ class Poll: + """ + Slot class for each Pool object. + """ MAX_OPTIONS = 10 MIN_OPTIONS = 2 @@ -20,7 +23,7 @@ def __init__( message_id: int, channel_id: int, question: str, - options: tuple[str, ...], + options: list[str, ...], user_id: Optional[int] = None, date_created: Optional[Union[datetime, str]] = datetime.now().strftime("%Y-%m-%d") ): @@ -44,7 +47,7 @@ def question(self) -> str: return self._question @property - def options(self) -> tuple[str, ...]: + def options(self) -> list[str, ...]: return self._options @property diff --git a/tests/test_pool.py b/tests/test_pool.py new file mode 100644 index 0000000..c0dc6bc --- /dev/null +++ b/tests/test_pool.py @@ -0,0 +1,31 @@ +import datetime + +import pytest + +from src.ui.poll import Poll + +MESSAGE_ID = 123456789 +CHANNEL_ID = 123456789 +QUESTION = "Test No. 1" +OPTIONS = ["1", "2", "3"] +USER_ID = 123456789 + + +@pytest.mark.parametrize( + "date_test", + [ + datetime.datetime.now(), + datetime.date.today(), + datetime.datetime.now().strftime("%Y-%m-%d"), + ] +) +def test_datetime(date_test): + pool = Poll( + MESSAGE_ID, + CHANNEL_ID, + QUESTION, + OPTIONS, + USER_ID, + date_test + ) + assert pool.created_at == datetime.date.today()