diff --git a/requirements.txt b/requirements.txt index 2f89a2502d..fdb4b6b8e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ requests>=2.0.0,<3.0.0 dnspython<2.0; python_version >= '2.7' and python_version < '3.0' dnspython<1.16.0; python_version == '3.3' dnspython<3.0; python_version >= '3.4' +sqlalchemy<1.3; python_version == '3.3' +sqlalchemy<1.4; python_version != '3.3' diff --git a/sopel/config/core_section.py b/sopel/config/core_section.py index b08fdbcf13..fa8eb818ac 100644 --- a/sopel/config/core_section.py +++ b/sopel/config/core_section.py @@ -96,8 +96,41 @@ class CoreSection(StaticSection): channels = ListAttribute('channels') """List of channels for the bot to join when it connects""" + db_type = ChoiceAttribute('db_type', choices=[ + 'sqlite', 'mysql', 'postgres', 'mssql', 'oracle', 'firebird', 'sybase'], default='sqlite') + """The type of database to use for Sopel's database. + + mysql - pip install mysql-python (Python 2) or pip install mysqlclient (Python 3) + postgres - pip install psycopg2 + mssql - pip install pymssql + + See https://docs.sqlalchemy.org/en/latest/dialects/ for a full list of dialects + """ + db_filename = ValidatedAttribute('db_filename') - """The filename for Sopel's database.""" + """The filename for Sopel's database. (SQLite only)""" + + db_driver = ValidatedAttribute('db_driver') + """The driver for Sopel's database. + + This is optional, but can be specified if user wants to use a different driver + https://docs.sqlalchemy.org/en/latest/core/engines.html + """ + + db_user = ValidatedAttribute('db_user') + """The user for Sopel's database.""" + + db_pass = ValidatedAttribute('db_pass') + """The password for Sopel's database.""" + + db_host = ValidatedAttribute('db_host') + """The host for Sopel's database.""" + + db_port = ValidatedAttribute('db_port') + """The port for Sopel's database.""" + + db_name = ValidatedAttribute('db_name') + """The name of Sopel's database.""" default_time_format = ValidatedAttribute('default_time_format', default='%Y-%m-%d - %T%Z') diff --git a/sopel/db.py b/sopel/db.py index 3dcd2d7a0b..52f5f4a546 100644 --- a/sopel/db.py +++ b/sopel/db.py @@ -4,9 +4,14 @@ import json import os.path import sys -import sqlite3 -from sopel.tools import Identifier +from sopel import tools + +from sqlalchemy import create_engine, Column, ForeignKey, Integer, String +from sqlalchemy.engine.url import URL +from sqlalchemy.exc import OperationalError, SQLAlchemyError +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import scoped_session, sessionmaker if sys.version_info.major >= 3: unicode = str @@ -28,6 +33,39 @@ def _deserialize(value): return value +BASE = declarative_base() + + +class NickIDs(BASE): + """NickIDs SQLAlchemy Class""" + __tablename__ = 'nick_ids' + nick_id = Column(Integer, primary_key=True) + + +class Nicknames(BASE): + """Nicknames SQLAlchemy Class""" + __tablename__ = 'nicknames' + nick_id = Column(Integer, ForeignKey('nick_ids.nick_id'), primary_key=True) + slug = Column(String, primary_key=True) + canonical = Column(String) + + +class NickValues(BASE): + """NickValues SQLAlchemy Class""" + __tablename__ = 'nick_values' + nick_id = Column(Integer, ForeignKey('nick_ids.nick_id'), primary_key=True) + key = Column(String(255), primary_key=True) + value = Column(String(255)) + + +class ChannelValues(BASE): + """ChannelValues SQLAlchemy Class""" + __tablename__ = 'channel_values' + channel = Column(String(255), primary_key=True) + key = Column(String(255), primary_key=True) + value = Column(String(255)) + + class SopelDB(object): """*Availability: 5.0+* @@ -39,20 +77,70 @@ class SopelDB(object): directory as the config.""" def __init__(self, config): - path = config.core.db_filename - config_dir, config_file = os.path.split(config.filename) - config_name, _ = os.path.splitext(config_file) - if path is None: - path = os.path.join(config_dir, config_name + '.db') - path = os.path.expanduser(path) - if not os.path.isabs(path): - path = os.path.normpath(os.path.join(config_dir, path)) - self.filename = path - self._create() + # MySQL - mysql://username:password@localhost/db + # SQLite - sqlite:////home/sopel/.sopel/default.db + db_type = config.core.db_type + + # Handle SQLite explicitly as a default + if db_type == 'sqlite': + path = config.core.db_filename + config_dir, config_file = os.path.split(config.filename) + config_name, _ = os.path.splitext(config_file) + if path is None: + path = os.path.join(config_dir, config_name + '.db') + path = os.path.expanduser(path) + if not os.path.isabs(path): + path = os.path.normpath(os.path.join(config_dir, path)) + self.filename = path + self.url = 'sqlite:///%s' % path + # Otherwise, handle all other database engines + else: + if db_type == 'mysql': + drivername = config.core.db_driver or 'mysql' + elif db_type == 'postgres': + drivername = config.core.db_driver or 'postgresql' + elif db_type == 'oracle': + drivername = config.core.db_driver or 'oracle' + elif db_type == 'mssql': + drivername = config.core.db_driver or 'mssql+pymssql' + elif db_type == 'firebird': + drivername = config.core.db_driver or 'firebird+fdb' + elif db_type == 'sybase': + drivername = config.core.db_driver or 'sybase+pysybase' + else: + raise config.ConfigurationError('Unknown db_type') + + db_user = config.core.db_user + db_pass = config.core.db_pass + db_host = config.core.db_host + db_port = config.core.db_port # Optional + db_name = config.core.db_name # Optional, depending on DB + + # Ensure we have all our variables defined + if db_user is None or db_pass is None or db_host is None: + raise DatabaseConfigurationError('Please make sure the following core ' + 'configuration values are defined: ' + 'db_user, db_pass, db_host') + self.url = URL(drivername=drivername, username=db_user, password=db_pass, + host=db_host, port=db_port, database=db_name) + + self.engine = create_engine(self.url) + + # Catch any errors connecting to database + try: + self.engine.connect() + except OperationalError: + print("OperationalError: Unable to connect to database.") + raise + + # Create our tables + BASE.metadata.create_all(self.engine) + + self.session_factory = scoped_session(sessionmaker(bind=self.engine)) def connect(self): """Return a raw database connection object.""" - return sqlite3.connect(self.filename, timeout=10) + return self.engine.connect() def execute(self, *args, **kwargs): """Execute an arbitrary SQL query against the database. @@ -60,30 +148,7 @@ def execute(self, *args, **kwargs): Returns a cursor object, on which things like `.fetchall()` can be called per PEP 249.""" with self.connect() as conn: - cur = conn.cursor() - return cur.execute(*args, **kwargs) - - def _create(self): - """Create the basic database structure.""" - self.execute( - 'CREATE TABLE IF NOT EXISTS nick_ids (nick_id INTEGER PRIMARY KEY AUTOINCREMENT)' - ) - self.execute( - 'CREATE TABLE IF NOT EXISTS nicknames ' - '(nick_id INTEGER REFERENCES nick_ids, ' - 'slug STRING PRIMARY KEY, canonical string)' - ) - self.execute( - 'CREATE TABLE IF NOT EXISTS nick_values ' - '(nick_id INTEGER REFERENCES nick_ids(nick_id), ' - 'key STRING, value STRING, ' - 'PRIMARY KEY (nick_id, key))' - ) - self.execute( - 'CREATE TABLE IF NOT EXISTS channel_values ' - '(channel STRING, key STRING, value STRING, ' - 'PRIMARY KEY (channel, key))' - ) + return conn.execute(*args, **kwargs) def get_uri(self): """Returns a URL for the database, usable to connect with SQLAlchemy.""" @@ -97,60 +162,101 @@ def get_nick_id(self, nick, create=True): This identifier is unique to a user, and shared across all of that user's aliases. If create is True, a new ID will be created if one does not already exist""" + session = self.session_factory() slug = nick.lower() - nick_id = self.execute('SELECT nick_id from nicknames where slug = ?', - [slug]).fetchone() - if nick_id is None: - if not create: - raise ValueError('No ID exists for the given nick') - with self.connect() as conn: - cur = conn.cursor() - cur.execute('INSERT INTO nick_ids VALUES (NULL)') - nick_id = cur.execute('SELECT last_insert_rowid()').fetchone()[0] - cur.execute( - 'INSERT INTO nicknames (nick_id, slug, canonical) VALUES ' - '(?, ?, ?)', - [nick_id, slug, nick] - ) - nick_id = self.execute('SELECT nick_id from nicknames where slug = ?', - [slug]).fetchone() - return nick_id[0] + try: + nickname = session.query(Nicknames).filter( + Nicknames.slug == slug + ).one_or_none() + + if nickname is None: + if not create: + raise ValueError('No ID exists for the given nick') + # Generate a new ID + nick_id = NickIDs() + session.add(nick_id) + session.commit() + + # Create a new Nickname + nickname = Nicknames(nick_id=nick_id.nick_id, slug=slug, canonical=nick) + session.add(nickname) + session.commit() + return nickname.nick_id + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() def alias_nick(self, nick, alias): """Create an alias for a nick. Raises ValueError if the alias already exists. If nick does not already exist, it will be added along with the alias.""" - nick = Identifier(nick) - alias = Identifier(alias) + nick = tools.Identifier(nick) + alias = tools.Identifier(alias) nick_id = self.get_nick_id(nick) - sql = 'INSERT INTO nicknames (nick_id, slug, canonical) VALUES (?, ?, ?)' - values = [nick_id, alias.lower(), alias] + session = self.session_factory() try: - self.execute(sql, values) - except sqlite3.IntegrityError: - raise ValueError('Alias already exists.') + result = session.query(Nicknames).filter( + Nicknames.slug == alias.lower(), + Nicknames.canonical == alias + ).one_or_none() + if result: + raise ValueError('Given alias is the only entry in its group.') + nickname = Nicknames(nick_id=nick_id, slug=alias.lower(), canonical=alias) + session.add(nickname) + session.commit() + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() def set_nick_value(self, nick, key, value): """Sets the value for a given key to be associated with the nick.""" - nick = Identifier(nick) + nick = tools.Identifier(nick) value = json.dumps(value, ensure_ascii=False) nick_id = self.get_nick_id(nick) - self.execute('INSERT OR REPLACE INTO nick_values VALUES (?, ?, ?)', - [nick_id, key, value]) + session = self.session_factory() + try: + result = session.query(NickValues).filter( + NickValues.nick_id == nick_id, + NickValues.key == key + ).one_or_none() + # NickValue exists, update + if result: + result.value = value + session.commit() + # DNE - Insert + else: + new_nickvalue = NickValues(nick_id=nick_id, key=key, value=value) + session.add(new_nickvalue) + session.commit() + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() def get_nick_value(self, nick, key): """Retrieves the value for a given key associated with a nick.""" - nick = Identifier(nick) - result = self.execute( - 'SELECT value FROM nicknames JOIN nick_values ' - 'ON nicknames.nick_id = nick_values.nick_id ' - 'WHERE slug = ? AND key = ?', - [nick.lower(), key] - ).fetchone() - if result is not None: - result = result[0] - return _deserialize(result) + nick = tools.Identifier(nick) + session = self.session_factory() + try: + result = session.query(NickValues).filter( + Nicknames.nick_id == NickValues.nick_id, + Nicknames.slug == nick.lower(), + NickValues.key == key + ).one_or_none() + if result is not None: + result = result.value + return _deserialize(result) + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() def unalias_nick(self, alias): """Removes an alias. @@ -158,20 +264,42 @@ def unalias_nick(self, alias): Raises ValueError if there is not at least one other nick in the group. To delete an entire group, use `delete_group`. """ - alias = Identifier(alias) + alias = tools.Identifier(alias) nick_id = self.get_nick_id(alias, False) - count = self.execute('SELECT COUNT(*) FROM nicknames WHERE nick_id = ?', - [nick_id]).fetchone()[0] - if count <= 1: - raise ValueError('Given alias is the only entry in its group.') - self.execute('DELETE FROM nicknames WHERE slug = ?', [alias.lower()]) + session = self.session_factory() + try: + count = session.query(Nicknames).filter( + Nicknames.nick_id == nick_id + ).count() + if count <= 1: + raise ValueError('Given alias is the only entry in its group.') + session.query(Nicknames).filter(Nicknames.slug == alias.lower()).delete() + session.commit() + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() def delete_nick_group(self, nick): """Removes a nickname, and all associated aliases and settings.""" - nick = Identifier(nick) + nick = tools.Identifier(nick) nick_id = self.get_nick_id(nick, False) - self.execute('DELETE FROM nicknames WHERE nick_id = ?', [nick_id]) - self.execute('DELETE FROM nick_values WHERE nick_id = ?', [nick_id]) + session = self.session_factory() + try: + count = session.query(Nicknames).filter( + Nicknames.nick_id == nick_id + ).count() + if count <= 1: + raise ValueError('Given alias is the only entry in its group.') + session.query(Nicknames).filter(Nicknames.nick_id == nick_id).delete() + session.query(NickValues).filter(NickValues.nick_id == nick_id).delete() + session.commit() + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() def merge_nick_groups(self, first_nick, second_nick): """Merges the nick groups for the specified nicks. @@ -184,40 +312,81 @@ def merge_nick_groups(self, first_nick, second_nick): Note that merging of data only applies to the native key-value store. If modules define their own tables which rely on the nick table, they will need to have their merging done separately.""" - first_id = self.get_nick_id(Identifier(first_nick)) - second_id = self.get_nick_id(Identifier(second_nick)) - self.execute( - 'UPDATE OR IGNORE nick_values SET nick_id = ? WHERE nick_id = ?', - [first_id, second_id]) - self.execute('DELETE FROM nick_values WHERE nick_id = ?', [second_id]) - self.execute('UPDATE nicknames SET nick_id = ? WHERE nick_id = ?', - [first_id, second_id]) + first_id = self.get_nick_id(tools.Identifier(first_nick)) + second_id = self.get_nick_id(tools.Identifier(second_nick)) + session = self.session_factory() + try: + # Get second_id's values + res = session.query(NickValues).filter(NickValues.nick_id == second_id).all() + # Update first_id with second_id values if first_id doesn't have that key + for row in res: + first_res = session.query(NickValues).filter( + NickValues.nick_id == first_id, + NickValues.key == row.key + ).one_or_none() + if not first_res: + self.set_nick_value(first_nick, row.key, _deserialize(row.value)) + session.query(NickValues).filter(NickValues.nick_id == second_id).delete() + session.query(Nicknames) \ + .filter(Nicknames.nick_id == second_id) \ + .update({'nick_id': first_id}) + session.commit() + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() # CHANNEL FUNCTIONS def set_channel_value(self, channel, key, value): """Sets the value for a given key to be associated with the channel.""" - channel = Identifier(channel).lower() + channel = tools.Identifier(channel).lower() value = json.dumps(value, ensure_ascii=False) - self.execute('INSERT OR REPLACE INTO channel_values VALUES (?, ?, ?)', - [channel, key, value]) + session = self.session_factory() + try: + result = session.query(ChannelValues).filter( + ChannelValues.channel == channel, + ChannelValues.key == key + ).one_or_none() + # ChannelValue exists, update + if result: + result.value = value + session.commit() + # DNE - Insert + else: + new_channelvalue = ChannelValues(channel=channel, key=key, value=value) + session.add(new_channelvalue) + session.commit() + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() def get_channel_value(self, channel, key): """Retrieves the value for a given key associated with a channel.""" - channel = Identifier(channel).lower() - result = self.execute( - 'SELECT value FROM channel_values WHERE channel = ? AND key = ?', - [channel, key] - ).fetchone() - if result is not None: - result = result[0] - return _deserialize(result) + channel = tools.Identifier(channel).lower() + session = self.session_factory() + try: + result = session.query(ChannelValues).filter( + ChannelValues.channel == channel, + ChannelValues.key == key + ).one_or_none() + if result is not None: + result = result.value + return _deserialize(result) + except SQLAlchemyError: + session.rollback() + raise + finally: + session.close() # NICK AND CHANNEL FUNCTIONS def get_nick_or_channel_value(self, name, key): """Gets the value `key` associated to the nick or channel `name`.""" - name = Identifier(name) + name = tools.Identifier(name) if name.is_nick(): return self.get_nick_value(name, key) else: