Skip to content

Commit

Permalink
Merge pull request #340 from pyiron/sql
Browse files Browse the repository at this point in the history
Keep connection alive after each database transaction for a configurable amount of seconds, default is 60.  Adds a new config key `CONNECTION_TIMEOUT` that can be used to change the time before the connection is closed, set to 0 to disable.
  • Loading branch information
pmrv authored Sep 8, 2021
2 parents 1cdee4a + 0caeaef commit 442e22b
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ options can be added to the :code:`~/.pyiron`:

* :code:`JOB_TABLE` the name of the database table. pyiron is commonly using one table per user.

* :code:`CONNECTION_TIMEOUT` the time in seconds before an idle connection to the database server is closed, set to 0 to
close after every transaction. Default is 60 seconds.

A typical :code:`.pyiron` configuration with a `PostgreSQL <https://www.postgresql.org>`_ database might look like this:

.. code-block:: bash
Expand Down
120 changes: 110 additions & 10 deletions pyiron_base/database/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
DatabaseAccess class deals with accessing the database
"""

import pyiron_base.settings.logger

import numpy as np
import re
import time
Expand All @@ -26,6 +28,8 @@
from sqlalchemy.pool import NullPool
from sqlalchemy.sql import select
from sqlalchemy.exc import OperationalError, DatabaseError
from threading import Thread, Lock
from queue import SimpleQueue, Empty as QueueEmpty

__author__ = "Murat Han Celik"
__copyright__ = (
Expand All @@ -39,19 +43,112 @@
__date__ = "Sep 1, 2017"


class ConnectionWatchDog(Thread):
"""
Helper class that closes idle connections after a given timeout.
Initialize it with the connection to watch and a lock that protects it. The lock prevents the watchdog from killing
a connection that is currently used. The timeout is in seconds.
>>> conn = SqlConnection(...)
>>> lock = threading.Lock()
>>> dog = ConnectionWatchDog(conn, lock, timeout=60)
After it is created, :method:`.kick()` the watchdog periodically before the timeout runs out. It is important to
acquire the lock when using the connection object.
>>> dog.kick()
>>> with lock:
... conn.execute(...)
>>> dog.kick()
Once you want to finish the connection or want to make sure the watchdog quit, call :method:`.kill()` to shut it
down. This also causes the watch dog to try and close the connection.
>>> dog.kill()
"""

def __init__(self, conn, lock, timeout=60):
"""
Create new watchdog.
Args:
conn: any python object with a `close()` method.
lock (:class:`threading.Lock`): lock to protect conn
timeout (int): time in seconds before the watchdog closes the connection.
"""
super().__init__()
self._queue = SimpleQueue()
self._conn = conn
self._lock = lock
self._timeout = timeout

def run(self):
"""
Starts the watchdog.
"""
while True:
try:
kicked = self._queue.get(timeout=self._timeout)
except QueueEmpty:
kicked = False
if not kicked:
with self._lock:
try:
self._conn.close()
except:
pass
break

def kick(self):
"""
Restarts the timeout.
"""
self._queue.put(True)

def kill(self):
"""
Stop the watchdog and close the connection.
"""
self._queue.put(False)
self.join()


class AutorestoredConnection:
def __init__(self, engine):
def __init__(self, engine, timeout=60):
self.engine = engine
self._conn = None
self._lock = Lock()
self._watchdog = None
self._logger = pyiron_base.settings.logger.get_logger()
self._timeout = timeout

def execute(self, *args, **kwargs):
try:
if self._conn is None or self._conn.closed:
self._conn = self.engine.connect()
result = self._conn.execute(*args, **kwargs)
except OperationalError:
time.sleep(5)
result = self.execute(*args, **kwargs)
while True:
try:
if self._conn is None or self._conn.closed:
self._conn = self.engine.connect()
if self._timeout > 0:
# only log reconnections when we keep the connection alive between requests otherwise we'll spam
# the log
if self._conn is None:
self._logger.info("Reconnecting to DB; connection not existing.")
else:
self._logger.info("Reconnecting to DB; connection closed.")
if self._watchdog is not None:
# in case connection is dead, but watchdog is still up, something else killed the connection,
# make the watchdog quit, then making a new one
self._watchdog.kill()
self._watchdog = ConnectionWatchDog(self._conn, self._lock, timeout=self._timeout)
self._watchdog.start()
if self._timeout > 0:
self._watchdog.kick()
with self._lock:
result = self._conn.execute(*args, **kwargs)
break
except OperationalError as e:
print(f"Database connection failed with operational error {e}, waiting 5s, then re-trying.")
time.sleep(5)
return result

def close(self):
Expand All @@ -73,7 +170,7 @@ class DatabaseAccess(object):
Murat Han Celik
"""

