Skip to content

Commit

Permalink
Merge pull request #84 from schireson/dc/ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin authored May 23, 2023
2 parents d308a19 + f7fcf1f commit d1cbea3
Show file tree
Hide file tree
Showing 22 changed files with 898 additions and 1,083 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Install poetry
uses: abatilo/[email protected]
with:
poetry-version: "1.2"
poetry-version: 1.2.2

- name: Set up cache
uses: actions/cache@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Run image
uses: abatilo/[email protected]
with:
poetry-version: 1.2.0
poetry-version: 1.2.2

- name: Publish
env:
Expand Down
7 changes: 2 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@ test:
coverage xml

lint:
flake8 src tests || exit 1
isort --check-only src tests || exit 1
pydocstyle src tests || exit 1
ruff src tests || exit 1
black --check src tests || exit 1
mypy src tests || exit 1
bandit -r src --skip B101 || exit 11

format:
isort src tests
ruff --fix src tests
black src tests

publish: build
Expand Down
1,653 changes: 728 additions & 925 deletions poetry.lock

Large diffs are not rendered by default.

43 changes: 27 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest-alembic"
version = "0.10.4"
version = "0.10.5"
description = "A pytest plugin for verifying alembic migrations."
authors = [
"Dan Cardin <[email protected]>",
Expand All @@ -25,33 +25,22 @@ alembic = "*"
sqlalchemy = "*"

[tool.poetry.dev-dependencies]
bandit = "*"
asyncpg = "*"
black = {version = "22.3.0", python = ">=3.6.2"}
coverage = {version = ">=6.4.4", extras = ["toml"], python = ">=3.7"}
flake8 = ">=3.9"
isort = ">=5"
greenlet = "*"
mypy = {version = ">=0.900", python = ">=3.5"}
psycopg2-binary = "*"
pydocstyle = {version = "*", python = ">=3.5"}
pytest = {version = ">=6.2"}
pytest-asyncio = "*"
pytest-mock-resources = {version = ">=2.6.3", extras = ["docker"], python = ">=3.7"}
types-dataclasses = "^0.1.7"
ruff = {version = '0.0.269', python = ">3.7"}
sqlalchemy = {version = ">=1.4", extras = ["asyncio"]}
greenlet = "*"
asyncpg = "*"
types-dataclasses = "^0.1.7"

[tool.poetry.plugins.pytest11]
pytest_alembic = "pytest_alembic.plugin"

[tool.isort]
profile = 'black'
known_first_party = 'pytest_alembic,tests'
line_length = 100
float_to_top = true
order_by_type = false
use_parentheses = true

[tool.black]
line_length = 100

Expand All @@ -64,6 +53,28 @@ exclude_lines = [
"if __name__ == .__main__.:",
]

[tool.ruff]
src = ["src", "tests"]
target-version = "py37"
line-length = 100
select = [
"A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ARG", "BLE",
"DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH",
"PIE", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID",
"TRY", "UP", "YTT"
]
ignore = ["E501", "S101", "D1", "TRY200", "B007", "B904"]

[tool.ruff.isort]
order-by-type = false

[tool.ruff.per-file-ignores]
"src/pytest_alembic/tests/**/*.py" = ["S101", "BLE001"]
"**/tests/**/*.py" = ["D", "S", "N801", "N802", "N806", "T201", "E501"]

[tool.ruff.pydocstyle]
convention = "google"

[tool.coverage.run]
source = ["src"]
branch = true
Expand Down
6 changes: 0 additions & 6 deletions setup.cfg

This file was deleted.

9 changes: 4 additions & 5 deletions src/pytest_alembic/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import alembic.config

if TYPE_CHECKING:
from pytest_alembic.revision_data import RevisionSpec


@dataclass
class Config:
Expand Down Expand Up @@ -148,7 +151,3 @@ def duplicate_alembic_config(config: alembic.config.Config):
config_args=config.config_args,
attributes=config.attributes,
)


# isort: split
from pytest_alembic.revision_data import RevisionSpec # noqa
31 changes: 13 additions & 18 deletions src/pytest_alembic/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def upgrade(self, revision):
"""Upgrade to the given `revision`."""

