diff --git a/cogs/error.py b/cogs/error.py index 3f9f4af..318fd99 100644 --- a/cogs/error.py +++ b/cogs/error.py @@ -1,25 +1,25 @@ -# import logging +import logging -from discord import Message from discord.ext import commands +# for finding errors with the code. +# TODO: SPRAV TO UŽ!!!!! + class Error(commands.Cog): """Basic class for catching errors and sending a message""" def __init__(self, bot): self.bot = bot - # TODO: spravit logger - - # self.logger = logging.getLogger('discord') - # self.logger.setLevel(logging.WARN) - # handler = logging.FileHandler(filename='discord.log', encoding='utf-8', mode='w') - # handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s:%(name)s: %(message)s')) - # self.logger.addHandler(handler) + self.logger = logging.getLogger('discord') + self.logger.setLevel(logging.WARN) + handler = logging.FileHandler(filename='discord.log', encoding='utf-8', mode='w') + handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s:%(name)s: %(message)s')) + self.logger.addHandler(handler) @commands.Cog.listener() - async def on_command_error(self, ctx: commands.Context, error: commands.CommandError) -> Message | None: + async def on_command_error(self, ctx: commands.Context, error: commands.CommandError): if isinstance(error, commands.ExpectedClosingQuoteError): return await ctx.send(f"Pozor! Chybí tady: {error.close_quote} uvozovka!") @@ -35,15 +35,11 @@ async def on_command_error(self, ctx: commands.Context, error: commands.CommandE else: # self.logger.critical(f"{ctx.message.id}, {ctx.message.content} | {error}") - return await ctx.send( - f"O této chybě ještě nevím a nebyla zaznamenána. Napiš The Xero#1273 o této chybě.\n" - f"Text chyby: `{error}`\n" - f"Číslo chyby: `{ctx.message.id}`" - ) + print(error) - # @commands.Cog.listener() - # async def on_command(self, ctx: commands.Context): - # self.logger.info(f"{ctx.message.id} {ctx.message.content}") + @commands.Cog.listener() + async def on_command(self, ctx: commands.Context): + self.logger.info(f"{ctx.message.id} {ctx.message.content}") async def setup(bot): diff --git a/cogs/event.py b/cogs/event.py deleted file mode 100644 index 3041394..0000000 --- a/cogs/event.py +++ /dev/null @@ -1,279 +0,0 @@ -import datetime -import json - -import discord -from discord import Message -from discord.ext import commands, tasks - -from db_folder.sqldatabase import AioSQL - - -class EventSystem(commands.Cog): - """Class for event system, creating pools and sending a message on exact day""" - - def __init__(self, bot): - self.bot = bot - self.pool = self.bot.pool - self.caching = set() - - self.cache.start() - self.send_events.start() - - # Pro pretty-print dnů v týdnu, páč z nějakýho důvodu ten lokální na mašině nejede jak má. Řešeno tímto. - self.weekdays = { - "Monday": "Pondělí", - "Tuesday": "Úterý", - "Wednesday": "Středa", - "Thursday": "Čtvrtek", - "Friday": "Pátek", - "Saturday": "Sobota", - "Sunday": "Neděle" - } - - # Caching systém, oproti caching systému ve poll.py se tento vždy smaže pokud je event odeslán a zpracován. - @tasks.loop(minutes=30) - async def cache(self) -> set[int, ...]: - async with AioSQL(self.pool) as db: - query = "SELECT `EventEmbedID` FROM `EventPlanner`" - tuples = await db.query(query=query) - - self.caching = { - int(clean_variable) - for variable in tuples - for clean_variable in variable} - - return self.caching - - @cache.before_loop - async def before_cache(self): - await self.bot.wait_until_ready() - - # Ověřuje databázi jestli něco není starší než dané datum a pak jej pošle. Změněno na 1 minutu, něco mi tam shazuje - # connection k databázi - @tasks.loop(minutes=1) - async def send_events(self): - async with AioSQL(self.pool) as db: - - sql = "SELECT * FROM EventPlanner;" - result = await db.query(query=sql) - - for GuildID, EventID, EventTitle, EventDescription, EventDate, ChannelID in result: - if EventDate > datetime.datetime.now(): - continue - try: - sql = "SELECT ReactionUser FROM ReactionUsers WHERE EventEmbedID = %s;" - result = await db.query(query=sql, val=(EventID,)) - - members = { - await self.bot.fetch_user(int(x)) - for ID in result - for x in ID} - - embed = discord.Embed( - title=f"**Pořádá se akce:** \n{EventTitle}", - description=f"{EventDescription}", - colour=discord.Colour.gold()) - - file = discord.File("fotky/trojuhelnik.png", filename="trojuhelnik.png") - embed.set_thumbnail(url="attachment://trojuhelnik.png") - - if len(members) == 0: - embed.add_field(name="Účastníci", value="Nikdo nejede.") - else: - embed.add_field(name="Účastníci", value=f"{','.join(user.mention for user in members)}") - - channel = self.bot.get_channel(int(ChannelID)) - await channel.send(file=file, embed=embed) - - sql2 = "DELETE FROM EventPlanner WHERE EventEmbedID = %s;" - await db.execute(sql2, (EventID,), commit=True) - - msg = await channel.fetch_message(EventID) - await msg.delete() - except discord.errors.NotFound: - sql2 = "DELETE FROM EventPlanner WHERE EventEmbedID = %s;" - await db.execute(sql2, (EventID,), commit=True) - - @send_events.before_loop - async def before_send_events(self): - await self.bot.wait_until_ready() - - # help systém pro to. - @commands.group(invoke_without_command=True) - async def udalost(self, ctx: commands.Context) -> Message: - with open("text_json/cz_text.json") as f: - test = json.load(f) - - embed = discord.Embed.from_dict(test["udalost"]) - embed.set_footer(text=self.bot.user.name, icon_url=self.bot.user.avatar_url) - - return await ctx.send(embed=embed) - - @udalost.command() - async def create(self, ctx: commands.Context, title: str, description: str, eventdatetime: str) -> Message: - - datetime_formatted = datetime.datetime.strptime(eventdatetime, '%d.%m.%Y=2022 %H:%M') - - if datetime.datetime.now() > datetime_formatted: - return await ctx.send("Nemůžeš zakládat událost, která se stala v minulosti!") - - await ctx.message.delete() - embed = discord.Embed(title=title, description=description, colour=discord.Colour.gold()) - embed.add_field(name="Datum", - value=f"{self.weekdays[datetime_formatted.strftime('%A')]}, " - f"{datetime_formatted:%d.%m.%Y %H:%M}") - - embed.add_field(name="Ano, pojedu:", value="0 |", inline=False) - embed.add_field(name="Ne, nejedu:", value="0 |", inline=False) - embed.add_field(name="Ještě nevím:", value="0 |", inline=False) - reactions = ["✅", "❌", "❓"] - - sent = await ctx.send(embed=embed) - for reaction in reactions: - await sent.add_reaction(reaction) - - sql = """INSERT INTO `EventPlanner` ( - GuildID, - EventEmbedID, - EventTitle, - EventDescription, - EventDate, - ChannelID - ) VALUES (%s, %s, %s, %s, %s, %s)""" - val = (ctx.guild.id, sent.id, title, description, datetime_formatted, ctx.channel.id) - - async with AioSQL(self.pool) as db: - await db.execute(sql, val, commit=True) - - self.caching.add(sent.id) - - @udalost.command() - async def vypis(self, ctx: commands.Context) -> Message: - sql = """ - SELECT EventTitle, EventDescription, EventDate - FROM EventPlanner - WHERE GuildID = %s - ORDER BY EventDate; """ - - async with AioSQL(self.pool) as db: - result = await db.query(query=sql, val=(ctx.guild.id,)) - embed = discord.Embed(title="Výpis všech událostí", colour=discord.Colour.gold()) - - for title, description, date in result: - embed.add_field( - name=title, - value=f"{self.weekdays[date.strftime('%A')]}, {date: %d.%m.%Y %H:%M}\n{description}", - inline=False) - - return await ctx.send(embed=embed) - - # Smaže event z databáze pomocí ID embedu. Přijít na lepší způsob? - @udalost.command(aliases=["delete"]) - async def smazat(self, ctx: commands.Context, embed: Message): - async with AioSQL(self.pool) as db: - try: - sql = "DELETE FROM EventPlanner WHERE EventEmbedID = %s;" - await db.execute(sql, (embed.id,), commit=True) - - msg = await ctx.fetch_message(embed.id) - await msg.delete() - - await ctx.send("Úspěšně smazán event") - - except discord.errors.NotFound: - await ctx.send("Zkontroluj si číslo, páč tento není v mé paměti. Možná jsi to blbě napsal?") - - # To stejné, akorát s každou reakcí se dává záznam do databáze. Nějak to vylepšit? Přijít na způsob jak to udělat - @commands.Cog.listener() - async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> Message: - if payload.message_id in self.caching: - - channel = self.bot.get_channel(payload.channel_id) - message = await channel.fetch_message(payload.message_id) - embed = message.embeds[0] - reaction = discord.utils.get(message.reactions, emoji=payload.emoji.name) - - vypis_hlasu = [ - user.display_name - async for user in reaction.users() - if not user.id == self.bot.user.id - ] - - match payload.emoji.name: - case "✅": - edit = embed.set_field_at( - 1, - name="Ano, pojedu:", - value=f"{len(vypis_hlasu)} | {', '.join(vypis_hlasu)}", - inline=False) - - async with AioSQL(self.pool) as db: - sql = """ INSERT INTO `ReactionUsers` ( - EventEmbedID, - ReactionUser - ) VALUES (%s, %s)""" - val = (payload.message_id, payload.user_id) - await db.execute(sql, val, commit=True) - - return await reaction.message.edit(embed=edit) - - case "❌": - edit = embed.set_field_at( - 2, - name="Ne, nejedu:", - value=f"{len(vypis_hlasu)} | {', '.join(vypis_hlasu)}", - inline=False) - - return await reaction.message.edit(embed=edit) - - case "❓": - edit = embed.set_field_at( - 3, - name="Ještě nevím:", - value=f"{len(vypis_hlasu)} | {', '.join(vypis_hlasu)}", inline=False) - - return await reaction.message.edit(embed=edit) - - @commands.Cog.listener() - async def on_raw_reaction_remove(self, payload: discord.RawReactionActionEvent) -> Message: - if payload.message_id in self.caching: - - channel = self.bot.get_channel(payload.channel_id) - message = await channel.fetch_message(payload.message_id) - embed = message.embeds[0] - reaction = discord.utils.get(message.reactions, emoji=payload.emoji.name) - - vypis_hlasu = [user.display_name - async for user in reaction.users() - if not user.id == self.bot.user.id] - - match payload.emoji.name: - case "✅": - edit = embed.set_field_at( - 1, - name="Ano, pojedu:", - value=f"{len(vypis_hlasu)} | {', '.join(vypis_hlasu)}", - inline=False) - - return await reaction.message.edit(embed=edit) - - case "❌": - edit = embed.set_field_at( - 2, - name="Ne, nejedu:", - value=f"{len(vypis_hlasu)} | {', '.join(vypis_hlasu)}", - inline=False) - - return await reaction.message.edit(embed=edit) - - case "❓": - edit = embed.set_field_at( - 3, - name="Ještě nevím:", - value=f"{len(vypis_hlasu)} | {', '.join(vypis_hlasu)}", inline=False) - - return await reaction.message.edit(embed=edit) - - -async def setup(bot): - await bot.add_cog(EventSystem(bot)) diff --git a/cogs/newpollstyle.py b/cogs/newpollstyle.py new file mode 100644 index 0000000..33ecf00 --- /dev/null +++ b/cogs/newpollstyle.py @@ -0,0 +1,43 @@ +from discord.ext import commands + +from db_folder.sqldatabase import PollDatabase, VoteButtonDatabase +from poll_design.poll import Poll +from poll_design.poll_view import PollView +from ui.poll_embed import PollEmbed, PollEmbedBase + + +def error_handling(answer: tuple[str]) -> str: + if len(answer) > Poll.MAX_OPTIONS: + return f"Zadal jsi příliš mnoho odpovědí, můžeš maximálně {Poll.MAX_OPTIONS}!" + elif len(answer) < Poll.MIN_OPTIONS: + return f"Zadal jsi příliš málo odpovědí, můžeš alespoň {Poll.MIN_OPTIONS}!" + + +class PollCreate(commands.Cog): + def __init__(self, bot): + self.bot = bot + + @commands.command() + async def poll(self, ctx: commands.Context, question: str, *answer: str): + message = await ctx.send(embed=PollEmbedBase("Dělám na tom, vydrž!")) + if error_handling(answer): + return await message.edit(embed=PollEmbedBase(error_handling(answer))) + + poll = Poll( + message_id=message.id, + channel_id=message.channel.id, + question=question, + options=answer, + user_id=ctx.message.author.id + ) + + embed = PollEmbed(poll) + view = PollView(poll, embed, db_poll=self.bot.pool) + await PollDatabase(self.bot.pool).add(poll) + await VoteButtonDatabase(self.bot.pool).add_options(poll) + + await message.edit(embed=embed, view=view) + + +async def setup(bot): + await bot.add_cog(PollCreate(bot)) diff --git a/cogs/poll.py b/cogs/poll.py deleted file mode 100644 index bf78b1a..0000000 --- a/cogs/poll.py +++ /dev/null @@ -1,130 +0,0 @@ -import datetime - -import discord -from discord import PartialEmoji -from discord.ext import commands, tasks - -from db_folder.sqldatabase import AioSQL - - -class Poll(commands.Cog): - """Class for Poll system""" - - def __init__(self, bot): - self.bot = bot - self.pool = self.bot.pool - - self.cache.start() - self.caching = set() - - # emoji na embedu : index v embedu - self.emoji = { - PartialEmoji(name="1️⃣"): 0, - PartialEmoji(name="2️⃣"): 1, - PartialEmoji(name="3️⃣"): 2, - PartialEmoji(name="4️⃣"): 3, - PartialEmoji(name="5️⃣"): 4, - PartialEmoji(name="6️⃣"): 5, - PartialEmoji(name="7️⃣"): 6, - PartialEmoji(name="8️⃣"): 7, - PartialEmoji(name="9️⃣"): 8, - PartialEmoji(name="🔟"): 9, - } - - # RawReaction pro pool systém, automaticky rozpozná jestli někdo reaguje a dá tak odpovídající reakci na tu anketu - async def reaction_add_remove(self, payload: discord.RawReactionActionEvent) -> discord.Message: - if payload.message_id in self.caching: - channel = self.bot.get_channel(payload.channel_id) - message = await channel.fetch_message(payload.message_id) - - embed = message.embeds[0] - reaction = discord.utils.get(message.reactions, emoji=payload.emoji.name) - - # index pro edit specifického řádku v embedu - i = self.emoji[payload.emoji] - - vypis_hlasu = [ - user.display_name - async for user in reaction.users() - if not user.id == self.bot.user.id] - - edit = embed.set_field_at( - i, - name=embed.fields[i].name, - value=f"**{len(vypis_hlasu)}** | {', '.join(vypis_hlasu)}", - inline=False) - - return await reaction.message.edit(embed=edit) - - @commands.command() - async def anketa(self, ctx: commands.Context, question: str, *answer: str) -> discord.Message: - await ctx.message.delete() - - if len(answer) > 10: - return await ctx.send("Zadal jsi příliš mnoho odpovědí, maximum je 10!") - - elif len(answer) <= 10: - embed = discord.Embed( - title="📊 " + question, - timestamp=ctx.message.created_at, - color=0xff0000) - embed.set_footer(text=f"Anketu vytvořil {ctx.message.author.display_name}") - - reactions = ['1️⃣', '2️⃣', '3️⃣', '4️⃣', '5️⃣', '6️⃣', '7️⃣', '8️⃣', '9️⃣', '🔟'] - - for x, option in enumerate(answer): - embed.add_field( - name=f"{reactions[x]} {option}", - value="**0** |", - inline=False) - - embed.set_author(name=ctx.author.display_name, icon_url=ctx.author.avatar_url) - embed.set_footer(text=self.bot.user.name, icon_url=self.bot.user.avatar_url) - - sent = await ctx.send(embed=embed) - - for reaction in reactions[:len(answer)]: - await sent.add_reaction(reaction) - - async with AioSQL(self.pool) as db: - sql = "INSERT INTO `Poll`(PollID, DateOfPoll) VALUES (%s, %s)" - val = (sent.id, datetime.date.today()) - - await db.execute(sql, val, commit=True) - - self.caching.add(sent.id) - - @commands.Cog.listener() - async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent): - await self.reaction_add_remove(payload=payload) - - @commands.Cog.listener() - async def on_raw_reaction_remove(self, payload: discord.RawReactionActionEvent): - await self.reaction_add_remove(payload=payload) - - # Caching systém pro databázi, ať discord bot nebombarduje furt databázi a vše udržuje ve své paměti - @tasks.loop(minutes=30) - async def cache(self) -> set[int, ...]: - async with AioSQL(self.pool) as db: - # Query pro to, aby se každý záznam, který je starší než měsíc, smazal - query2 = "DELETE FROM `Poll` WHERE `DateOfPoll` < NOW() - INTERVAL 30 DAY" - await db.execute(query2, commit=True) - - query = "SELECT `PollID` FROM `Poll`" - tuples = await db.query(query=query) - - # ořezání všeho co tam je, předtím to bylo ve tvaru [('987234', ''..)] - self.caching = { - int(clean_variable) - for variable in tuples - for clean_variable in variable} - - return self.caching - - @cache.before_loop - async def before_cache(self): - await self.bot.wait_until_ready() - - -async def setup(bot): - await bot.add_cog(Poll(bot)) diff --git a/cogs/utility.py b/cogs/utility.py index 236570b..37a7c9f 100644 --- a/cogs/utility.py +++ b/cogs/utility.py @@ -1,11 +1,10 @@ import datetime import json -from itertools import cycle import discord from discord import Message from discord import app_commands -from discord.ext import commands, tasks +from discord.ext import commands from discord.ext.commands import has_permissions @@ -15,14 +14,8 @@ class Utility(commands.Cog): def __init__(self, bot): self.bot = bot - self.news = cycle([ - f"Jsem na {len(self.bot.guilds)} serverech!", - "Pomoc? !help", - ]) - - self.pressence.start() - - def json_to_embed(self, root_name: str) -> discord.Embed: + @staticmethod + def json_to_embed(root_name: str) -> discord.Embed: with open("text_json/cz_text.json") as f: text = json.load(f) embed = discord.Embed.from_dict(text[root_name]) @@ -71,20 +64,6 @@ async def clear(self, ctx: commands.Context, limit: int) -> Message: else: return await ctx.send("Limit musí být někde mezi 1 nebo 99!") - @tasks.loop(seconds=10) - async def pressence(self): - # proč tady? https://stackoverflow.com/questions/59126137/how-to-change-discord-py-bot-activity - - await self.bot.change_presence( - activity=discord.Game( - name=next(self.news) - ) - ) - - @pressence.before_loop - async def before_cache(self): - await self.bot.wait_until_ready() - @commands.command() async def time(self, ctx: commands.Context): return await ctx.send(str(datetime.datetime.now())) diff --git a/db_folder/__init__.py b/db_folder/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/db_folder/sqldatabase.py b/db_folder/sqldatabase.py index 0b13238..cea7fee 100644 --- a/db_folder/sqldatabase.py +++ b/db_folder/sqldatabase.py @@ -1,37 +1,143 @@ import aiomysql +from poll_design.poll import Poll -class AioSQL: - """Async context manager for aiomysql, this produces minimal reproducible results.""" - def __init__(self, bot_pool): - self.database = bot_pool - self.connection = None - self.cursor = None +# TODO: Reformat it, redundant code all over here!!! - async def __aenter__(self): - self.connection = await self.database.acquire() - self.cursor = await self.connection.cursor() +class Crud: + def __init__(self, poll: aiomysql.pool.Pool): + self.poll = poll - return self - async def __aexit__(self, exc_type, exc_value, exc_traceback): - try: - await self.cursor.close() - await self.database.release(self.connection) +class PollDatabase(Crud): + def __init__(self, database_poll: aiomysql.pool.Pool): + super().__init__(database_poll) - except aiomysql.Error: - # TODO: use logging - self.database.rollback() - return True + async def add(self, discord_poll: Poll): + sql = """ + INSERT INTO `Poll`(message_id, channel_id, question, date_created_at, creator_user) + VALUES (%s, %s, %s, %s, %s)""" + values = ( + discord_poll.message_id, + discord_poll.channel_id, + discord_poll.question, + discord_poll.date_created_at, + discord_poll.user_id + ) - async def query(self, query: str, val=None): - await self.cursor.execute(query, val or ()) + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.execute(sql, values) + await conn.commit() - return await self.cursor.fetchall() + async def remove(self, discord_poll: Poll): + sql = """ + DELETE FROM `Poll` WHERE message_id = %s; + """ + value = discord_poll.message_id - async def execute(self, query: str, val=None, commit=True): - await self.cursor.execute(query, val or ()) + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.execute(sql, value) + await conn.commit() - if commit: - await self.connection.commit() + async def fetch_all_polls(self): + sql = """SELECT * FROM `Poll`""" + + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.execute(sql) + polls = await cursor.fetchall() + + return polls + + +class VoteButtonDatabase(Crud): + def __init__(self, pool: aiomysql.pool.Pool): + super().__init__(pool) + + async def add_options(self, discord_poll: Poll): + sql = """ + INSERT INTO `VoteButtons`(message_id, answers) VALUES (%s, %s) + """ + values = [ + (discord_poll.message_id, vote_option) + for vote_option in discord_poll.options + ] + + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.executemany(sql, values) + await conn.commit() + + async def add_user(self, message_id, user, index): + sql = """ + INSERT INTO `Answers`(message_id, vote_user, iter_index) VALUES (%s, %s, %s) + """ + values = (message_id, user, index) + + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.execute(sql, values) + await conn.commit() + + async def remove_user(self, message_id, user, index): + sql = "DELETE FROM `Answers` WHERE message_id = %s AND vote_user = %s AND iter_index = %s" + value = (message_id, user, index) + + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.execute(sql, value) + await conn.commit() + + async def fetch_all_users(self, message_id, index) -> set: + sql = """ + SELECT vote_user FROM `Answers` WHERE message_id = %s AND iter_index = %s + """ + values = (message_id, index) + + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.execute(sql, values) + users_voted_for = await cursor.fetchall() + + clean_users_voted_for = set( + user + for user_tuple in users_voted_for + for user in user_tuple + ) + + return clean_users_voted_for + + async def toggle_vote(self, message_id, user, index): + users = await self.fetch_all_users(message_id, index) + + if user not in users: + await self.add_user(message_id, user, index) + else: + await self.remove_user(message_id, user, index) + + +class AnswersDatabase(Crud): + def __init__(self, pool: aiomysql.pool.Pool): + super().__init__(pool) + + async def collect_all_answers(self, message_id): + sql = """ + SELECT answers FROM `VoteButtons` WHERE message_id = %s + """ + value = (message_id,) + + async with self.poll.acquire() as conn: + cursor = await conn.cursor() + await cursor.execute(sql, value) + tuple_of_tuples_db = await cursor.fetchall() + + answers = tuple( + answer + for tupl in tuple_of_tuples_db + for answer in tupl + ) + + return answers diff --git a/main.py b/main.py index fa01a06..87e4c52 100644 --- a/main.py +++ b/main.py @@ -5,12 +5,18 @@ import asyncio import os +from typing import Optional +import aiomysql.pool import discord from aiomysql import create_pool from discord.ext import commands from dotenv import load_dotenv +from db_folder.sqldatabase import PollDatabase, AnswersDatabase +from poll_design.poll import Poll +from poll_design.poll_view import PollView + load_dotenv("password.env") DISCORD_TOKEN = os.getenv("DISCORD_TOKEN") USER = os.getenv("USER_DATABASE") @@ -18,13 +24,49 @@ HOST = os.getenv("HOST") DATABASE = os.getenv("DATABASE") -# Co jsou intents? https://discordpy.readthedocs.io/en/stable/intents.html -intents = discord.Intents.all() -bot = commands.Bot( - command_prefix=commands.when_mentioned_or("!"), - intents=intents, - help_command=None) +class Potkan_Jachym(commands.Bot): + def __init__(self): + # Co jsou intents? https://discordpy.readthedocs.io/en/stable/intents.html + intents = discord.Intents.all() + self.pool: Optional[aiomysql.pool.Pool] = None + + super().__init__( + command_prefix=commands.when_mentioned_or("!"), + intents=intents, + help_command=None + ) + + async def _fetch_polls(self): + pools_in_db = await PollDatabase(self.pool).fetch_all_polls() + + for message_id, channel_id, question, _, _ in pools_in_db: + channel = self.get_channel(channel_id) + message = await channel.fetch_message(message_id) + answer = await AnswersDatabase(self.pool).collect_all_answers(message_id) + + poll = Poll( + message_id=message_id, + channel_id=channel_id, + question=question, + options=answer) + + self.add_view(PollView(poll=poll, embed=message.embeds[0], db_poll=self.pool)) + + @commands.Cog.listener() + async def on_ready(self): + self.pool = await create_pool( + user=USER, + password=PASSWORD, + host=HOST, + db=DATABASE, + ) + await self._fetch_polls() + + print("ready!") + + +bot = Potkan_Jachym() async def load_extensions(): @@ -39,18 +81,9 @@ async def load_extensions(): async def main(): - bot.pool = await create_pool(user=USER, password=PASSWORD, host=HOST, db=DATABASE) async with bot: await load_extensions() await bot.start(DISCORD_TOKEN) -@bot.event -async def on_ready(): - guild = discord.Object(id=765657737001828393) - bot.tree.copy_global_to(guild=guild) - await bot.tree.sync(guild=discord.Object(id=765657737001828393)) - print("ready!") - - asyncio.run(main()) diff --git a/poll_design/button.py b/poll_design/button.py new file mode 100644 index 0000000..9beb704 --- /dev/null +++ b/poll_design/button.py @@ -0,0 +1,66 @@ +import aiomysql.pool +import discord + +from db_folder.sqldatabase import VoteButtonDatabase +from poll_design.poll import Poll +from ui.poll_embed import PollEmbed + + +class ButtonBackend(discord.ui.Button): + def __init__(self, + custom_id: str, + poll: Poll, + embed: PollEmbed, + index: int, + label: str, + db_poll: aiomysql.pool.Pool) -> None: + super().__init__(label=label) + self.custom_id = custom_id + self.poll = poll + self.embed = embed + self.index = index + self.db_poll = db_poll + + self.users = set() + + def button_id(self): + return self.custom_id + + def message_id(self): + return self.message_id + + def index(self): + return self.index + + async def _load_users(self): + all_users = await VoteButtonDatabase(self.db_poll).fetch_all_users( + self.poll.message_id, + self.index) + + return all_users + + async def edit_embed(self, interaction: discord.Interaction) -> discord.Embed: + users_id = await self._load_users() + members = set( + interaction.guild.get_member(user_id) + for user_id in users_id + ) + + edit = self.embed.set_field_at( + index=self.index, + name=self.embed.fields[self.index].name, + value=f"**{len(members)}** | {', '.join(member.name for member in members)}", + inline=False) + + return edit + + async def callback(self, interaction: discord.Interaction): + await VoteButtonDatabase(self.db_poll).toggle_vote( + self.poll.message_id, + interaction.user.id, + self.index + ) + + edited_embed = await self.edit_embed(interaction) + + await interaction.response.edit_message(embed=edited_embed) diff --git a/poll_design/poll.py b/poll_design/poll.py new file mode 100644 index 0000000..647315a --- /dev/null +++ b/poll_design/poll.py @@ -0,0 +1,56 @@ +from datetime import datetime +from typing import Optional + +from discord import Message + + +class Poll: + MAX_OPTIONS = 10 + MIN_OPTIONS = 2 + + def __init__( + self, + message_id: Message.id, + channel_id: int, + question: str, + options: tuple[str, ...], + user_id: Optional[int] = None + ): + self.message_id = message_id + self.channel_id = channel_id + self.question = question + self.options = options + self.date_created_at = datetime.now().strftime("%Y-%m-%d") + self.user_id = user_id + + def message_id(self) -> int: + return self.message_id + + def channel_id(self) -> int: + return self.channel_id + + def question(self) -> str: + return self.question + + def options(self) -> tuple[str, ...]: + return self.options + + def created_at(self) -> str: + return self.date_created_at + + def user_id(self) -> int: + return self.user_id + + def delete(self): + # connection to database, make new tables and view + pass + + @classmethod + async def create_poll(cls, message: Message, question: str, *answers) -> "Poll": + poll = Poll( + message_id=message.id, + channel_id=message.channel.id, + question=question, + options=answers) + + return poll diff --git a/poll_design/poll_view.py b/poll_design/poll_view.py new file mode 100644 index 0000000..49a6601 --- /dev/null +++ b/poll_design/poll_view.py @@ -0,0 +1,30 @@ +import aiomysql.pool +import discord + +from poll_design.button import ButtonBackend +from poll_design.poll import Poll + + +class PollView(discord.ui.View): + def __init__(self, poll: Poll, embed, db_poll: aiomysql.pool.Pool): + super().__init__(timeout=None) + self.poll = poll + self.embed = embed + self.db_poll = db_poll + self._add_vote_buttons() + + def _add_vote_buttons(self): + for index, option in enumerate(self.poll.options): + self.add_item(ButtonBackend( + custom_id=f"{index}:{self.poll.message_id}", + label=f"{index + 1}", + poll=self.poll, + embed=self.embed, + index=index, + db_poll=self.db_poll + )) + + +class PollInitialization(discord.ui.View): + # this class should handle initialization of all polls + pass diff --git a/ui/poll_embed.py b/ui/poll_embed.py new file mode 100644 index 0000000..3bd14b5 --- /dev/null +++ b/ui/poll_embed.py @@ -0,0 +1,28 @@ +import discord +from discord.colour import Color + +from poll_design.poll import Poll + + +class PollEmbedBase(discord.Embed): + def __init__(self, question) -> None: + super().__init__( + title=f"📊 {question}", + colour=Color.blue() + ) + + +class PollEmbed(PollEmbedBase): + def __init__(self, poll: Poll): + super().__init__(poll.question) + self.answers = poll.options + self.reactions = ['1️⃣', '2️⃣', '3️⃣', '4️⃣', '5️⃣', '6️⃣', '7️⃣', '8️⃣', '9️⃣', '🔟'] + self._add_options() + + def _add_options(self): + for index, option in enumerate(self.answers): + self.add_field( + name=f"{self.reactions[index]} {option}", + value="**0** |", + inline=False + )