def __init__(self, connection_string, table_name):
def __init__(self, connection_string, table_name, timeout=60):
"""
Initialize the Database connection
Expand All @@ -82,9 +179,11 @@ def __init__(self, connection_string, table_name):
typical form: dialect+driver://username:password@host:port/database
example: 'postgresql://scott:[email protected]/mdb'
table_name (str): database table name, a simple string like: 'simulation'
timeout (int): time in seconds before unused database connection are closed
"""
self.table_name = table_name
self._keep_connection = False
self._timeout = timeout
self._sql_lite = "sqlite" in connection_string
try:
if not self._sql_lite:
Expand All @@ -93,7 +192,8 @@ def __init__(self, connection_string, table_name):
connect_args={"connect_timeout": 15},
poolclass=NullPool,
)
self.conn = AutorestoredConnection(self._engine)
self.conn = AutorestoredConnection(self._engine, timeout=self._timeout)
self._keep_connection = self._timeout > 0
else:
self._engine = create_engine(connection_string)
self.conn = self._engine.connect()
Expand Down
7 changes: 7 additions & 0 deletions pyiron_base/job/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ def job_wrapper_function(working_directory, job_id=None, file_path=None, submit_
debug (bool): enable debug mode
submit_on_remote (bool): submit to queuing system on remote host
"""

# always close the database connection in calculations on the cluster to avoid high number of concurrent
# connections.
s.close_connection()
s.connection_timeout = 0
s.open_connection()

if job_id is not None:
job = JobWrapper(
working_directory=working_directory,
Expand Down
18 changes: 18 additions & 0 deletions pyiron_base/settings/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, config=None):
"user": "pyiron",
"resource_paths": [],
"project_paths": [],
"connection_timeout": 60,
"sql_connection_string": None,
"sql_table_name": "jobs_pyiron",
"sql_view_connection_string": None,
Expand Down Expand Up @@ -191,6 +192,20 @@ def resource_paths(self):
"""
return self._configuration["resource_paths"]

@property
def connection_timeout(self):
"""
Get the connection timeout in seconds. Zero means close the database after every connection.
Returns:
int: timeout in seconds
"""
return self._configuration["connection_timeout"]

@connection_timeout.setter
def connection_timeout(self, val):
self._configuration["connection_timeout"] = val

def open_connection(self):
"""
Internal function to open the connection to the database. Only after this function is called the database is
Expand All @@ -200,6 +215,7 @@ def open_connection(self):
self._database = DatabaseAccess(
self._configuration["sql_connection_string"],
self._configuration["sql_table_name"],
timeout=self._configuration["connection_timeout"]
)

def switch_to_local_database(self, file_name="pyiron.db", cwd=None):
Expand Down Expand Up @@ -416,10 +432,12 @@ def _config_parse_file(self, config_file):
self._configuration["sql_view_user_key"] = parser.get(
section, "VIEWERPASSWD"
)
self._configuration["connection_timeout"] = parser.get(section, "CONNECTION_TIMEOUT", fallback=60)
elif self._configuration["sql_type"] == "SQLalchemy":
self._configuration["sql_connection_string"] = parser.get(
section, "CONNECTION"
)
self._configuration["connection_timeout"] = parser.get(section, "CONNECTION_TIMEOUT", fallback=60)
else: # finally we assume an SQLite connection
if parser.has_option(section, "FILE"):
self._configuration["sql_file"] = parser.get(section, "FILE").replace(
Expand Down
10 changes: 10 additions & 0 deletions pyiron_base/settings/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,13 @@ def setup_logger():
logger.addHandler(fh)

return logger

_logger = setup_logger()
def get_logger():
"""
Return global instance of the default logger to the log file at `pyiron.log`.
This exists only to circumvent recursive imports for modules that implement functionality for :class:`.Settings`,
normal code should rely on the logger defined at :attribute:`.Settings.logger`.
"""
return _logger

0 comments on commit 442e22b

Please sign in to comment.