forked from elastic/connectors
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generic_database.py
150 lines (114 loc) · 3.91 KB
/
generic_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.
#
import asyncio
from abc import ABC, abstractmethod
from asyncpg.exceptions._base import InternalClientError
from sqlalchemy.exc import ProgrammingError
from connectors.utils import RetryStrategy, retryable
WILDCARD = "*"
DEFAULT_FETCH_SIZE = 50
DEFAULT_RETRY_COUNT = 3
DEFAULT_WAIT_MULTIPLIER = 2
def configured_tables(tables):
"""Split a string containing a comma-seperated list of tables by comma and strip the table names.
Filter out `None` and zero-length values from the tables.
If `tables` is a list return the list also without `None` and zero-length values.
Arguments:
- `tables`: string containing a comma-seperated list of tables or a list of tables
"""
def table_filter(table):
return table is not None and len(table) > 0
return (
list(
filter(
lambda table: table_filter(table),
(table.strip() for table in tables.split(",")),
)
)
if isinstance(tables, str)
else list(filter(lambda table: table_filter(table), tables))
)
def is_wildcard(tables):
return tables in (WILDCARD, [WILDCARD])
def map_column_names(column_names, schema=None, tables=None):
prefix = ""
if schema and len(schema.strip()) > 0:
prefix += schema.strip() + "_"
if tables and len(tables) > 0:
prefix += f"{'_'.join(sorted(tables))}_"
return [f"{prefix}{column}".lower() for column in column_names]
def hash_id(tables, row, primary_key_columns):
"""Generates an id using table names as prefix in sorted order and primary key values.
Example:
tables: table1, table2
primary key values: 1, 42
table1_table2_1_42
"""
if not isinstance(tables, list):
tables = [tables]
return (
f"{'_'.join(sorted(tables))}_"
f"{'_'.join([str(pk_value) for pk in primary_key_columns if (pk_value := row.get(pk)) is not None])}"
)
async def fetch(
cursor_func,
fetch_columns=False,
fetch_size=DEFAULT_FETCH_SIZE,
retry_count=DEFAULT_RETRY_COUNT,
):
@retryable(
retries=retry_count,
interval=DEFAULT_WAIT_MULTIPLIER,
strategy=RetryStrategy.EXPONENTIAL_BACKOFF,
skipped_exceptions=[InternalClientError, ProgrammingError],
)
async def _execute():
cursor = await cursor_func()
# sending back column names if required
if fetch_columns:
yield cursor.keys()
while True:
rows = cursor.fetchmany(size=fetch_size) # pyright: ignore
rows_length = len(rows)
if not rows_length:
break
for row in rows:
yield row
if rows_length < fetch_size:
break
await asyncio.sleep(0)
async for result in _execute():
yield result
class Queries(ABC):
"""Class contains abstract methods for queries"""
@abstractmethod
def ping(self):
"""Query to ping source"""
pass
@abstractmethod
def all_tables(self, **kwargs):
"""Query to get all tables"""
pass
@abstractmethod
def table_primary_key(self, **kwargs):
"""Query to get the primary key"""
pass
@abstractmethod
def table_data(self, **kwargs):
"""Query to get the table data"""
pass
@abstractmethod
def table_last_update_time(self, **kwargs):
"""Query to get the last update time of the table"""
pass
@abstractmethod
def table_data_count(self, **kwargs):
"""Query to get the number of rows in the table"""
pass
@abstractmethod
def all_schemas(self):
"""Query to get all schemas of database"""
pass