Skip to content

Commit

Permalink
add static typecheck system to poetry (#114)
Browse files Browse the repository at this point in the history
* update deps for typing, type models.py

* move to pyright

* black

* fully type basic.py as example

* github action time

* fix version at 1.1.316 and specify uqcsbot folder

* switch to poetry for pyright

* clean up optionals in models.py

* fixed jimmy's comments & typed 9 more modules

* fix poetry lockfile error
  • Loading branch information
andrewj-brown authored Jun 30, 2023
1 parent 04b90fe commit e1d3c25
Show file tree
Hide file tree
Showing 23 changed files with 438 additions and 232 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/typecheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Static Type Check

on: [pull_request]

env:
PYTHON_VERSION: '3.10'
POETRY_VERSION: '1.4.2'

jobs:
types:
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v3

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: ${{ env.POETRY_VERSION }}
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true

- name: Install dependencies
run: poetry install --no-interaction

- name: Type with pyright
run: poetry run pyright uqcsbot
287 changes: 192 additions & 95 deletions poetry.lock

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,33 @@ pytest = "^7.3.1"
pytest-datafiles = "^3.0.0"
python-dotenv = "^1.0.0"
black = "^23.3.0"
pyright = "^1.1.316"
types-requests = "^2.30.0.0"
types-beautifulsoup4 = "^4.12.0.4"
types-python-dateutil = "^2.8.19.12"
types-pytz = "^2023.3.0.0"

[build-system]
requires = ["poetry-core>=1.3.0"]
build-backend = "poetry.core.masonry.api"

[tool.pyright]
strict = ["**"]
exclude = [
"**/advent.py",
"**/bot.py",
"**/error_handler.py",
"**/events.py",
"**/gaming.py",
"**/haiku.py",
"**/member_counter.py",
"**/remindme.py",
"**/snailrace.py",
"**/starboard.py",
"**/uptime.py",
"**/whatsdue.py",
"**/working_on.py",
"**/utils/command_utils.py",
"**/utils/snailrace_utils.py",
"**/utils/uq_course_utils.py"
]
14 changes: 6 additions & 8 deletions uqcsbot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ async def main():
intents.members = True
intents.message_content = True

DISCORD_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")

DATABASE_URI = os.environ.get("POSTGRES_URI_BOT")
if DATABASE_URI == None:
# If the database env variable is not defined, default to SQLite in memory db.
DATABASE_URI = "sqlite:///"
if (discord_token := os.environ.get("DISCORD_BOT_TOKEN")) is None:
raise RuntimeError("Bot token is not set!")
if (database_uri := os.environ.get("POSTGRES_URI_BOT")) is None:
database_uri = "sqlite:///"

# If you need to override the allowed mentions that can be done on a per message basis, but default to off
allowed_mentions = discord.AllowedMentions.none()
Expand Down Expand Up @@ -72,11 +70,11 @@ async def main():
for cog in cogs:
await bot.load_extension(f"uqcsbot.{cog}")

db_engine = create_engine(DATABASE_URI, echo=True)
db_engine = create_engine(database_uri, echo=True)
Base.metadata.create_all(db_engine)
bot.set_db_engine(db_engine)

await bot.start(DISCORD_TOKEN)
await bot.start(discord_token)


asyncio.run(main())
24 changes: 18 additions & 6 deletions uqcsbot/advent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
from uqcsbot.bot import UQCSBot
from uqcsbot.models import AOCWinner
from uqcsbot.utils.command_utils import loading_status
from uqcsbot.utils.err_log_utils import FatalErrorWithLog

# Leaderboard API URL with placeholders for year and code.
LEADERBOARD_URL = "https://adventofcode.com/{year}/leaderboard/private/view/{code}.json"
# Session cookie (will expire in approx 30 days).
# See: https://github.com/UQComputingSociety/uqcsbot-discord/wiki/Tokens-and-Environment-Variables#aoc_session_id
SESSION_ID = os.environ.get("AOC_SESSION_ID")

# UQCS leaderboard ID.
UQCS_LEADERBOARD = 989288

# Days in Advent of Code. List of numbers 1 to 25.
ADVENT_DAYS = list(range(1, 25 + 1))

# Puzzles are unlocked at midnight EST.
EST_TIMEZONE = timezone(timedelta(hours=-5))

Expand Down Expand Up @@ -159,6 +159,10 @@ def sort_key(sort: SortMode) -> Callable[["Member"], Any]:
class Advent(commands.Cog):
CHANNEL_NAME = "contests"

# Session cookie (will expire in approx 30 days).
# See: https://github.com/UQComputingSociety/uqcsbot-discord/wiki/Tokens-and-Environment-Variables#aoc_session_id
SESSION_ID: str = ""

def __init__(self, bot: UQCSBot):
self.bot = bot
self.bot.schedule_task(
Expand All @@ -179,6 +183,13 @@ def __init__(self, bot: UQCSBot):
month=12,
)

if os.environ.get("AOC_SESSION_ID") is not None:
SESSION_ID = os.environ.get("AOC_SESSION_ID")
else:
raise FatalErrorWithLog(
bot, "Unable to find AoC session ID. Not loading advent cog."
)

def star_char(self, num_stars: int):
"""
Given a number of stars (0, 1, or 2), returns its leaderboard
Expand Down Expand Up @@ -334,7 +345,7 @@ def parse_arguments(self, argv: List[str]) -> Namespace:
def usage_error(message, *args, **kwargs):
raise ValueError(message)

parser.error = usage_error # type: ignore
parser.error = usage_error

args = parser.parse_args(argv)

Expand All @@ -343,14 +354,14 @@ def usage_error(message, *args, **kwargs):

return args

def get_leaderboard(self, year: int, code: int) -> Dict:
def get_leaderboard(self, year: int, code: int) -> Optional[Dict]:
"""
Returns a json dump of the leaderboard
"""
try:
response = requests.get(
LEADERBOARD_URL.format(year=year, code=code),
cookies={"session": SESSION_ID},
cookies={"session": self.SESSION_ID},
)
return response.json()
except ValueError as exception: # json.JSONDecodeError
Expand Down Expand Up @@ -543,4 +554,5 @@ async def advent_winners(

async def setup(bot: UQCSBot):
cog = Advent(bot)

await bot.add_cog(cog)
10 changes: 5 additions & 5 deletions uqcsbot/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ async def on_ready(self):
)

@commands.Cog.listener()
async def on_member_join(self, member):
async def on_member_join(self, member: discord.Member):
"""Member join listener"""
channel = member.guild.system_channel
if (channel := member.guild.system_channel) is None:
return
# On user joining, a system join message will appear in the system channel
# This should prevent the bot waving on a user message when #general is busy
async for msg in channel.history(limit=5):
Expand Down Expand Up @@ -83,7 +84,7 @@ def format_repo_message(self, repos: List[str]) -> str:
:param repos: list of strings of repo names
:return: a single string with a formatted message containing repo info for the given names
"""
repo_strings = []
repo_strings: List[str] = []
for potential_repo in repos:
repo_strings.append(self.find_repo(potential_repo))
return "".join(repo_strings)
Expand All @@ -98,8 +99,7 @@ async def repo_list(self, interaction: discord.Interaction):
+ self.format_repo_message(list(REPOS.keys()))
)

