Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

103 get connection should pass parameters on to psycopg2 #104

Merged
123 changes: 87 additions & 36 deletions dbt_automation/utils/postgres.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""helpers for postgres"""

import os
from logging import basicConfig, getLogger, INFO
import psycopg2
import os
from sshtunnel import SSHTunnelForwarder
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface

Expand All @@ -15,38 +16,80 @@ class PostgresClient(WarehouseInterface):
"""a postgres client that can be used as a context manager"""

@staticmethod
def get_connection(host: str, port: str, user: str, password: str, database: str):
"""returns a psycopg connection"""
connection = psycopg2.connect(
host=host,
port=port,
user=user,
password=password,
database=database,
)
def get_connection(conn_info):
"""
returns a psycopg connection
parameters are
host: str
port: str
user: str
password: str
database: str
sslmode: require | disable | prefer | allow | verify-ca | verify-full
sslrootcert: /path/to/cert
...
"""
connect_params = {}
for key in [
"host",
"port",
"user",
"password",
"database",
"sslmode",
"sslrootcert",
]:
if key in conn_info:
connect_params[key] = conn_info[key]
connection = psycopg2.connect(**connect_params)
return connection

def __init__(self, conn_info: dict):
self.name = "postgres"
self.cursor = None
self.tunnel = None
self.connection = None

if conn_info is None: # take creds from env
conn_info = {
"host": os.getenv("DBHOST"),
"port": os.getenv("DBPORT"),
"username": os.getenv("DBUSER"),
"user": os.getenv("DBUSER"),
"password": os.getenv("DBPASSWORD"),
"database": os.getenv("DBNAME"),
}

self.connection = PostgresClient.get_connection(
conn_info.get("host"),
conn_info.get("port"),
conn_info.get("username"),
conn_info.get("password"),
conn_info.get("database"),
)
self.cursor = None
if "ssh_host" in conn_info:
self.tunnel = SSHTunnelForwarder(
(conn_info["ssh_host"], conn_info["ssh_port"]),
remote_bind_address=(conn_info["host"], conn_info["port"]),
# ...and credentials
ssh_pkey=conn_info.get("ssh_pkey"),
ssh_username=conn_info.get("ssh_username"),
ssh_password=conn_info.get("ssh_password"),
ssh_private_key_password=conn_info.get("ssh_private_key_password"),
)
self.tunnel.start()
conn_info["host"] = "localhost"
conn_info["port"] = self.tunnel.local_bind_port
self.connection = PostgresClient.get_connection(conn_info)

else:
self.connection = PostgresClient.get_connection(conn_info)
self.conn_info = conn_info

def __del__(self):
"""destructor"""
if self.cursor is not None:
self.cursor.close()
self.cursor = None
if self.connection is not None:
self.connection.close()
self.connection = None
if self.tunnel is not None:
self.tunnel.stop()
self.tunnel = None

def runcmd(self, statement: str):
Ishankoradia marked this conversation as resolved.
Show resolved Hide resolved
"""runs a command"""
if self.cursor is None:
Expand Down Expand Up @@ -75,7 +118,7 @@ def get_tables(self, schema: str) -> list:
def get_schemas(self) -> list:
"""returns the list of schema names in the given database connection"""
resultset = self.execute(
f"""
"""
SELECT nspname
FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%' AND nspname != 'information_schema';
Expand Down Expand Up @@ -132,15 +175,15 @@ def get_table_columns(self, schema: str, table: str) -> list:
)
return [{"name": x[0], "data_type": x[1]} for x in resultset]

def get_columnspec(self, schema: str, table: str):
def get_columnspec(self, schema: str, table_id: str):
"""get the column schema for this table"""
return [
x[0]
for x in self.execute(
f"""SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{schema}'
AND table_name = '{table}'
AND table_name = '{table_id}'
"""
)
]
Expand Down Expand Up @@ -193,33 +236,41 @@ def json_extract_op(self, json_column: str, json_field: str, sql_column: str):

def close(self):
try:
self.connection.close()
if self.cursor is not None:
self.cursor.close()
self.cursor = None
if self.tunnel is not None:
self.tunnel.stop()
self.tunnel = None
if self.connection is not None:
self.connection.close()
self.connection = None
except Exception:
logger.error("something went wrong while closing the postgres connection")

return True

def generate_profiles_yaml_dbt(self, project_name, default_schema):
"""Generates the profiles.yml dictionary object for dbt"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

"""
<project_name>:
Generates the profiles.yml dictionary object for dbt
<project_name>:
outputs:
prod:
dbname:
host:
password:
prod:
dbname:
host:
password:
port: 5432
user: airbyte_user
schema:
schema:
threads: 4
type: postgres
target: prod
"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

profiles_yml = {
f"{project_name}": {
"outputs": {
Expand All @@ -228,7 +279,7 @@ def generate_profiles_yaml_dbt(self, project_name, default_schema):
"host": self.conn_info["host"],
"password": self.conn_info["password"],
"port": int(self.conn_info["port"]),
"user": self.conn_info["username"],
"user": self.conn_info["user"],
"schema": default_schema,
"threads": 4,
"type": "postgres",
Expand Down
12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
bcrypt==4.1.2
cachetools==5.3.1
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.0
coverage==7.3.2
cryptography==42.0.5
dbt-automation @ git+https://github.com/DalgoT4D/dbt-automation.git
exceptiongroup==1.1.3
google-api-core==2.12.0
google-auth==2.23.2
Expand All @@ -14,22 +18,30 @@ grpcio==1.59.0
grpcio-status==1.59.0
idna==3.4
iniconfig==2.0.0
numpy==1.26.0
packaging==23.2
pandas==2.1.1
paramiko==3.4.0
pluggy==1.3.0
proto-plus==1.22.3
protobuf==4.24.4
psycopg2-binary==2.9.7
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.22
PyNaCl==1.5.0
pytest==7.4.3
pytest-cov==4.1.0
pytest-env==1.1.1
python-dateutil==2.8.2
python-dotenv==1.0.0
pytz==2023.3.post1
PyYAML==6.0.1
requests==2.31.0
rsa==4.9
six==1.16.0
sshtunnel==0.4.0
tomli==2.0.1
tqdm==4.66.1
tzdata==2023.3
urllib3==2.0.6
2 changes: 1 addition & 1 deletion tests/warehouse/test_postgres_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TestPostgresOperations:
{
"host": os.environ.get("TEST_PG_DBHOST"),
"port": os.environ.get("TEST_PG_DBPORT"),
"username": os.environ.get("TEST_PG_DBUSER"),
"user": os.environ.get("TEST_PG_DBUSER"),
"database": os.environ.get("TEST_PG_DBNAME"),
"password": os.environ.get("TEST_PG_DBPASSWORD"),
},
Expand Down
Loading