Skip to content

Commit

Permalink
Use Enum for database procedures & refactor
Browse files Browse the repository at this point in the history
Allows for better control over procedures and removing the hacky "returns" parameter of dbcallprocedure.

By introducing types of procedures (ones that return bool, int, or no values), introducing fitting function overloads which only accept BoolProcedures but also return a bool value, for example, was simple.

Allows for better extensibility in case of future types of procedures with other return types, or such that return a list of values.
  • Loading branch information
MajorTanya committed Mar 14, 2023
1 parent cbfbe09 commit 2fc646f
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 20 deletions.
15 changes: 7 additions & 8 deletions DTbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from discord import app_commands
from discord.ext import commands

from util.utils import checkdbforuser, dbcallprocedure
from util.utils import DBProcedure, checkdbforuser, dbcallprocedure

intents = discord.Intents.default()
intents.members = True
Expand Down Expand Up @@ -57,7 +57,7 @@ async def setup_hook(self) -> None:
await self.tree.sync()

async def on_guild_join(self, guild: discord.Guild):
dbcallprocedure(self.db_cnx, 'AddNewServer', params=(guild.id, guild.member_count))
dbcallprocedure(self.db_cnx, DBProcedure.AddNewServer, params=(guild.id, guild.member_count))

async def on_message(self, message: discord.Message):
if (message.author == self.user) or message.author.bot:
Expand All @@ -67,15 +67,14 @@ async def on_message(self, message: discord.Message):
finally:
pass

async def on_app_command_completion(self, interaction: discord.Interaction, command: app_commands.Command):
result = dbcallprocedure(self.db_cnx, 'CheckAppCommandExist', returns=True,
params=(command.qualified_name, '@res'))
async def on_app_command_completion(self, _: discord.Interaction, command: app_commands.Command):
result = dbcallprocedure(self.db_cnx, DBProcedure.CheckAppCommandExist, params=(command.qualified_name, '@res'))
if result:
dbcallprocedure(self.db_cnx, 'IncrementAppCommandUsage', params=(command.qualified_name,))
dbcallprocedure(self.db_cnx, DBProcedure.IncrementAppCommandUsage, params=(command.qualified_name,))
else:
dbcallprocedure(self.db_cnx, 'AddNewAppCommand', params=(command.qualified_name,))
dbcallprocedure(self.db_cnx, DBProcedure.AddNewAppCommand, params=(command.qualified_name,))
# because the command was used this one time, we increment the default value (0) by 1
dbcallprocedure(self.db_cnx, 'IncrementAppCommandUsage', params=(command.qualified_name,))
dbcallprocedure(self.db_cnx, DBProcedure.IncrementAppCommandUsage, params=(command.qualified_name,))

async def on_ready(self):
# online confimation
Expand Down
4 changes: 2 additions & 2 deletions dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from discord.ext import commands, tasks

from DTbot import DTbot
from util.utils import dbcallprocedure
from util.utils import DBProcedure, dbcallprocedure


@app_commands.guilds(DTbot.DEV_GUILD)
Expand Down Expand Up @@ -188,7 +188,7 @@ async def shutdownbot(self, interaction: discord.Interaction, passcode: str):
async def refreshservers(self, interaction: discord.Interaction):
await interaction.response.defer(ephemeral=True)
for guild in self.bot.guilds:
dbcallprocedure(self.bot.db_cnx, 'AddNewServer', params=(guild.id, guild.member_count))
dbcallprocedure(self.bot.db_cnx, DBProcedure.AddNewServer, params=(guild.id, guild.member_count))
await interaction.followup.send('Server list refreshed', ephemeral=True)

async def sync(self, *, dev_sync: bool = False, global_sync: bool = False):
Expand Down
4 changes: 2 additions & 2 deletions general.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from linklist import changelog_link
from util.AniListMediaQuery import AniListMediaQuery
from util.PaginatorSession import PaginatorSession
from util.utils import dbcallprocedure, even_out_embed_fields
from util.utils import DBProcedure, dbcallprocedure, even_out_embed_fields

anilist_cooldown = app_commands.Cooldown(80, 60)