@repo_group.command(name="find")
@app_commands.describe(name="Name of the repo to find")
@repo_group.command(name="find", description="Name of the repo to find")
async def repo_find(self, interaction: discord.Interaction, name: str):
"""Finds a specific UQCS GitHub repository"""
await interaction.response.send_message(
Expand Down
49 changes: 38 additions & 11 deletions uqcsbot/bot.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
import logging
import os
from typing import List
from typing import List, Optional, Tuple, Any, Callable, Coroutine

import discord
from discord.ext import commands
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from datetime import datetime
from aiohttp import web
from pytz import timezone

ADMIN_ALERTS = "admin-alerts"
"""
TODO: TYPE ISSUES IN THIS FILE:
- apscheduler has no stubs. They're planned for the 4.0 release... in the future.
- aiohttp handler witchery
"""


class UQCSBot(commands.Bot):
"""An extended bot client to add extra functionality."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._scheduler = AsyncIOScheduler()
self.start_time = datetime.now()

def schedule_task(self, func, *args, **kwargs):
# Important channel names & constants go here
self.ADMIN_ALERTS_CNAME = "admin-alerts"
self.GENERAL_CNAME = "general"
self.BOT_TIMEZONE = timezone("Australia/Brisbane")

self.uqcs_server: discord.Guild

def schedule_task(
self, func: Callable[..., Coroutine[Any, Any, None]], *args: Any, **kwargs: Any
):
"""Schedule a function to be run at a later time. A wrapper for apscheduler add_job."""
self._scheduler.add_job(func, *args, **kwargs)

Expand All @@ -35,15 +50,17 @@ async def admin_alert(
self,
title: str,
colour: discord.Colour,
description: str = None,
footer: str = None,
fields: List[tuple] = None,
description: Optional[str] = None,
footer: Optional[str] = None,
fields: Optional[List[Tuple[str, str]]] = None,
fields_inline: bool = True,
):
"""Sends an alert to the admin channel for logging."""
admin_channel = discord.utils.get(self.uqcs_server.channels, name=ADMIN_ALERTS)
admin_channel = discord.utils.get(
self.uqcs_server.channels, name=self.ADMIN_ALERTS_CNAME
)

if admin_channel == None:
if admin_channel == None or not isinstance(admin_channel, discord.TextChannel):
return

admin_message = discord.Embed(title=title, colour=colour)
Expand Down Expand Up @@ -75,12 +92,22 @@ def handle(request):
async def on_ready(self):
"""Once the bot is loaded and has connected, run these commands first."""
self._scheduler.start()

if (user := self.user) is None:
raise RuntimeError("Ready... but not logged in!")
self.safe_user = user

logging.info(
f'Bot online and logged in: [Name="{self.user.name}", ID={self.user.id}]'
f'Bot online and logged in: [Name="{self.safe_user.id}", ID={self.safe_user.id}]'
)

# Get the UQCS server object and store it centrally
self.uqcs_server = self.get_guild(int(os.environ.get("SERVER_ID")))
if (server_id := os.environ.get("SERVER_ID")) is None:
raise RuntimeError("Server ID is not set!")
if (server := self.get_guild(int(server_id))) is None:
raise RuntimeError("Unable to find server with id {server_id}")
self.uqcs_server: discord.Guild = server

logging.info(f"Active in the {self.uqcs_server} server.")

# Sync the app comamand tree with servers.
Expand Down
2 changes: 1 addition & 1 deletion uqcsbot/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def cat(self, interaction: discord.Interaction):
order = deque([pink, red, yellow, green, cyan, blue])
# randomly shifts starting colout
shift = randrange(0, 5)
for i in range(shift):
for _ in range(shift):
order.append(order.popleft())

cat = "\n".join(
Expand Down
20 changes: 10 additions & 10 deletions uqcsbot/cowsay.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def draw_cow(
"""

# Set the tongue if the cow is dead or if the tongue is set to True.
tongue = "U" if tongue or mood == "Dead" else " "
tongue_out = "U" if tongue or mood == "Dead" else " "

# Set the bubble connection based on whether the cow is thinking or
# speaking.
Expand All @@ -155,9 +155,9 @@ def draw_cow(

# Draw the cow.
cow = f" {bubble_connect} ^__^\n"
cow += f" {bubble_connect} ({cow_eyes})\_______\n"
cow += f" (__)\ )\/\ \n"
cow += f" {tongue} ||----w |\n"
cow += f" {bubble_connect} ({cow_eyes})\\_______\n"
cow += f" (__)\\ )\\/\\ \n"
cow += f" {tongue_out} ||----w |\n"
cow += f" || ||\n"
return cow

Expand Down Expand Up @@ -186,10 +186,10 @@ def draw_tux(
tux += f" .--. \n"
tux += f" |{tux_eyes} | \n"
tux += f" |:_/ | \n"
tux += f" // \ \ \n"
tux += f" // \\ \\ \n"
tux += f" (| | ) \n"
tux += f" /'\_ _/`\ \n"
tux += f" \___)=(___/ \n"
tux += f" /'\\_ _/`\\ \n"
tux += f" \\___)=(___/ \n"
return tux

@staticmethod
Expand All @@ -215,7 +215,7 @@ def sanitise_emotes(message: str) -> str:
"""

# Regex to match emotes.
emotes: List[str] = re.findall("<a?:\w+:\d+>", message)
emotes: List[str] = re.findall(r"<a?:\w+:\d+>", message)

# Replace each emote with its name.
for emote in emotes:
Expand All @@ -242,7 +242,7 @@ def word_wrap(message: str, wrap: int) -> List[str]:
# As requested by the audience, you can manually break lines by
# adding "\n" anywhere in the message and it will be respected.
if "\\n" in word:
parts: str = word.split("\\n", 1)
parts = word.split("\\n", 1)

# The `\n` is by itself, so start a new line.
if parts[0] == "" and parts[1] == "":
Expand Down Expand Up @@ -271,7 +271,7 @@ def word_wrap(message: str, wrap: int) -> List[str]:
# the list of words to be processed.
if len(word) > wrap:
# Cut the word to the remaining space on the line.
cut_word: str = word[: (wrap - len(line))]
cut_word = word[: (wrap - len(line))]

# Add the rest of the word to the list of words to be processed.
words.insert(index, word[len(cut_word) :])
Expand Down
4 changes: 4 additions & 0 deletions uqcsbot/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from discord.ext.commands.errors import MissingRequiredArgument
import logging

"""
TODO: this is bundled with advent.py and should be removed.
"""


class ErrorHandler(commands.Cog):
@commands.Cog.listener()
Expand Down
Loading

0 comments on commit e1d3c25

Please sign in to comment.