Skip to content

Commit

Permalink
[Issue 2665] Implement db connection pools (#2828)
Browse files Browse the repository at this point in the history
## Summary
Maybe Fixes #2665 

### Time to review: __5 mins__

## Changes proposed
> What was added, updated, or removed in this PR.

Added connection pools to `/analytics`, replacing single instance db
connections. Implementation follows pattern in
`/api/src/adapters/db/clients/`.

## Context for reviewers
> Testing instructions, background context, more in-depth details of the
implementation, and anything else you'd like to call out or ask
reviewers. Explain how the changes were verified.

This is latest step in a series of attempts to resolve failed db
connections from /analytics step functions. See ticket history and
comments for details.

## Additional information
> Screenshots, GIF demos, code examples or output to help show the
changes working as expected.
  • Loading branch information
DavidDudas-Intuitial authored Nov 13, 2024
1 parent f17ff75 commit 2e5d6a8
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 58 deletions.
12 changes: 6 additions & 6 deletions analytics/src/analytics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from analytics.datasets.issues import GitHubIssues
from analytics.etl.github import GitHubProjectConfig, GitHubProjectETL
from analytics.etl.utils import load_config
from analytics.integrations import db, etldb, slack
from analytics.integrations import etldb, slack
from analytics.integrations.db import PostgresDbClient
from analytics.metrics.base import BaseMetric, Unit
from analytics.metrics.burndown import SprintBurndown
from analytics.metrics.burnup import SprintBurnup
Expand Down Expand Up @@ -208,9 +209,8 @@ def show_and_or_post_results(
@import_app.command(name="test_connection")
def test_connection() -> None:
"""Test function that ensures the DB connection works."""
engine = db.get_db()
# connection method from sqlalchemy
connection = engine.connect()
client = PostgresDbClient()
connection = client.connect()

# Test INSERT INTO action
result = connection.execute(
Expand All @@ -234,14 +234,14 @@ def export_json_to_database(delivery_file: Annotated[str, ISSUE_FILE_ARG]) -> No
logger.info("Beginning import")

# Get the database engine and establish a connection
engine = db.get_db()
client = PostgresDbClient()

# Load data from the sprint board
issues = GitHubIssues.from_json(delivery_file)

issues.to_sql(
output_table="github_project_data",
engine=engine,
engine=client.engine(),
replace_table=True,
)
rows = len(issues.to_dict())
Expand Down
81 changes: 55 additions & 26 deletions analytics/src/analytics/integrations/db.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,77 @@
# pylint: disable=invalid-name, line-too-long
"""Get a connection to the database using a SQLAlchemy engine object."""

from typing import Any, cast

import boto3
from sqlalchemy import Engine, create_engine
import psycopg
from sqlalchemy import Connection, Engine, create_engine, pool

from config import DBSettings, get_db_settings

# The variables used in the connection url are pulled from local.env
# and configured in the DBSettings class found in config.py


def get_db() -> Engine:
"""
Get a connection to the database using a SQLAlchemy engine object.
class PostgresDbClient:
"""An implementation of of a Postgres db client."""

def __init__(self, config: DBSettings | None = None) -> None:
"""Construct a class instance."""
if not config:
config = get_db_settings()
self._engine = self._configure_engine(config)

def _configure_engine(self, config: DBSettings) -> Engine:
"""Configure db engine to use short-lived IAM tokens for access."""

# inspired by /api/src/adapters/db/clients/postgres_client.py
def get_conn() -> psycopg.Connection:
"""Get a psycopg connection."""
return psycopg.connect(**get_connection_parameters(config))

conn_pool = pool.QueuePool(cast(Any, get_conn), max_overflow=5, pool_size=10)

return create_engine(
"postgresql+psycopg://",
pool=conn_pool,
hide_parameters=True,
)

def connect(self) -> Connection:
"""Get a new database connection object."""
return self._engine.connect()

This function retrieves the database connection URL from the configuration
and creates a SQLAlchemy engine object.
def engine(self) -> Engine:
"""Get reference to db engine."""
return self._engine

Yields
------
sqlalchemy.engine.Engine
A SQLAlchemy engine object representing the connection to the database.
"""
db = get_db_settings()
# inspired by simpler-grants-gov/blob/main/api/src/adapters/db/clients/postgres_client.py
token = db.password if db.local_env is True else generate_iam_auth_token(db)
url = f"postgresql+psycopg://{db.user}:{token}@{db.db_host}:{db.port}/{db.name}?sslmode={db.ssl_mode}"

return create_engine(
url,
pool_pre_ping=True,
hide_parameters=True,
def get_connection_parameters(config: DBSettings) -> dict[str, Any]:
"""Get parameters for db connection."""
token = (
config.password if config.local_env is True else generate_iam_auth_token(config)
)
return {
"host": config.db_host,
"dbname": config.name,
"user": config.user,
"password": token,
"port": config.port,
"connect_timeout": 20,
"sslmode": config.ssl_mode,
}


def generate_iam_auth_token(settings: DBSettings) -> str:
def generate_iam_auth_token(config: DBSettings) -> str:
"""Generate IAM auth token."""
if settings.aws_region is None:
if config.aws_region is None:
msg = "AWS region needs to be configured for DB IAM auth"
raise ValueError(msg)
client = boto3.client("rds", region_name=settings.aws_region)
client = boto3.client("rds", region_name=config.aws_region)
return client.generate_db_auth_token(
DBHostname=settings.db_host,
Port=settings.port,
DBUsername=settings.user,
Region=settings.aws_region,
DBHostname=config.db_host,
Port=config.port,
DBUsername=config.user,
Region=config.aws_region,
)
23 changes: 9 additions & 14 deletions analytics/src/analytics/integrations/etldb/etldb.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
"""Define EtlDb as an abstraction layer for database connections."""

import contextlib
from enum import Enum

from sqlalchemy import Connection

from analytics.integrations import db
from analytics.integrations.db import PostgresDbClient


class EtlDb:
"""Encapsulate etl database connections."""

def __init__(self, effective: str | None = None) -> None:
"""Construct instance."""
self._db_engine = db.get_db()
try:
self._db_client = PostgresDbClient()
except RuntimeError as e:
message = f"Failed to instantiate database engine: {e}"
raise RuntimeError(message) from e

self._connection: Connection | None = None
self.effective_date = effective
self.dateformat = "%Y-%m-%d"

def __del__(self) -> None:
"""Destroy instance."""
self.disconnect()

def connection(self) -> Connection:
"""Get a connection object from the db engine."""
"""Get a connection object from the database engine."""
if self._connection is None:
try:
self._connection = self._db_engine.connect()
self._connection = self._db_client.connect()
except RuntimeError as e:
message = f"Failed to connect to database: {e}"
raise RuntimeError(message) from e
Expand All @@ -36,11 +36,6 @@ def commit(self, connection: Connection) -> None:
"""Commit an open transaction."""
connection.commit()

def disconnect(self) -> None:
"""Dispose of db connection."""
with contextlib.suppress(Exception):
self._db_engine.dispose()


class EtlChangeType(Enum):
"""An enum to describe ETL change types."""
Expand Down
56 changes: 44 additions & 12 deletions analytics/src/analytics/integrations/etldb/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from pathlib import Path

from psycopg.errors import InsufficientPrivilege
from sqlalchemy import text
from sqlalchemy.exc import OperationalError, ProgrammingError

from analytics.datasets.etl_dataset import EtlDataset, EtlEntityType
from analytics.integrations.etldb.deliverable_model import EtlDeliverableModel
Expand Down Expand Up @@ -33,8 +35,13 @@ def init_db() -> None:
text(sql),
)
db.commit(cursor)
except RuntimeError as e:
message = f"Failed to initialize db: {e}"
except (
RuntimeError,
ProgrammingError,
OperationalError,
InsufficientPrivilege,
) as e:
message = f"FATAL: Failed to initialize db: {e}"
raise RuntimeError(message) from e


Expand All @@ -55,8 +62,13 @@ def sync_db(dataset: EtlDataset, effective: str) -> None:
try:
ghid_map[EtlEntityType.QUAD] = sync_quads(db, dataset)
print(f"quad row(s) processed: {len(ghid_map[EtlEntityType.QUAD])}")
except RuntimeError as e:
message = f"Failed to sync quad data: {e}"
except (
RuntimeError,
ProgrammingError,
OperationalError,
InsufficientPrivilege,
) as e:
message = f"FATAL: Failed to sync quad data: {e}"
raise RuntimeError(message) from e

# sync deliverable data to db resulting in row id for each deliverable
Expand All @@ -69,32 +81,52 @@ def sync_db(dataset: EtlDataset, effective: str) -> None:
print(
f"deliverable row(s) processed: {len(ghid_map[EtlEntityType.DELIVERABLE])}",
)
except RuntimeError as e:
message = f"Failed to sync deliverable data: {e}"
except (
RuntimeError,
ProgrammingError,
OperationalError,
InsufficientPrivilege,
) as e:
message = f"FATAL: Failed to sync deliverable data: {e}"
raise RuntimeError(message) from e

# sync sprint data to db resulting in row id for each sprint
try:
ghid_map[EtlEntityType.SPRINT] = sync_sprints(db, dataset, ghid_map)
print(f"sprint row(s) processed: {len(ghid_map[EtlEntityType.SPRINT])}")
except RuntimeError as e:
message = f"Failed to sync sprint data: {e}"
except (
RuntimeError,
ProgrammingError,
OperationalError,
InsufficientPrivilege,
) as e:
message = f"FATAL: Failed to sync sprint data: {e}"
raise RuntimeError(message) from e

# sync epic data to db resulting in row id for each epic
try:
ghid_map[EtlEntityType.EPIC] = sync_epics(db, dataset, ghid_map)
print(f"epic row(s) processed: {len(ghid_map[EtlEntityType.EPIC])}")
except RuntimeError as e:
message = f"Failed to sync epic data: {e}"
except (
RuntimeError,
ProgrammingError,
OperationalError,
InsufficientPrivilege,
) as e:
message = f"FATAL: Failed to sync epic data: {e}"
raise RuntimeError(message) from e

# sync issue data to db resulting in row id for each issue
try:
issue_map = sync_issues(db, dataset, ghid_map)
print(f"issue row(s) processed: {len(issue_map)}")
except RuntimeError as e:
message = f"Failed to sync issue data: {e}"
except (
RuntimeError,
ProgrammingError,
OperationalError,
InsufficientPrivilege,
) as e:
message = f"FATAL: Failed to sync issue data: {e}"
raise RuntimeError(message) from e


Expand Down

0 comments on commit 2e5d6a8

Please sign in to comment.