Expand Down Expand Up @@ -264,7 +264,7 @@ async def xp(self, interaction: discord.Interaction, user: discord.Member | disc
if user.bot:
return await interaction.response.send_message("Bots don't get XP. :robot:")
await interaction.response.defer()
xp = dbcallprocedure(self.bot.db_cnx, 'GetUserXp', returns=True, params=(user.id, '@res'))
xp = dbcallprocedure(self.bot.db_cnx, DBProcedure.GetUserXp, params=(user.id, '@res'))
if xp > 0:
await interaction.followup.send(f"**{user.display_name}** has `{xp}` XP.")
else:
Expand Down
96 changes: 88 additions & 8 deletions util/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,97 @@
import enum
import random
import time
import typing

import discord
import mariadb
import mariadb # type: ignore

TProcedures = typing.Literal['AddNewAppCommand', 'CheckAppCommandExist', 'IncrementAppCommandUsage',
'AddNewUser', 'CheckUserExist', 'CheckXPTime', 'GetUserXp', 'IncreaseXP', 'AddNewServer']

class DBProcedure(enum.StrEnum):
GetUserXp = 'GetUserXp'
CheckXPTime = 'CheckXPTime'
CheckAppCommandExist = 'CheckAppCommandExist'
CheckUserExist = 'CheckUserExist'
AddNewAppCommand = 'AddNewAppCommand'
IncrementAppCommandUsage = 'IncrementAppCommandUsage'
AddNewUser = 'AddNewUser'
IncreaseXP = 'IncreaseXP'
AddNewServer = 'AddNewServer'

def dbcallprocedure(pool: mariadb.ConnectionPool, procedure: TProcedures, *, returns: bool = False, params: tuple = ()):
@classmethod
def bool_procedures(cls) -> list[typing.Self]:
return [DBProcedure.CheckUserExist, DBProcedure.CheckAppCommandExist]

@classmethod
def int_procedures(cls) -> list[typing.Self]:
return [DBProcedure.GetUserXp, DBProcedure.CheckXPTime]

@classmethod
def returning_procedures(cls) -> list[typing.Self]:
return [*DBProcedure.bool_procedures(), *DBProcedure.int_procedures()]

@classmethod
def non_returning_procedures(cls) -> list[typing.Self]:
return [
DBProcedure.AddNewAppCommand,
DBProcedure.IncrementAppCommandUsage,
DBProcedure.AddNewUser,
DBProcedure.IncreaseXP,
DBProcedure.AddNewServer,
]


_BoolProcedures = typing.Literal[DBProcedure.CheckUserExist, DBProcedure.CheckAppCommandExist]
_IntProcedures = typing.Literal[DBProcedure.GetUserXp, DBProcedure.CheckXPTime]
_NoReturnProcedures = typing.Literal[
DBProcedure.AddNewAppCommand,
DBProcedure.IncrementAppCommandUsage,
DBProcedure.AddNewUser,
DBProcedure.IncreaseXP,
DBProcedure.AddNewServer,
]
_ReturnProcedures = typing.Literal[
DBProcedure.CheckUserExist,
DBProcedure.CheckAppCommandExist,
DBProcedure.GetUserXp,
DBProcedure.CheckXPTime,
]


@typing.overload
def dbcallprocedure(pool: mariadb.ConnectionPool, procedure: _BoolProcedures, *,
params: tuple[typing.Any, ...] = ()) -> bool:
...


@typing.overload
def dbcallprocedure(pool: mariadb.ConnectionPool, procedure: _IntProcedures, *,
params: tuple[typing.Any, ...] = ()) -> int:
...


@typing.overload
def dbcallprocedure(pool: mariadb.ConnectionPool, procedure: _NoReturnProcedures, *,
params: tuple[typing.Any, ...] = ()) -> None:
...


def dbcallprocedure(pool: mariadb.ConnectionPool, procedure: DBProcedure, *,
params: tuple[typing.Any, ...] = ()) -> bool | int | None:
"""Calls a stored procedure with the given parameters.
Parameters
-----------
pool : mariadb.ConnectionPool
The mariadb.ConnectionPool to get a connection from
procedure : DBProcedure
The Stored Procedure to call
params : tuple
A tuple of parameters to supply to the Stored Procedure (Default: ())
"""
pconn: mariadb.Connection
result = None
returns = procedure in DBProcedure.returning_procedures()
with pool.get_connection() as pconn:
with pconn.cursor() as cursor:
cursor.callproc(procedure, params)
Expand All @@ -24,17 +104,17 @@ def dbcallprocedure(pool: mariadb.ConnectionPool, procedure: TProcedures, *, ret


def checkdbforuser(pool: mariadb.ConnectionPool, message: discord.Message):
result = dbcallprocedure(pool, 'CheckUserExist', returns=True, params=(message.author.id, '@res'))
result = dbcallprocedure(pool, DBProcedure.CheckUserExist, params=(message.author.id, '@res'))
if result:
# entry for this user ID exists, proceed to check for last XP gain time, possibly awarding some new XP
last_xp_gain = dbcallprocedure(pool, 'CheckXPTime', returns=True, params=(message.author.id, '@res'))
last_xp_gain = dbcallprocedure(pool, DBProcedure.CheckXPTime, params=(message.author.id, '@res'))
unix_now = int(time.time())
if unix_now - last_xp_gain > 120:
# user got XP more than two minutes ago, award between 15 and 25 XP and update last XP gain time
dbcallprocedure(pool, 'IncreaseXP', params=(message.author.id, random.randint(15, 25), unix_now))
dbcallprocedure(pool, DBProcedure.IncreaseXP, params=(message.author.id, random.randint(15, 25), unix_now))
else:
# user is unknown to the database, add it with user ID and default in the other fields
dbcallprocedure(pool, 'AddNewUser', params=(message.author.id,))
dbcallprocedure(pool, DBProcedure.AddNewUser, params=(message.author.id,))


def even_out_embed_fields(embed: discord.Embed):
Expand Down

0 comments on commit 2fc646f

Please sign in to comment.