From 19124848b68e7b1d5f793a479cc7533e50393b39 Mon Sep 17 00:00:00 2001 From: DanCardin Date: Wed, 3 Apr 2024 12:21:56 -0400 Subject: [PATCH] feat: Improve default postgres driver selection heuristics. --- docs/source/postgres.rst | 12 +++---- pyproject.toml | 2 +- .../container/postgres.py | 34 ++++++++++++++++--- .../container/redshift.py | 2 +- 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/docs/source/postgres.rst b/docs/source/postgres.rst index 882b493..225fc9e 100644 --- a/docs/source/postgres.rst +++ b/docs/source/postgres.rst @@ -7,13 +7,13 @@ Postgres and `asyncpg` for async fixtures. If you want to use a different driver, you can configure the `drivername` field using the `pmr_postgres_config` fixture: - ```python - from pytest_mock_resources import PostgresConfig + .. code-block:: python - @pytest.fixture - def pmr_postgres_config(): - PostgresConfig(drivername='postgresql+psycopg2') # but whatever drivername you require. - ``` + from pytest_mock_resources import PostgresConfig + + @pytest.fixture(scope='session') + def pmr_postgres_config(): + return PostgresConfig(drivername='postgresql+psycopg2') # but whatever driver you require. Note however, that the `asyncpg` driver **only** works with the async fixture, and the `psycopg2` driver **only** works with the synchronous fixture. These are inherent diff --git a/pyproject.toml b/pyproject.toml index 6ece8a3..4c1e8a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pytest-mock-resources" -version = "2.10.1" +version = "2.10.2" description = "A pytest plugin for easily instantiating reproducible mock resources." authors = [ "Omar Khan ", diff --git a/src/pytest_mock_resources/container/postgres.py b/src/pytest_mock_resources/container/postgres.py index cfb9e31..8cd6639 100644 --- a/src/pytest_mock_resources/container/postgres.py +++ b/src/pytest_mock_resources/container/postgres.py @@ -1,4 +1,5 @@ -from typing import ClassVar, Iterable +import sys +from typing import ClassVar, Iterable, Optional import sqlalchemy import sqlalchemy.exc @@ -7,6 +8,11 @@ from pytest_mock_resources.config import DockerContainerConfig, fallback from pytest_mock_resources.container.base import ContainerCheckFailed +if sys.version_info < (3, 8): + from importlib_metadata import Distribution +else: + from importlib.metadata import Distribution + class PostgresConfig(DockerContainerConfig): """Define the configuration object for postgres. @@ -48,7 +54,7 @@ class PostgresConfig(DockerContainerConfig): "username": "user", "password": "password", "root_database": "dev", - "drivername": "postgresql+psycopg2", + "drivername": None, } @fallback @@ -94,9 +100,7 @@ def check_fn(self): def get_sqlalchemy_engine(config, database_name, async_=False, autocommit=False, **engine_kwargs): # For backwards compatibility, our hardcoded default is psycopg2, and async fixtures # will not work with psycopg2, so we instead swap the default to the preferred async driver. - drivername = config.drivername - if async_ and drivername.endswith("psycopg2"): - drivername = drivername.replace("psycopg2", "asyncpg") + drivername = detect_driver(config.drivername, async_=async_) url = URL( drivername=drivername, @@ -118,3 +122,23 @@ def get_sqlalchemy_engine(config, database_name, async_=False, autocommit=False, engine = sqlalchemy.create_engine(url, **engine_kwargs) return engine + + +def detect_driver(drivername: Optional[str] = None, async_: bool = False) -> str: + if drivername: + return drivername + + if any(Distribution.discover(name="psycopg")): + return "postgresql+psycopg" + + if async_: + if any(Distribution.discover(name="asyncpg")): + return "postgresql+asyncpg" + else: + if any(Distribution.discover(name="psycopg2")): + return "postgresql+psycopg2" + + raise ValueError( # pragma: no cover + "No suitable driver found for Postgres. Please install `psycopg`, `psycopg2`, " + "`asyncpg`, or explicitly configure the `drivername=` field of `PostgresConfig`." + ) diff --git a/src/pytest_mock_resources/container/redshift.py b/src/pytest_mock_resources/container/redshift.py index 2aa36b0..799d2ed 100644 --- a/src/pytest_mock_resources/container/redshift.py +++ b/src/pytest_mock_resources/container/redshift.py @@ -43,5 +43,5 @@ class RedshiftConfig(PostgresConfig): "username": "user", "password": "password", "root_database": "dev", - "drivername": "postgresql+psycopg2", + "drivername": None, }