Skip to content

Commit

Permalink
Adds a bunch of random type casts to fix latest version of pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
lexicalunit committed Nov 2, 2023
1 parent c6be156 commit e97dfbd
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 89 deletions.
17 changes: 10 additions & 7 deletions src/spellbot/models/channel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import BigInteger, Column, DateTime, String
from sqlalchemy.orm import relationship
Expand All @@ -21,11 +21,14 @@ class Channel(Base):

__tablename__ = "channels"

xid = Column(
BigInteger,
primary_key=True,
nullable=False,
doc="The external Discord ID for a channel",
xid: int = cast(
int,
Column(
BigInteger,
primary_key=True,
nullable=False,
doc="The external Discord ID for a channel",
),
)
created_at = Column(
DateTime,
Expand Down Expand Up @@ -139,7 +142,7 @@ def to_dict(self) -> dict[str, Any]:
"guild_xid": self.guild_xid,
"name": self.name,
"default_seats": self.default_seats,
"default_format": GameFormat(self.default_format),
"default_format": GameFormat(cast(int, self.default_format)),
"auto_verify": self.auto_verify,
"unverified_only": self.unverified_only,
"verified_only": self.verified_only,
Expand Down
74 changes: 43 additions & 31 deletions src/spellbot/models/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from datetime import datetime
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, cast

import discord
from dateutil import tz
Expand Down Expand Up @@ -68,12 +68,15 @@ class Game(Base):
nullable=False,
doc="The external Discord ID of the associated guild",
)
channel_xid = Column(
BigInteger,
ForeignKey("channels.xid", ondelete="CASCADE"),
index=True,
nullable=False,
doc="The external Discord ID of the associated channel",
channel_xid: int = cast(
int,
Column(
BigInteger,
ForeignKey("channels.xid", ondelete="CASCADE"),
index=True,
nullable=False,
doc="The external Discord ID of the associated channel",
),
)
message_xid = Column(
BigInteger,
Expand All @@ -87,27 +90,36 @@ class Game(Base):
nullable=True,
doc="The external Discord ID of an associated voice channel",
)
seats = Column(
Integer,
index=True,
nullable=False,
doc="The number of seats (open or occupied) available at this game",
seats: int = cast(
int,
Column(
Integer,
index=True,
nullable=False,
doc="The number of seats (open or occupied) available at this game",
),
)
status = Column(
Integer(),
default=GameStatus.PENDING.value,
server_default=text(str(GameStatus.PENDING.value)),
index=True,
nullable=False,
doc="Pending or started status of this game",
status: int = cast(
int,
Column(
Integer(),
default=GameStatus.PENDING.value,
server_default=text(str(GameStatus.PENDING.value)),
index=True,
nullable=False,
doc="Pending or started status of this game",
),
)
format = Column(
Integer(),
default=GameFormat.COMMANDER.value,
server_default=text(str(GameFormat.COMMANDER.value)),
index=True,
nullable=False,
doc="The Magic: The Gathering format for this game",
format: int = cast(
int,
Column(
Integer(),
default=GameFormat.COMMANDER.value,
server_default=text(str(GameFormat.COMMANDER.value)),
index=True,
nullable=False,
doc="The Magic: The Gathering format for this game",
),
)
spelltable_link = Column(
String(255),
Expand Down Expand Up @@ -150,7 +162,7 @@ def player_xids(self) -> list[int]:
@property
def started_at_timestamp(self) -> int:
assert self.started_at is not None
return int(self.started_at.replace(tzinfo=tz.UTC).timestamp())
return int(cast(datetime, self.started_at).replace(tzinfo=tz.UTC).timestamp())

def show_links(self, dm: bool = False) -> bool:
return True if dm else self.guild.show_links
Expand All @@ -159,7 +171,7 @@ def show_links(self, dm: bool = False) -> bool:
def embed_title(self) -> str:
if self.status == GameStatus.STARTED.value:
return "**Your game is ready!**"
remaining = int(self.seats) - len(self.players)
remaining = int(cast(int, self.seats)) - len(self.players)
plural = "s" if remaining > 1 else ""
return f"**Waiting for {remaining} more player{plural} to join...**"

Expand Down Expand Up @@ -210,7 +222,7 @@ def placeholders(self) -> dict[str, str]:
"game_start": game_start,
}
for i, player in enumerate(self.players):
placeholders[f"player_name_{i+1}"] = player.name
placeholders[f"player_name_{i+1}"] = cast(str, player.name)
return placeholders

def apply_placeholders(self, placeholders: dict[str, str], text: str) -> str:
Expand All @@ -224,13 +236,13 @@ def embed_players(self) -> str:
for player in self.players:
points_str = ""
if self.status == GameStatus.STARTED.value:
points = player.points(self.id)
points = player.points(cast(int, self.id))
if points:
points_str = f" ({points} point{'s' if points > 1 else ''})"

power_level_str = ""
if self.status == GameStatus.PENDING.value:
config = player.config(self.guild_xid) or {}
config = player.config(cast(int, self.guild_xid)) or {}
power_level = config.get("power_level", None)
if power_level:
power_level_str = f" (power level: {power_level})"
Expand Down
8 changes: 4 additions & 4 deletions src/spellbot/models/guild.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable, cast

from sqlalchemy import BigInteger, Boolean, Column, DateTime, String, false
from sqlalchemy.orm import relationship
Expand All @@ -17,7 +17,7 @@ class Guild(Base):

__tablename__ = "guilds"

xid = Column(BigInteger, primary_key=True, nullable=False)
xid: int = cast(int, Column(BigInteger, primary_key=True, nullable=False))
created_at = Column(
DateTime,
nullable=False,
Expand Down Expand Up @@ -86,11 +86,11 @@ def to_dict(self) -> dict[str, Any]:
"show_links": self.show_links,
"voice_create": self.voice_create,
"channels": sorted(
[channel.to_dict() for channel in self.channels],
[channel.to_dict() for channel in cast("Iterable[Channel]", self.channels)],
key=lambda c: c["xid"],
),
"awards": sorted(
[award.to_dict() for award in self.awards],
[award.to_dict() for award in cast("Iterable[GuildAward]", self.awards)],
key=lambda c: c["id"],
),
}
19 changes: 11 additions & 8 deletions src/spellbot/models/play.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import BigInteger, Column, DateTime, ForeignKey, Integer

Expand Down Expand Up @@ -38,13 +38,16 @@ class Play(Base):
nullable=False,
doc="The external Discord ID of the user who played a game",
)
game_id = Column(
Integer,
ForeignKey("games.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
index=True,
doc="The SpellBot game ID of the game the user played",
game_id = cast(
int,
Column(
Integer,
ForeignKey("games.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
index=True,
doc="The SpellBot game ID of the game the user played",
),
)
points = Column(
Integer,
Expand Down
14 changes: 7 additions & 7 deletions src/spellbot/services/games.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from datetime import datetime, timedelta
from typing import Any, Optional
from typing import Any, Optional, cast

import discord
import pytz
Expand Down Expand Up @@ -252,7 +252,7 @@ def make_ready(self, spelltable_link: Optional[str]) -> int:

if not player_xids:
DatabaseSession.commit()
return self.game.id
return cast(int, self.game.id)

# upsert into plays
DatabaseSession.execute(
Expand Down Expand Up @@ -290,7 +290,7 @@ def make_ready(self, spelltable_link: Optional[str]) -> int:
)

DatabaseSession.commit()
return self.game.id
return cast(int, self.game.id)

@sync_to_async()
@tracer.wrap()
Expand All @@ -312,7 +312,7 @@ def watch_notes(self, player_xids: list[int]) -> dict[int, Optional[str]]:
)
.all()
)
return {watch.user_xid: watch.note for watch in watched}
return {cast(int, watch.user_xid): cast(Optional[str], watch.note) for watch in watched}

@sync_to_async()
@tracer.wrap()
Expand All @@ -326,11 +326,11 @@ def set_voice(self, voice_xid: int) -> None:
def filter_blocked_list(self, author_xid: int, other_xids: list[int]) -> list[int]:
"""Given an author, filters out any blocked players from a list of others."""
users_author_has_blocked = [
row.blocked_user_xid
cast(int, row.blocked_user_xid)
for row in DatabaseSession.query(Block).filter(Block.user_xid == author_xid)
]
users_who_blocked_author_or_other = [
row.user_xid
cast(int, row.user_xid)
for row in DatabaseSession.query(Block).filter(
Block.blocked_user_xid.in_([author_xid, *other_xids]),
)
Expand Down Expand Up @@ -485,7 +485,7 @@ def message_xids(self, game_ids: list[int]) -> list[int]:
def dequeue_players(self, player_xids: list[int]) -> list[int]:
"""Removes the given players from any queues that they're in; returns changed game ids."""
queues = DatabaseSession.query(Queue).filter(Queue.user_xid.in_(player_xids)).all()
game_ids = {queue.game_id for queue in queues}
game_ids = {cast(int, queue.game_id) for queue in queues}
for queue in queues:
DatabaseSession.delete(queue)
DatabaseSession.commit()
Expand Down
6 changes: 3 additions & 3 deletions src/spellbot/services/guilds.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, cast

import discord
import pytz
Expand Down Expand Up @@ -59,7 +59,7 @@ def select(self, guild_xid: int) -> bool:
@sync_to_async()
def should_voice_create(self) -> bool:
assert self.guild
return self.guild.voice_create
return cast(bool, self.guild.voice_create)

@sync_to_async()
def set_motd(self, message: Optional[str] = None) -> None:
Expand All @@ -85,7 +85,7 @@ def toggle_voice_create(self) -> None:
@sync_to_async()
def current_name(self) -> str:
assert self.guild
return self.guild.name or ""
return cast(Optional[str], self.guild.name) or ""

@sync_to_async()
def voice_category_prefixes(self) -> list[str]:
Expand Down
6 changes: 3 additions & 3 deletions src/spellbot/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from datetime import datetime
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast

import discord
import pytz
Expand Down Expand Up @@ -147,7 +147,7 @@ def is_banned(self, target_xid: Optional[int] = None) -> bool:
return bool(row[0]) if row else False

assert self.user
return self.user.banned
return cast(bool, self.user.banned)

@sync_to_async()
def block(self, author_xid: int, target_xid: int) -> None:
Expand Down Expand Up @@ -206,7 +206,7 @@ def unwatch(self, guild_xid: int, user_xid: int) -> None:
@sync_to_async()
def blocklist(self, user_xid: int) -> list[int]:
return [
block.blocked_user_xid
cast(int, block.blocked_user_xid)
for block in DatabaseSession.query(Block).filter(Block.user_xid == user_xid).all()
]

Expand Down
4 changes: 2 additions & 2 deletions tests/cogs/test_lfg_cog_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def test_concurrent_lfg_requests_different_channels(self, bot: SpellBot) -
# At leat one game is out of order, this is good!
messages_out_of_order = True
break
message_xid = game.message_xid
message_xid = cast(Optional[int], game.message_xid)
assert messages_out_of_order

async def test_concurrent_lfg_requests_same_channel(
Expand Down Expand Up @@ -110,5 +110,5 @@ def get_next_message(*args: Any, **kwargs: Any) -> discord.Message:
# At leat one game is out of order, this is good!
messages_out_of_order = True
break
message_xid = game.message_xid
message_xid = cast(Optional[int], game.message_xid)
assert messages_out_of_order
2 changes: 1 addition & 1 deletion tests/mocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def mock_discord_channel(
if guild:
discord_channel.guild = guild
else:
discord_channel.guild = mock_discord_guild(channel.guild)
discord_channel.guild = mock_discord_guild(cast(Guild, channel.guild))
discord_channel.fetch_message = AsyncMock()
discord_channel.get_partial_message = MagicMock()
discord_channel.permissions_for = MagicMock()
Expand Down
Loading

0 comments on commit e97dfbd

Please sign in to comment.