def upgrade(rev, _):
return self.script._upgrade_revs(revision, rev)
return self.script._upgrade_revs(revision, rev) # noqa: SLF001

self._run_env(upgrade, revision)

def downgrade(self, revision):
"""Downgrade to the given `revision`."""

def downgrade(rev, _):
return self.script._downgrade_revs(revision, rev)
return self.script._downgrade_revs(revision, rev) # noqa: SLF001

self._run_env(downgrade, revision)

Expand Down Expand Up @@ -134,16 +134,13 @@ def table_insert(
table = _tablename or tablename

if table is None:
raise ValueError(
"No table name provided as either `table` argument, or '__tablename__' key in `data`."
)
message = "No table name provided as either `table` argument, or '__tablename__' key in `data`."
raise ValueError(message)

try:
with contextlib.suppress(ValueError):
# Attempt to parse the schema out of the tablename
schema, table = table.split(".", 1)
except ValueError:
# However, if it doesn't work, both `table` and `schema` are in scope, so failure is fine.
pass
schema, table = table.split(".", 1)

table = self.table(revision, table, schema=schema, connection=connection)
values = {k: v for k, v in item.items() if k != "__tablename__"}
Expand All @@ -163,7 +160,7 @@ def run_task(self, fn, **kwargs):
try:
from sqlalchemy.ext.asyncio import AsyncEngine
except ImportError: # pragma: no cover
AsyncEngine = None
AsyncEngine = None # noqa: N806

if AsyncEngine and isinstance(self.connection, AsyncEngine):
import asyncio
Expand All @@ -177,11 +174,9 @@ async def run(engine):
return result

return asyncio.run(run(self.connection))
else:
if isinstance(self.connection, Engine):
with self.connection.connect() as connection:
result = fn(connection=connection, **kwargs)
else:
result = fn(connection=self.connection, **kwargs)

return result

if isinstance(self.connection, Engine):
with self.connection.connect() as connection:
return fn(connection=connection, **kwargs)

return fn(connection=self.connection, **kwargs)
7 changes: 4 additions & 3 deletions src/pytest_alembic/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@dataclass
class AlembicHistory:
map: RevisionMap
map: RevisionMap # noqa: A003
revisions: List[str]
revision_indices: Dict[str, int]
revisions_by_index: Dict[int, str]
Expand Down Expand Up @@ -41,7 +41,8 @@ def validate_revision(self, revision):
revision = "heads"

if revision not in self.revision_indices:
raise ValueError(f"Revision {revision} is not a valid revision in alembic's history")
message = f"Revision {revision} is not a valid revision in alembic's history"
raise ValueError(message)
return revision

def previous_revision(self, revision: str) -> Optional[str]:
Expand All @@ -64,7 +65,7 @@ def revision_range(self, current_revision: str, dest_revision: str) -> List[str]
def revision_window(self, current_revision: str, dest_revision: str) -> List[Tuple[str, str]]:
revision_range = self.revision_range(current_revision, dest_revision)
return list(
zip( # type: ignore
zip( # type: ignore[arg-type]
*(
collections.deque(itertools.islice(it, i), 0) or it
for i, it in enumerate(itertools.tee(revision_range, 2))
Expand Down
2 changes: 1 addition & 1 deletion src/pytest_alembic/plugin/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List


class AlembicTestFailure(AssertionError):
class AlembicTestFailure(AssertionError): # noqa: N818
def __init__(self, message, context=None):
super().__init__(message)
self.context = context
Expand Down
12 changes: 6 additions & 6 deletions src/pytest_alembic/plugin/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def create_alembic_fixture(raw_config=None):
>>> alembic = create_alembic_fixture({'file': 'migrations.ini'})
"""

@pytest.fixture
def _(alembic_engine):
@pytest.fixture()
def alembic_fixture(alembic_engine):
config = Config.from_raw_config(raw_config)
with pytest_alembic.runner(config=config, engine=alembic_engine) as runner:
yield runner

return _
return alembic_fixture


@pytest.fixture
@pytest.fixture()
def alembic_runner(alembic_config, alembic_engine):
"""Produce the primary alembic migration context in which to execute alembic tests.
Expand All @@ -66,7 +66,7 @@ def alembic_runner(alembic_config, alembic_engine):
yield runner


@pytest.fixture
@pytest.fixture()
def alembic_config() -> Union[Dict[str, Any], alembic.config.Config, Config]:
"""Override this fixture to configure the exact alembic context setup required.
Expand Down Expand Up @@ -119,7 +119,7 @@ def alembic_config() -> Union[Dict[str, Any], alembic.config.Config, Config]:
return {}


@pytest.fixture
@pytest.fixture()
def alembic_engine():
"""Override this fixture to provide pytest-alembic powered tests with a database handle."""
return sqlalchemy.create_engine("sqlite:///")
18 changes: 10 additions & 8 deletions src/pytest_alembic/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ class PytestAlembicPlugin:
# way to support both <7 and >=7 without weird nonsense like this.
if pytest_version_tuple and pytest_version_tuple[0] >= 7:

def pytest_collect_file(self, file_path, path, parent):
def pytest_collect_file(self, file_path, path, parent): # noqa: ARG002
if self.should_register(file_path):
return TestCollector.from_parent(parent, path=file_path)
return None

else:

def pytest_collect_file(self, path, parent): # type: ignore
def pytest_collect_file(self, path, parent): # type: ignore[misc]
if self.should_register(Path(path)):
return TestCollector.from_parent(parent, fspath=path)
return None

def should_register(self, path):
tests_path = PurePath(
Expand All @@ -35,10 +37,9 @@ def should_register(self, path):
or "tests/conftest.py"
)
relative_path = path.relative_to(self.config.rootpath)
if relative_path == tests_path:
if not self.registered:
self.registered = True
return True
if relative_path == tests_path and not self.registered:
self.registered = True
return True

return False

Expand Down Expand Up @@ -122,7 +123,7 @@ class OptionResolver:
excluded_tests: Optional[List[str]] = None

@classmethod
def collect_test_definitions(cls, default=True, experimental=True):
def collect_test_definitions(cls, *, default=True, experimental=True): # noqa: ARG003
import pytest_alembic.tests
import pytest_alembic.tests.experimental

Expand Down Expand Up @@ -197,7 +198,8 @@ def tests(self):

if invalid_tests:
invalid_str = ", ".join(sorted(invalid_tests))
raise ValueError(f"The following tests were unrecognized: {invalid_str}")
message = f"The following tests were unrecognized: {invalid_str}"
raise ValueError(message)

return [self.available_tests[t] for t in selected_tests]

Expand Down
22 changes: 10 additions & 12 deletions src/pytest_alembic/revision_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass
from typing import Dict, List, Union
from typing import Dict, List, TYPE_CHECKING, Union

if TYPE_CHECKING:
from pytest_alembic.config import Config


@dataclass
Expand Down Expand Up @@ -39,23 +42,18 @@ def from_config(cls, config: "Config"):
at_revision_data=RevisionSpec.parse(config.at_revision_data),
)

def get(self, revision: str, revision_data: Union[Dict, List[Dict]]):
def get(self, revision_data: Union[Dict, List[Dict]]):
if isinstance(revision_data, Dict):
yield revision_data
else:
for item in revision_data:
yield item
yield from revision_data

def get_before(self, revision: str) -> Union[Dict, List[Dict]]:
def get_before(self, revision: str) -> List[Dict]:
"""Yield the individual data insertions which should occur before the given revision."""
before_revision_data = self.before_revision_data.get(revision)
return self.get(revision, before_revision_data)
return list(self.get(before_revision_data))

def get_at(self, revision: str) -> Union[Dict, List[Dict]]:
"""Yield the individual data insertions which should occur upon reaching the given revision."""
"""Yield individual data insertions which should occur upon reaching the given revision."""
at_revision_data = self.at_revision_data.get(revision)
return self.get(revision, at_revision_data)


# isort: split
from pytest_alembic.config import Config # noqa
return list(self.get(at_revision_data))
Loading

0 comments on commit d1cbea3

Please sign in to comment.