-
Notifications
You must be signed in to change notification settings - Fork 15
/
connections.py
206 lines (169 loc) · 7.05 KB
/
connections.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional, Union
from dbt.adapters.contracts.connection import AdapterResponse, Credentials
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.sql import SQLConnectionManager
from dbt_common.exceptions import DbtDatabaseError, DbtRuntimeError
from dbt_common.helper_types import Port
from mashumaro.jsonschema.annotations import Maximum, Minimum
import psycopg2
from typing_extensions import Annotated
logger = AdapterLogger("Postgres")
@dataclass
class PostgresCredentials(Credentials):
host: str
user: str
# Annotated is used by mashumaro for jsonschema generation
port: Annotated[Port, Minimum(0), Maximum(65535)]
password: str # on postgres the password is mandatory
connect_timeout: int = 10
role: Optional[str] = None
search_path: Optional[str] = None
keepalives_idle: int = 0 # 0 means to use the default value
sslmode: Optional[str] = None
sslcert: Optional[str] = None
sslkey: Optional[str] = None
sslrootcert: Optional[str] = None
application_name: Optional[str] = "dbt"
retries: int = 1
_ALIASES = {"dbname": "database", "pass": "password"}
@property
def type(self):
return "postgres"
@property
def unique_field(self):
return self.host
def _connection_keys(self):
return (
"host",
"port",
"user",
"database",
"schema",
"connect_timeout",
"role",
"search_path",
"keepalives_idle",
"sslmode",
"sslcert",
"sslkey",
"sslrootcert",
"application_name",
"retries",
)
class PostgresConnectionManager(SQLConnectionManager):
TYPE = "postgres"
@contextmanager
def exception_handler(self, sql):
try:
yield
except psycopg2.DatabaseError as e:
logger.debug("Postgres error: {}".format(str(e)))
try:
self.rollback_if_open()
except psycopg2.Error:
logger.debug("Failed to release connection!")
pass
raise DbtDatabaseError(str(e).strip()) from e
except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.rollback_if_open()
if isinstance(e, DbtRuntimeError):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
# useful information, so raise it without modification.
raise
raise DbtRuntimeError(e) from e
@classmethod
def open(cls, connection):
if connection.state == "open":
logger.debug("Connection is already open, skipping open.")
return connection
credentials = cls.get_credentials(connection.credentials)
kwargs = {}
# we don't want to pass 0 along to connect() as postgres will try to
# call an invalid setsockopt() call (contrary to the docs).
if credentials.keepalives_idle:
kwargs["keepalives_idle"] = credentials.keepalives_idle
# psycopg2 doesn't support search_path officially,
# see https://github.com/psycopg/psycopg2/issues/465
search_path = credentials.search_path
if search_path is not None and search_path != "":
# see https://postgresql.org/docs/9.5/libpq-connect.html
kwargs["options"] = "-c search_path={}".format(search_path.replace(" ", "\\ "))
if credentials.sslmode:
kwargs["sslmode"] = credentials.sslmode
if credentials.sslcert is not None:
kwargs["sslcert"] = credentials.sslcert
if credentials.sslkey is not None:
kwargs["sslkey"] = credentials.sslkey
if credentials.sslrootcert is not None:
kwargs["sslrootcert"] = credentials.sslrootcert
if credentials.application_name:
kwargs["application_name"] = credentials.application_name
def connect():
handle = psycopg2.connect(
dbname=credentials.database,
user=credentials.user,
host=credentials.host,
password=credentials.password,
port=credentials.port,
connect_timeout=credentials.connect_timeout,
**kwargs,
)
if credentials.role:
handle.cursor().execute("set role {}".format(credentials.role))
return handle
retryable_exceptions = [
# OperationalError is subclassed by all psycopg2 Connection Exceptions and it's raised
# by generic connection timeouts without an error code. This is a limitation of
# psycopg2 which doesn't provide subclasses for errors without a SQLSTATE error code.
# The limitation has been known for a while and there are no efforts to tackle it.
# See: https://github.com/psycopg/psycopg2/issues/682
psycopg2.errors.OperationalError,
]
def exponential_backoff(attempt: int):
return attempt * attempt
return cls.retry_connection(
connection,
connect=connect,
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
retryable_exceptions=retryable_exceptions,
)
def cancel(self, connection):
connection_name = connection.name
try:
pid = connection.handle.get_backend_pid()
except psycopg2.InterfaceError as exc:
# if the connection is already closed, not much to cancel!
if "already closed" in str(exc):
logger.debug(f"Connection {connection_name} was already closed")
return
# probably bad, re-raise it
raise
sql = "select pg_terminate_backend({})".format(pid)
logger.debug("Cancelling query '{}' ({})".format(connection_name, pid))
_, cursor = self.add_query(sql)
res = cursor.fetchone()
logger.debug("Cancel query '{}': {}".format(connection_name, res))
@classmethod
def get_credentials(cls, credentials):
return credentials
@classmethod
def get_response(cls, cursor) -> AdapterResponse:
message = str(cursor.statusmessage)
rows = cursor.rowcount
status_message_parts = message.split() if message is not None else []
status_messsage_strings = [part for part in status_message_parts if not part.isdigit()]
code = " ".join(status_messsage_strings)
return AdapterResponse(_message=message, code=code, rows_affected=rows)
@classmethod
def data_type_code_to_name(cls, type_code: Union[int, str]) -> str:
if type_code in psycopg2.extensions.string_types:
return psycopg2.extensions.string_types[type_code].name
else:
return f"unknown type_code {type_code}"