Skip to content

Commit

Permalink
Fix pyright warnings and ignore what we can not fix
Browse files Browse the repository at this point in the history
  • Loading branch information
uittenbroekrobbert committed May 29, 2024
1 parent 22360e3 commit fccce5c
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 34 deletions.
4 changes: 3 additions & 1 deletion tad/services/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def assign_task(self, task: Task, user: User) -> Task:
task.user_id = user.id
return self.repository.save(task)

def move_task(self, task_id: int, status_id: int, previous_sibling_id: int, next_sibling_id: int) -> Task:
def move_task(
self, task_id: int, status_id: int, previous_sibling_id: int | None = None, next_sibling_id: int | None = None
) -> Task:
"""
Updates the task with the given task_id
:param task_id: the id of the task
Expand Down
16 changes: 9 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from collections.abc import Generator
from multiprocessing import Process
from time import sleep
from typing import Any

import pytest
import uvicorn
from _pytest.fixtures import SubRequest
from fastapi.testclient import TestClient
from playwright.sync_api import sync_playwright
from playwright.sync_api import Page, Playwright, sync_playwright
from sqlmodel import Session
from tad.core.config import settings
from tad.core.db import get_engine
Expand All @@ -19,18 +21,18 @@ class TestSettings:
HTTP_SERVER_PORT: int = 8000


def run_server():
def run_server() -> None:
uvicorn.run(app, host=TestSettings.HTTP_SERVER_HOST, port=TestSettings.HTTP_SERVER_PORT)


def wait_for_server_ready(url: str, timeout: int = 30):
def wait_for_server_ready(url: str, timeout: int = 30) -> None:
# todo we can not use playwright because it gives async errors, so we need another
# wait to check the server for being up
sleep(5)


@pytest.fixture(scope="module")
def server():
def server() -> Generator[Any, Any, Any]:
# todo (robbert) use a better way to get the test database in the app configuration
os.environ["APP_DATABASE_FILE"] = "database.sqlite3.test"
process = Process(target=run_server)
Expand All @@ -45,12 +47,12 @@ def server():


@pytest.fixture(scope="session")
def get_session() -> Session:
def get_session() -> Generator[Session, Any, Any]:
with Session(get_engine()) as session:
yield session


def pytest_configure():
def pytest_configure() -> None:
"""
Called after the Session object has been created and
before performing collection and entering the run test loop.
Expand All @@ -73,7 +75,7 @@ def playwright():


@pytest.fixture(params=["chromium", "firefox", "webkit"])
def browser(playwright, request):
def browser(playwright: Playwright, request: SubRequest) -> Generator[Page, Any, Any]:
browser = getattr(playwright, request.param).launch(headless=True)
context = browser.new_context()
page = context.new_page()
Expand Down
22 changes: 11 additions & 11 deletions tests/database_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from typing import Any

from sqlalchemy import text
from sqlalchemy.engine.result import ScalarResult, TupleResult
from sqlmodel import Session, SQLModel
from tad.core.db import get_engine

# todo (robbert) it seems pytest runs all methods in the tests folder,
# this should be moved or marked as helper functions


def enrich_with_default_values(specification: dict[str, Any]) -> dict[str, Any]:
def enrich_with_default_values(specification: dict[str, str | int]) -> dict[str, str | int]:
"""
If a known table dictionary is given, like a task or status, default values will be added
and an enriched dictionary is returned.
:param specification: the dictionary to be enriched
:return: an enriched dictionary
"""
default_specification = {}
default_specification: dict[str, str | int] = {}
if specification["table"] == "task":
default_specification["title"] = "Test task " + str(specification["id"])
default_specification["description"] = "Test task description " + str(specification["id"])
Expand All @@ -43,21 +42,21 @@ def fix_missing_relations(specification: dict[str, Any]) -> None:
create_db_entries([status_specification])


def get_items(specification: dict[str, Any]) -> TupleResult | ScalarResult:
def get_items(specification: dict[str, str | int]) -> Any:
"""
Create a query based on the dictionary specification and return the result
:param specification: a dictionary with a table specification
:return: the results of the query
"""
values = ", ".join(
key + "=" + str(val) if str(val).isnumeric() else str("'" + val + "'")
for key, val in specification.items()
if key != "table"
key + "=" + str(val) if str(val).isnumeric() else str('"' + val + '"') # type: ignore
for key, val in specification.items() # type: ignore
if key != "table" # type: ignore
)
table = specification["table"]
statement = f"SELECT * FROM {table} WHERE {values}" # noqa S608
with Session(get_engine()) as session:
return session.exec(text(statement)).all()
return session.exec(text(statement)).all() # type: ignore


def item_exists(specification: dict[str, Any]) -> bool:
Expand Down Expand Up @@ -90,15 +89,16 @@ def create_db_entries(specifications: list[dict[str, Any]]) -> None:
table = specification.pop("table")
keys = ", ".join(key for key in specification)
values = ", ".join(
str(val) if str(val).isnumeric() else str("'" + val + "'") for val in specification.values()
str(val) if str(val).isnumeric() else str("'" + val + "'")
for val in specification.values() # type: ignore
)
statement = f"INSERT INTO {table} ({keys}) VALUES ({values})" # noqa S608
with Session(get_engine()) as session:
session.exec(text(statement))
session.exec(text(statement)) # type: ignore
session.commit()


def init_db(specifications=None) -> None:
def init_db(specifications: list[dict[str, str | int]] | None = None) -> None:
"""
Drop all database tables and create them. Then fill the database with the
entries from the array of dictionaries with table specifications.
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/test_move_task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from playwright.sync_api import expect
from playwright.sync_api import Page, expect

from tests.database_test_utils import init_db


def test_move_task_to_column(server, browser) -> None:
def test_move_task_to_column(server: str, browser: Page) -> None:
"""
Test moving a task in the browser to another column and verify that after a reload
it is in the right column.
Expand Down Expand Up @@ -36,7 +36,7 @@ def test_move_task_to_column(server, browser) -> None:
expect(card).to_be_visible()


def test_move_task_order_in_same_column(server, browser) -> None:
def test_move_task_order_in_same_column(server: str, browser: Page) -> None:
"""
Test moving a task in the browser below another task and verify that after a reload
it is in the right position in the column.
Expand Down
12 changes: 7 additions & 5 deletions tests/services/test_statuses_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections.abc import Sequence
from collections.abc import Generator, Sequence
from typing import Any
from unittest.mock import patch

import pytest
from tad.models import Status
from tad.repositories.statuses import StatusesRepository
from tad.services.statuses import StatusesService


Expand All @@ -11,26 +13,26 @@ def __init__(self):
pass

def find_by_id(self, status_id: int) -> Status:
return Status(id=status_id)
return Status(id=status_id, name="Test status", sort_order=1)

def find_all(self) -> Sequence[Status]:
return [self.find_by_id(1)]


@pytest.fixture(scope="module")
def mock_statuses_repository():
def mock_statuses_repository() -> Generator[MockStatusesRepository, Any, Any]:
with patch("tad.services.statuses.StatusesRepository"):
mock_statuses_repository = MockStatusesRepository()
yield mock_statuses_repository


def test_get_status(mock_statuses_repository):
def test_get_status(mock_statuses_repository: StatusesRepository):
statuses_service = StatusesService(mock_statuses_repository)
status: Status = statuses_service.get_status(1)
assert status.id == 1


def test_get_statuses(mock_statuses_repository):
def test_get_statuses(mock_statuses_repository: StatusesRepository):
statuses_service = StatusesService(mock_statuses_repository)
statuses = statuses_service.get_statuses()
assert len(statuses) == 1
16 changes: 9 additions & 7 deletions tests/services/test_tasks_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

class MockStatusesRepository:
def __init__(self):
self._statuses = []
self._statuses: list[Status] = []
self.reset()

def reset(self):
self._statuses.clear()
self._statuses.append(Status(id=1, name="todo", sort_order=1))
self._statuses.append(Status(id=2, name="in_progress", sort_order=1))
self._statuses.append(Status(id=3, name="review", sort_order=1))
Expand All @@ -26,10 +27,11 @@ def find_by_id(self, status_id: int) -> Status:

class MockTasksRepository:
def __init__(self):
self._tasks = []
self._tasks: list[Task] = []
self.reset()

def reset(self):
self._tasks.clear()
self._tasks.append(Task(id=1, title="Test 1", description="Description 1", status_id=1, sort_order=10))
self._tasks.append(Task(id=2, title="Test 2", description="Description 2", status_id=1, sort_order=20))
self._tasks.append(Task(id=3, title="Test 3", description="Description 3", status_id=1, sort_order=30))
Expand All @@ -46,7 +48,7 @@ def find_by_id(self, task_id: int) -> Task:
return next(filter(lambda x: x.id == task_id, self._tasks))

def save(self, task: Task) -> Task:
pass # objects are saved implicit because, there is no real repository
return task


@pytest.fixture(scope="module")
Expand All @@ -70,18 +72,18 @@ def tasks_service_with_mock(mock_tasks_repository: TasksRepository, mock_statuse
return tasks_service


def test_get_tasks(tasks_service_with_mock, mock_tasks_repository):
def test_get_tasks(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
assert len(tasks_service_with_mock.get_tasks(1)) == 3


def test_assign_task(tasks_service_with_mock, mock_tasks_repository):
def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
task1: Task = mock_tasks_repository.find_by_id(1)
user1: User = User(id=1)
user1: User = User(id=1, name="User 1", avatar="none.jpg")
tasks_service_with_mock.assign_task(task1, user1)
assert task1.user_id == 1


def test_move_task(tasks_service_with_mock, mock_tasks_repository):
def test_move_task(tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository):
# test changing order
mock_tasks_repository.reset()
assert mock_tasks_repository.find_by_id(1).sort_order == 10
Expand Down

0 comments on commit fccce5c

Please sign in to comment.