Skip to content

Commit

Permalink
Move from psycopg2 -> psycopg3 (#98)
Browse files Browse the repository at this point in the history
This PR moves our PG driver from psycopg2 to psycopg3. We are changing
this because psycopg3 supports asyncio while psycopg2 does not.

Generally, the changes consisted of:
* changing host for db URLs from `postgresql` to `postgresql+psycopg`
(note, the package name for psycopg3 is `psycopg`)
* changing PG error code field name from `pgcode` to `sqlstate`

Assorted other changes
* changed the psycopg specific code in `_notification_listener` to the
psycopg3 API syntax
* cleanup exception handling in cases where we are re-raising a caught
exception
* change CI/CD workflow to complete the entire python test matrix
instead of failing fast on first error

---------

Co-authored-by: Qian Li <[email protected]>
  • Loading branch information
devhawk and qianl15 authored Sep 10, 2024
1 parent 95ef717 commit 3c75b50
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 159 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
integration:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
services:
Expand Down
17 changes: 6 additions & 11 deletions dbos/application_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
import sqlalchemy.exc as sa_exc
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session, sessionmaker

Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(self, config: ConfigFile):

# If the application database does not already exist, create it
postgres_db_url = sa.URL.create(
"postgresql",
"postgresql+psycopg",
username=config["database"]["username"],
password=config["database"]["password"],
host=config["database"]["hostname"],
Expand All @@ -55,7 +54,7 @@ def __init__(self, config: ConfigFile):

# Create a connection pool for the application database
app_db_url = sa.URL.create(
"postgresql",
"postgresql+psycopg",
username=config["database"]["username"],
password=config["database"]["password"],
host=config["database"]["hostname"],
Expand Down Expand Up @@ -97,11 +96,9 @@ def record_transaction_output(
)
)
except DBAPIError as dbapi_error:
if dbapi_error.orig.pgcode == "23505": # type: ignore
if dbapi_error.orig.sqlstate == "23505": # type: ignore
raise DBOSWorkflowConflictIDError(output["workflow_uuid"])
raise dbapi_error
except Exception as e:
raise e
raise

def record_transaction_error(self, output: TransactionResultInternal) -> None:
try:
Expand All @@ -122,11 +119,9 @@ def record_transaction_error(self, output: TransactionResultInternal) -> None:
)
)
except DBAPIError as dbapi_error:
if dbapi_error.orig.pgcode == "23505": # type: ignore
if dbapi_error.orig.sqlstate == "23505": # type: ignore
raise DBOSWorkflowConflictIDError(output["workflow_uuid"])
raise dbapi_error
except Exception as e:
raise e
raise

@staticmethod
def check_transaction_execution(
Expand Down
10 changes: 5 additions & 5 deletions dbos/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _execute_workflow(
status["status"] = "ERROR"
status["error"] = utils.serialize(error)
dbos.sys_db.update_workflow_status(status)
raise error
raise

return output

Expand All @@ -217,7 +217,7 @@ def _execute_workflow_wthread(
dbos.logger.error(
f"Exception encountered in asynchronous workflow: {traceback.format_exc()}"
)
raise e
raise


def _execute_workflow_id(dbos: "DBOS", workflow_id: str) -> "WorkflowHandle[Any]":
Expand Down Expand Up @@ -493,7 +493,7 @@ def invoke_tx(*args: Any, **kwargs: Any) -> Any:
)
break
except DBAPIError as dbapi_error:
if dbapi_error.orig.pgcode == "40001": # type: ignore
if dbapi_error.orig.sqlstate == "40001": # type: ignore
# Retry on serialization failure
ctx.get_current_span().add_event(
"Transaction Serialization Failure",
Expand All @@ -505,13 +505,13 @@ def invoke_tx(*args: Any, **kwargs: Any) -> Any:
max_retry_wait_seconds,
)
continue
raise dbapi_error
raise
except Exception as error:
# Don't record the error if it was already recorded
if not has_recorded_error:
txn_output["error"] = utils.serialize(error)
dbos.app_db.record_transaction_error(txn_output)
raise error
raise
return output

fi = get_or_create_func_info(func)
Expand Down
2 changes: 1 addition & 1 deletion dbos/dbos_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_dbos_database_url(config_file_path: str = "dbos-config.yaml") -> str:
"""
dbos_config = load_config(config_file_path)
db_url = URL.create(
"postgresql",
"postgresql+psycopg",
username=dbos_config["database"]["username"],
password=dbos_config["database"]["password"],
host=dbos_config["database"]["hostname"],
Expand Down
118 changes: 51 additions & 67 deletions dbos/system_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, TypedDict, cast

import psycopg2
import psycopg
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from alembic import command
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(self, config: ConfigFile):

# If the system database does not already exist, create it
postgres_db_url = sa.URL.create(
"postgresql",
"postgresql+psycopg",
username=config["database"]["username"],
password=config["database"]["password"],
host=config["database"]["hostname"],
Expand All @@ -172,7 +172,7 @@ def __init__(self, config: ConfigFile):
engine.dispose()

system_db_url = sa.URL.create(
"postgresql",
"postgresql+psycopg",
username=config["database"]["username"],
password=config["database"]["password"],
host=config["database"]["hostname"],
Expand All @@ -196,7 +196,7 @@ def __init__(self, config: ConfigFile):
)
command.upgrade(alembic_cfg, "head")

self.notification_conn: Optional[psycopg2.extensions.connection] = None
self.notification_conn: Optional[psycopg.connection.Connection] = None
self.notifications_map: Dict[str, threading.Condition] = {}
self.workflow_events_map: Dict[str, threading.Condition] = {}

Expand Down Expand Up @@ -565,11 +565,9 @@ def record_operation_result(
with self.engine.begin() as c:
c.execute(sql)
except DBAPIError as dbapi_error:
if dbapi_error.orig.pgcode == "23505": # type: ignore
if dbapi_error.orig.sqlstate == "23505": # type: ignore
raise DBOSWorkflowConflictIDError(result["workflow_uuid"])
raise dbapi_error
except Exception as e:
raise e
raise

def check_operation_execution(
self, workflow_uuid: str, function_id: int, conn: Optional[sa.Connection] = None
Expand Down Expand Up @@ -623,11 +621,9 @@ def send(
)
except DBAPIError as dbapi_error:
# Foreign key violation
if dbapi_error.orig.pgcode == "23503": # type: ignore
if dbapi_error.orig.sqlstate == "23503": # type: ignore
raise DBOSNonExistentWorkflowError(destination_uuid)
raise dbapi_error
except Exception as e:
raise e
raise
output: OperationResultInternal = {
"workflow_uuid": workflow_uuid,
"function_id": function_id,
Expand Down Expand Up @@ -729,69 +725,59 @@ def recv(
return message

def _notification_listener(self) -> None:
notification_cursor: Optional[psycopg2.extensions.cursor] = None
while self._run_background_processes:
try:
# Listen to notifications
self.notification_conn = psycopg2.connect(
self.engine.url.render_as_string(hide_password=False)
# since we're using the psycopg connection directly, we need a url without the "+pycopg" suffix
url = sa.URL.create(
"postgresql", **self.engine.url.translate_connect_args()
)
self.notification_conn.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
# Listen to notifications
self.notification_conn = psycopg.connect(
url.render_as_string(hide_password=False), autocommit=True
)
notification_cursor = self.notification_conn.cursor()

notification_cursor.execute("LISTEN dbos_notifications_channel")
notification_cursor.execute("LISTEN dbos_workflow_events_channel")
self.notification_conn.execute("LISTEN dbos_notifications_channel")
self.notification_conn.execute("LISTEN dbos_workflow_events_channel")

while self._run_background_processes:
if select.select([self.notification_conn], [], [], 60) == (
[],
[],
[],
):
continue
else:
self.notification_conn.poll()
while self.notification_conn.notifies:
notify = self.notification_conn.notifies.pop(0)
channel = notify.channel
dbos_logger.debug(
f"Received notification on channel: {channel}, payload: {notify.payload}"
)
if channel == "dbos_notifications_channel":
if (
notify.payload
and notify.payload in self.notifications_map
):
condition = self.notifications_map[notify.payload]
condition.acquire()
condition.notify_all()
condition.release()
dbos_logger.debug(
f"Signaled notifications condition for {notify.payload}"
)
elif channel == "dbos_workflow_events_channel":
if (
notify.payload
and notify.payload in self.workflow_events_map
):
condition = self.workflow_events_map[notify.payload]
condition.acquire()
condition.notify_all()
condition.release()
dbos_logger.debug(
f"Signaled workflow_events condition for {notify.payload}"
)
else:
dbos_logger.error(f"Unknown channel: {channel}")
gen = self.notification_conn.notifies(timeout=60)
for notify in gen:
channel = notify.channel
dbos_logger.debug(
f"Received notification on channel: {channel}, payload: {notify.payload}"
)
if channel == "dbos_notifications_channel":
if (
notify.payload
and notify.payload in self.notifications_map
):
condition = self.notifications_map[notify.payload]
condition.acquire()
condition.notify_all()
condition.release()
dbos_logger.debug(
f"Signaled notifications condition for {notify.payload}"
)
elif channel == "dbos_workflow_events_channel":
if (
notify.payload
and notify.payload in self.workflow_events_map
):
condition = self.workflow_events_map[notify.payload]
condition.acquire()
condition.notify_all()
condition.release()
dbos_logger.debug(
f"Signaled workflow_events condition for {notify.payload}"
)
else:
dbos_logger.error(f"Unknown channel: {channel}")
except Exception as e:
if self._run_background_processes:
dbos_logger.error(f"Notification listener error: {e}")
time.sleep(1)
# Then the loop will try to reconnect and restart the listener
finally:
if notification_cursor is not None:
notification_cursor.close()
if self.notification_conn is not None:
self.notification_conn.close()

Expand Down Expand Up @@ -848,11 +834,9 @@ def set_event(
)
)
except DBAPIError as dbapi_error:
if dbapi_error.orig.pgcode == "23505": # type: ignore
if dbapi_error.orig.sqlstate == "23505": # type: ignore
raise DBOSDuplicateWorkflowEventError(workflow_uuid, key)
raise dbapi_error
except Exception as e:
raise e
raise
output: OperationResultInternal = {
"workflow_uuid": workflow_uuid,
"function_id": function_id,
Expand Down
Loading

0 comments on commit 3c75b50

Please sign in to comment.