Skip to content

Commit

Permalink
Have separate paratemeter for template and regular database name - cl…
Browse files Browse the repository at this point in the history
…oses #672
  • Loading branch information
fizyk committed Feb 14, 2024
1 parent d821171 commit fc56165
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 29 deletions.
2 changes: 2 additions & 0 deletions newsfragments/672.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Have separate parameters for template database name and database name in DatabaseJanitor.
It'll make it much clearer to understand the code and Janitor's behaviour.
16 changes: 13 additions & 3 deletions pytest_postgresql/factories/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
pg_user = proc_fixture.user
pg_password = proc_fixture.password
pg_options = proc_fixture.options
pg_db = dbname or proc_fixture.dbname
pg_db = dbname
pg_template = None
if not dbname:
pg_db = proc_fixture.dbname
pg_template = f"{pg_db}_tmpl"
pg_load = load or []
if pg_load:
warnings.warn(
Expand All @@ -75,9 +79,15 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
),
category=DeprecationWarning,
)

with DatabaseJanitor(
pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level
user=pg_user,
host=pg_host,
port=pg_port,
dbname=pg_db,
template_dbname=pg_template,
version=proc_fixture.version,
password=pg_password,
isolation_level=isolation_level,
) as janitor:
db_connection: Connection = psycopg.connect(
dbname=pg_db,
Expand Down
2 changes: 1 addition & 1 deletion pytest_postgresql/factories/noprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]
user=noop_exec.user,
host=noop_exec.host,
port=noop_exec.port,
dbname=template_dbname,
template_dbname=template_dbname,
version=noop_exec.version,
password=noop_exec.password,
) as janitor:
Expand Down
2 changes: 1 addition & 1 deletion pytest_postgresql/factories/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def postgresql_proc_fixture(
user=postgresql_executor.user,
host=postgresql_executor.host,
port=postgresql_executor.port,
dbname=template_dbname,
template_dbname=template_dbname,
version=postgresql_executor.version,
password=postgresql_executor.password,
) as janitor:
Expand Down
44 changes: 23 additions & 21 deletions pytest_postgresql/janitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def __init__(
user: str,
host: str,
port: Union[str, int],
dbname: str,
version: Union[str, float, Version], # type: ignore[valid-type]
dbname: Optional[str] = None,
template_dbname: Optional[str] = None,
password: Optional[str] = None,
isolation_level: "Optional[psycopg.IsolationLevel]" = None,
connection_timeout: int = 60,
Expand All @@ -49,7 +50,10 @@ def __init__(
self.password = password
self.host = host
self.port = port
# At least one of the dbname or template_dbname has to be filled.
assert any([dbname, template_dbname])
self.dbname = dbname
self.template_dbname = template_dbname
self._connection_timeout = connection_timeout
self.isolation_level = isolation_level
if not isinstance(version, Version):
Expand All @@ -59,36 +63,33 @@ def __init__(

def init(self) -> None:
"""Create database in postgresql."""
template_name = f"{self.dbname}_tmpl"
with self.cursor() as cur:
if self.dbname.endswith("_tmpl"):
result = False
else:
cur.execute(
"SELECT EXISTS "
"(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);",
(template_name,),
)
row = cur.fetchone()
result = (row is not None) and row[0]
if not result:
if self.is_template():
cur.execute(f'CREATE DATABASE "{self.template_dbname}";')
elif self.template_dbname is None:
cur.execute(f'CREATE DATABASE "{self.dbname}";')
else:
# All template database does not allow connection:
self._dont_datallowconn(cur, template_name)
self._dont_datallowconn(cur, self.template_dbname)
# And make sure no-one is left connected to the template database.
# Otherwise Creating database from template will fail
self._terminate_connection(cur, template_name)
cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{template_name}";')
# Otherwise, Creating database from template will fail
self._terminate_connection(cur, self.template_dbname)
cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";')

def is_template(self) -> bool:
"""Determine whether the DatabaseJanitor maintains template or database."""
return self.dbname is None

def drop(self) -> None:
"""Drop database in postgresql."""
# We cannot drop the database while there are connections to it, so we
# terminate all connections first while not allowing new connections.
db_to_drop = self.template_dbname if self.is_template() else self.dbname
assert db_to_drop
with self.cursor() as cur:
self._dont_datallowconn(cur, self.dbname)
self._terminate_connection(cur, self.dbname)
cur.execute(f'DROP DATABASE IF EXISTS "{self.dbname}";')
self._dont_datallowconn(cur, db_to_drop)
self._terminate_connection(cur, db_to_drop)
cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";')

@staticmethod
def _dont_datallowconn(cur: Cursor, dbname: str) -> None:
Expand All @@ -113,12 +114,13 @@ def load(self, load: Union[Callable, str, Path]) -> None:
* a callable that expects: host, port, user, dbname and password arguments.
"""
db_to_load = self.template_dbname if self.is_template() else self.dbname
_loader = build_loader(load)
_loader(
host=self.host,
port=self.port,
user=self.user,
dbname=self.dbname,
dbname=db_to_load,
password=self.password,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_janitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
@pytest.mark.parametrize("version", (VERSION, 10, "10"))
def test_version_cast(version: Any) -> None:
"""Test that version is cast to Version object."""
janitor = DatabaseJanitor("user", "host", "1234", "database_name", version)
janitor = DatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version)
assert janitor.version == VERSION


@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10)
janitor = DatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10)
with janitor.cursor():
connect_mock.assert_called_once_with(
dbname="postgres", user="user", password=None, host="host", port="1234"
Expand All @@ -32,7 +32,7 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10, "some_password")
janitor = DatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10, password="some_password")
with janitor.cursor():
connect_mock.assert_called_once_with(
dbname="postgres", user="user", password="some_password", host="host", port="1234"
Expand Down

0 comments on commit fc56165

Please sign in to comment.