-
Notifications
You must be signed in to change notification settings - Fork 0
/
database.py
287 lines (256 loc) · 11 KB
/
database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import logging
from typing import Callable, Optional, Sequence
from itertools import count
from functools import wraps
import datetime
import mysql.connector
from mysql.connector import MySQLConnection
from mysql.connector import Error, errorcode
from bookparse import Book
# MySQL Errors handling. Used as decorator
def handle_mysql_errors(func: Callable):
@wraps(func)
def wrapper(db_obj, *args, **kwargs):
try:
db_obj._validate_connection()
return func(db_obj, *args, **kwargs)
except Error as err:
if err.errno == errorcode.ER_PARSE_ERROR:
logging.error(f'incorrect SQL syntax')
elif err.errno == errorcode.ER_NO_REFERENCED_ROW_2:
logging.error('foreign key constraint fails')
else:
logging.error(f'unexpected error: {err}')
logging.error(f'\t in {func.__name__} args={args}, kwargs={kwargs}')
if hasattr(db_obj, '_connection') and db_obj._connection.is_connected():
db_obj._connection.rollback()
return wrapper
class Database:
_ids = count(0)
_connection: MySQLConnection
_config: dict
def __init__(self, config: dict):
self.id = next(self._ids)
if self.id > 1:
logging.warn(f"there are {self.id} instances of `Database`")
self._config = config
self.connect()
def connect(self):
try:
self._connection = mysql.connector.connect(**self._config,
autocommit=True)
except Error as err:
if err.errno == errorcode.ER_ACCESS_DENIED_ERROR:
logging.error('invalid login details')
elif err.errno == errorcode.ER_BAD_DB_ERROR:
logging.error('database does not exsist')
else:
logging.error(f'unexpected error: {err} in `Database.connect`')
def reconnect(self):
self._connection.reconnect()
def _validate_connection(self):
if not self._connection.is_connected():
self.connect()
@handle_mysql_errors
def insert_book(self, book: Book) -> None:
"""Insert data about the book and pages into database."""
self._connection.start_transaction()
with self._connection.cursor() as cursor:
cursor.execute("SELECT MAX(id) FROM book")
new_book_id = cursor.fetchone()
if new_book_id and new_book_id[0] != None:
new_book_id = int(new_book_id[0]) + 1
else:
new_book_id = 1
cursor.execute(f"INSERT INTO book VALUES ('{new_book_id}', {str(book)})")
marked_pages = [(new_book_id, i + 1, page) for i, page in enumerate(book.pages)]
cursor.executemany("INSERT INTO page (book_id, num, content)\
VALUES (%s, %s, %s)", marked_pages)
self._connection.commit()
@handle_mysql_errors
def check_for_admin(self, chat_id: int) -> bool:
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT role_name FROM chat_role_view \
WHERE chat_id = {chat_id} \
AND role_name = 'admin'")
if cursor.fetchone():
return True
return False
@handle_mysql_errors
def users_counts(self) -> tuple[int, int, int]:
"""
Counts bot users
Returns results in form `(total_count, banned_users, admins)`
"""
protect_value = lambda val: 0 if val == None else val[0]
with self._connection.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM chat")
total = protect_value(cursor.fetchone())
cursor.execute("SELECT COUNT(*) FROM chat_role_view \
WHERE role_name = 'banned'")
banned = protect_value(cursor.fetchone())
cursor.execute("SELECT COUNT(*) FROM chat_role_view \
WHERE role_name = 'admin'")
admins = protect_value(cursor.fetchone())
return (total, banned, admins)
@handle_mysql_errors
def search_admins(self) -> list[int]:
"""
Returns list of admins ids
"""
with self._connection.cursor() as cursor:
cursor.execute("SELECT chat_id FROM chat_role_view \
WHERE role_name = 'admin'")
results = cursor.fetchall()
return [row[0] for row in results]
@handle_mysql_errors
def new_admin(self, chat_id, expire_date: Optional[datetime.date] = None):
"""
Adds administrator role to given `chat_id` with given `expire_date`.
If the user is already an admin, expire date is updated
"""
self._connection.start_transaction()
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT COUNT(chat_id) FROM chat_role_view \
WHERE role_name = 'admin' and chat_id={chat_id}")
count = cursor.fetchone()
if count == None or count[0] == 0:
statement = """INSERT INTO chat_role VALUES (%s,
(SELECT id FROM role WHERE name = 'admin'),
CURDATE(), %s)"""
data = (chat_id, expire_date)
else:
statement = """UPDATE chat_role SET expire_date = %s
WHERE chat_id = %s AND
role_id = (SELECT id FROM role WHERE name = 'admin')"""
data = (expire_date, chat_id)
cursor.execute(statement, data)
self._connection.commit()
@handle_mysql_errors
def check_user_exist(self, chat_id: int) -> bool:
"""
Search user with given `chat_id` in database
Returns `True` if user exists
"""
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT COUNT(*) FROM chat \
WHERE id = {chat_id}")
count = cursor.fetchone()
if count != None and count[0] == 1:
return True
return False
@handle_mysql_errors
def get_banned_users(self) -> Sequence[int]:
"""
Search user's chats with role 'banned' in database.
"""
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT chat_id FROM chat_role_view \
WHERE role_name = 'banned'")
return [row[0] for row in cursor.fetchall()]
@handle_mysql_errors
def record_new_chat(self, chat_id: int) -> None:
"""
Insert new chat into database
When chat with given id already exists a MySQL exception occurs.
Exception handled by internal function `database.handle_mysql_errors`
(as other `Database` methods)
and returns `None`
"""
self._connection.start_transaction()
with self._connection.cursor() as cursor:
cursor.execute(f"INSERT INTO chat (id) \
VALUES ({chat_id})")
cursor.execute(f"INSERT INTO chat_role VALUES ({chat_id}, \
(SELECT id FROM role WHERE name = 'user'), \
CURDATE(), NULL)")
self._connection.commit()
@handle_mysql_errors
def ban_users(self, chat_ids: Sequence[int]) -> bool | None:
"""
Add role 'banned' to given `chat_ids`.
When chat with given id already exists a MySQL exception occurs.
Exception handled by internal function `database.handle_mysql_errors`
(as other `Database` methods)
and returns `None`
"""
self._connection.start_transaction()
with self._connection.cursor() as cursor:
statement = "INSERT INTO chat_role VALUES \
(%s, (SELECT id FROM role WHERE name = 'banned'), \
CURDATE(), NULL)"
# add extra dimesion for separate rows
chat_ids = [[chat] for chat in chat_ids]
cursor.executemany(statement, chat_ids)
self._connection.commit()
return True
@handle_mysql_errors
def unban_users(self, chat_ids: Sequence[int]) -> bool | None:
"""
Remove 'banned' role from given `chat_ids`.
Does the opposite operation of the `ban_users` method.
"""
self._connection.start_transaction()
with self._connection.cursor() as cursor:
statement = "DELETE FROM chat_role WHERE \
role_id = (SELECT id FROM role WHERE name = 'banned') \
AND chat_id=%s"
chat_ids = [[chat] for chat in chat_ids]
cursor.executemany(statement, chat_ids)
self._connection.commit()
return True
@handle_mysql_errors
def search_book(self, rows_count: int, offset: int = 0):
"""
Get list of available books in pairs (title, author)
"""
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT id, title, author FROM book \
ORDER BY id LIMIT {offset}, {rows_count}")
return cursor.fetchall()
@handle_mysql_errors
def search_max_page(self, chat_id: int):
"""
Returns max page number of user's book
"""
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT MAX(num) FROM page WHERE book_id = \
(SELECT book_id FROM chat WHERE id = {chat_id})")
max_page = cursor.fetchone()
if max_page != None:
return max_page[0]
return None
@handle_mysql_errors
def update_chat_book(self, chat_id: int, book_id):
self._connection.start_transaction()
with self._connection.cursor() as cursor:
cursor.execute(f"UPDATE chat SET book_id = {book_id} \
WHERE id = {chat_id}")
self._connection.commit()
@handle_mysql_errors
def book_metadata(self, book_id: int):
"""
Get book metadata: title, author, description
"""
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT title, author, info \
FROM book WHERE id = {book_id}")
metadata = cursor.fetchone()
return metadata
@handle_mysql_errors
def page_content(self, chat_id: int, page_num: int):
"""
Text of page with number=`page_num` from user's book
with chat_id=`chat_id`
"""
with self._connection.cursor() as cursor:
cursor.execute(f"SELECT content FROM page WHERE \
num = {page_num} AND book_id = \
(SELECT book_id FROM chat WHERE id = {chat_id})")
page_text = cursor.fetchone()
if page_text != None:
return page_text[0]
return None
if __name__ == '__main__':
logging.warning("To run the bot, use a different .py file. \
This class is needed only to communicate with the database.")