Skip to content

Commit

Permalink
Merge pull request #126 from DalgoT4D/125-explore-doesnt-work-if-the-…
Browse files Browse the repository at this point in the history
…ssl-certificate-is-inline

handle inline certificates
  • Loading branch information
fatchat authored Dec 9, 2024
2 parents 140a1f2 + 9daffab commit 444e4b7
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 2 deletions.
37 changes: 35 additions & 2 deletions dbt_automation/utils/postgres.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""helpers for postgres"""

import os
import tempfile
from logging import basicConfig, getLogger, INFO
import psycopg2
from sshtunnel import SSHTunnelForwarder
Expand Down Expand Up @@ -36,11 +37,43 @@ def get_connection(conn_info):
"user",
"password",
"database",
"sslmode",
"sslrootcert",
]:
if key in conn_info:
connect_params[key] = conn_info[key]

# ssl_mode is an alias for sslmode
if "ssl_mode" in conn_info:
conn_info["sslmode"] = conn_info["ssl_mode"]

if "sslmode" in conn_info:
# sslmode can be a string or a boolean or a dict
if isinstance(conn_info["sslmode"], str):
# "require", "disable", "verify-ca", "verify-full"
connect_params["sslmode"] = conn_info["sslmode"]
elif isinstance(conn_info["sslmode"], bool):
# true = require, false = disable
connect_params["sslmode"] = (
"require" if conn_info["sslmode"] else "disable"
)
elif (
isinstance(conn_info["sslmode"], dict)
and "mode" in conn_info["sslmode"]
):
# mode is "require", "disable", "verify-ca", "verify-full" etc
connect_params["sslmode"] = conn_info["sslmode"]["mode"]
if "ca_certificate" in conn_info["sslmode"]:
# connect_params['sslcert'] needs a file path but
# conn_info['sslmode']['ca_certificate']
# is a string (i.e. the actual certificate). so we write
# it to disk and pass the file path
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(conn_info["sslmode"]["ca_certificate"].encode())
connect_params["sslrootcert"] = fp.name
connect_params["sslcert"] = fp.name

if "sslrootcert" in conn_info:
connect_params["sslrootcert"] = conn_info["sslrootcert"]

connection = psycopg2.connect(**connect_params)
return connection

Expand Down
113 changes: 113 additions & 0 deletions tests/utils/test_postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from unittest.mock import patch, ANY
from dbt_automation.utils.postgres import PostgresClient


def test_get_connection_1():
"""tests PostgresClient.get_connection"""
with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect:
PostgresClient.get_connection(
{"host": "HOST", "port": 1234, "user": "USER", "password": "PASSWORD"}
)
mock_connect.assert_called_once()
mock_connect.assert_called_with(
host="HOST",
port=1234,
user="USER",
password="PASSWORD",
)


def test_get_connection_2():
"""tests PostgresClient.get_connection"""
with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect:
PostgresClient.get_connection(
{
"host": "HOST",
"port": 1234,
"user": "USER",
"password": "PASSWORD",
"database": "DATABASE",
}
)
mock_connect.assert_called_once()
mock_connect.assert_called_with(
host="HOST",
port=1234,
user="USER",
password="PASSWORD",
database="DATABASE",
)


def test_get_connection_3():
"""tests PostgresClient.get_connection"""
with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect:
PostgresClient.get_connection(
{
"sslmode": "verify-ca",
"sslrootcert": "/path/to/cert",
}
)
mock_connect.assert_called_once()
mock_connect.assert_called_with(
sslmode="verify-ca",
sslrootcert="/path/to/cert",
)


def test_get_connection_4():
"""tests PostgresClient.get_connection"""
with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect:
PostgresClient.get_connection(
{
"sslmode": True,
"sslrootcert": "/path/to/cert",
}
)
mock_connect.assert_called_once()
mock_connect.assert_called_with(
sslmode="require",
sslrootcert="/path/to/cert",
)


def test_get_connection_5():
"""tests PostgresClient.get_connection"""
with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect:
PostgresClient.get_connection(
{
"sslmode": False,
"sslrootcert": "/path/to/cert",
}
)
mock_connect.assert_called_once()
mock_connect.assert_called_with(
sslmode="disable",
sslrootcert="/path/to/cert",
)


def test_get_connection_6():
"""tests PostgresClient.get_connection"""
with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect:
PostgresClient.get_connection(
{
"sslmode": {
"mode": "disable",
}
}
)
mock_connect.assert_called_once()
mock_connect.assert_called_with(
sslmode="disable",
)


def test_get_connection_7():
"""tests PostgresClient.get_connection"""
with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect:
PostgresClient.get_connection(
{"sslmode": {"mode": "disable", "ca_certificate": "LONG-CERTIFICATE"}}
)
mock_connect.assert_called_once()
mock_connect.assert_called_with(sslmode="disable", sslrootcert=ANY, sslcert=ANY)

0 comments on commit 444e4b7

Please sign in to comment.