From 56acd6d7080c1bd4eb3f21dbdebe87e1d0f3c5e5 Mon Sep 17 00:00:00 2001 From: David Dudas Date: Fri, 8 Nov 2024 20:20:45 -0800 Subject: [PATCH] [Issue 2665] Analytics jobs conditionally use IAM token as Postgres pwd (#2799) ## Summary Fixes #2665 ### Time to review: __2 mins__ ## Changes proposed > What was added, updated, or removed in this PR. Adds a switch in `analytics/integrations/db.py` that determines which value to use for a DB password when connecting to Postgres: either the value in `local.env` or an IAM token, depending on an environment variable. Also adds better exception handling, to make errors easier to spot in CI. ## Context for reviewers > Testing instructions, background context, more in-depth details of the implementation, and anything else you'd like to call out or ask reviewers. Explain how the changes were verified. This is a follow up to previous PRs, https://github.com/HHS/simpler-grants-gov/pull/2786 and https://github.com/HHS/simpler-grants-gov/pull/2796, and part of an effort to get `analytics` step functions to successfully connect to Postgres DB in a CI environment. ## Additional information > Screenshots, GIF demos, code examples or output to help show the changes working as expected. --- analytics/config.py | 1 + analytics/src/analytics/integrations/db.py | 2 +- .../src/analytics/integrations/etldb/etldb.py | 10 ++- .../src/analytics/integrations/etldb/main.py | 66 +++++++++++++------ 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/analytics/config.py b/analytics/config.py index 01ad71d92..6d07bc3d5 100644 --- a/analytics/config.py +++ b/analytics/config.py @@ -19,6 +19,7 @@ class DBSettings(PydanticBaseEnvConfig): slack_bot_token: str = Field(alias="ANALYTICS_SLACK_BOT_TOKEN") reporting_channel_id: str = Field(alias="ANALYTICS_REPORTING_CHANNEL_ID") aws_region: Optional[str] = Field(None, alias="AWS_REGION") + local_env: bool = True if os.getenv("ENVIRONMENT", "local") == "local" else False def get_db_settings() -> DBSettings: return DBSettings() diff --git a/analytics/src/analytics/integrations/db.py b/analytics/src/analytics/integrations/db.py index be46df2d9..7ad722c71 100644 --- a/analytics/src/analytics/integrations/db.py +++ b/analytics/src/analytics/integrations/db.py @@ -24,7 +24,7 @@ def get_db() -> Engine: """ db = get_db_settings() # inspired by simpler-grants-gov/blob/main/api/src/adapters/db/clients/postgres_client.py - token = generate_iam_auth_token(db) if db.password is None else db.password + token = db.password if db.local_env is True else generate_iam_auth_token(db) return create_engine( f"postgresql+psycopg://{db.user}:{token}@{db.db_host}:{db.port}", diff --git a/analytics/src/analytics/integrations/etldb/etldb.py b/analytics/src/analytics/integrations/etldb/etldb.py index 7a25faed3..6404d7e8e 100644 --- a/analytics/src/analytics/integrations/etldb/etldb.py +++ b/analytics/src/analytics/integrations/etldb/etldb.py @@ -1,5 +1,6 @@ """Define EtlDb as an abstraction layer for database connections.""" +import contextlib from enum import Enum from sqlalchemy import Connection @@ -24,7 +25,11 @@ def __del__(self) -> None: def connection(self) -> Connection: """Get a connection object from the db engine.""" if self._connection is None: - self._connection = self._db_engine.connect() + try: + self._connection = self._db_engine.connect() + except RuntimeError as e: + message = f"Failed to connect to database: {e}" + raise RuntimeError(message) from e return self._connection def commit(self, connection: Connection) -> None: @@ -33,7 +38,8 @@ def commit(self, connection: Connection) -> None: def disconnect(self) -> None: """Dispose of db connection.""" - self._db_engine.dispose() + with contextlib.suppress(Exception): + self._db_engine.dispose() class EtlChangeType(Enum): diff --git a/analytics/src/analytics/integrations/etldb/main.py b/analytics/src/analytics/integrations/etldb/main.py index 11f790bda..634b24593 100644 --- a/analytics/src/analytics/integrations/etldb/main.py +++ b/analytics/src/analytics/integrations/etldb/main.py @@ -26,12 +26,16 @@ def init_db() -> None: sql = f.read() # execute sql - db = EtlDb() - cursor = db.connection() - cursor.execute( - text(sql), - ) - db.commit(cursor) + try: + db = EtlDb() + cursor = db.connection() + cursor.execute( + text(sql), + ) + db.commit(cursor) + except RuntimeError as e: + message = f"Failed to initialize db: {e}" + raise RuntimeError(message) from e def sync_db(dataset: EtlDataset, effective: str) -> None: @@ -48,28 +52,50 @@ def sync_db(dataset: EtlDataset, effective: str) -> None: db = EtlDb(effective) # sync quad data to db resulting in row id for each quad - ghid_map[EtlEntityType.QUAD] = sync_quads(db, dataset) - print(f"quad row(s) processed: {len(ghid_map[EtlEntityType.QUAD])}") + try: + ghid_map[EtlEntityType.QUAD] = sync_quads(db, dataset) + print(f"quad row(s) processed: {len(ghid_map[EtlEntityType.QUAD])}") + except RuntimeError as e: + message = f"Failed to sync quad data: {e}" + raise RuntimeError(message) from e # sync deliverable data to db resulting in row id for each deliverable - ghid_map[EtlEntityType.DELIVERABLE] = sync_deliverables( - db, - dataset, - ghid_map, - ) - print(f"deliverable row(s) processed: {len(ghid_map[EtlEntityType.DELIVERABLE])}") + try: + ghid_map[EtlEntityType.DELIVERABLE] = sync_deliverables( + db, + dataset, + ghid_map, + ) + print( + f"deliverable row(s) processed: {len(ghid_map[EtlEntityType.DELIVERABLE])}", + ) + except RuntimeError as e: + message = f"Failed to sync deliverable data: {e}" + raise RuntimeError(message) from e # sync sprint data to db resulting in row id for each sprint - ghid_map[EtlEntityType.SPRINT] = sync_sprints(db, dataset, ghid_map) - print(f"sprint row(s) processed: {len(ghid_map[EtlEntityType.SPRINT])}") + try: + ghid_map[EtlEntityType.SPRINT] = sync_sprints(db, dataset, ghid_map) + print(f"sprint row(s) processed: {len(ghid_map[EtlEntityType.SPRINT])}") + except RuntimeError as e: + message = f"Failed to sync sprint data: {e}" + raise RuntimeError(message) from e # sync epic data to db resulting in row id for each epic - ghid_map[EtlEntityType.EPIC] = sync_epics(db, dataset, ghid_map) - print(f"epic row(s) processed: {len(ghid_map[EtlEntityType.EPIC])}") + try: + ghid_map[EtlEntityType.EPIC] = sync_epics(db, dataset, ghid_map) + print(f"epic row(s) processed: {len(ghid_map[EtlEntityType.EPIC])}") + except RuntimeError as e: + message = f"Failed to sync epic data: {e}" + raise RuntimeError(message) from e # sync issue data to db resulting in row id for each issue - issue_map = sync_issues(db, dataset, ghid_map) - print(f"issue row(s) processed: {len(issue_map)}") + try: + issue_map = sync_issues(db, dataset, ghid_map) + print(f"issue row(s) processed: {len(issue_map)}") + except RuntimeError as e: + message = f"Failed to sync issue data: {e}" + raise RuntimeError(message) from e def sync_deliverables(db: EtlDb, dataset: EtlDataset, ghid_map: dict) -> dict: