diff --git a/datasette/database.py b/datasette/database.py index cb01301eee..ef516e9ed9 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -159,6 +159,22 @@ def count_params(params): kwargs["count"] = count return results + async def execute_isolated_fn(self, fn): + # Open a new connection just for the duration of this function + # blocking the write queue to avoid any writes occurring during it + if self.ds.executor is None: + # non-threaded mode + isolated_connection = self.connect(write=True) + try: + result = fn(isolated_connection) + finally: + isolated_connection.close() + self._all_file_connections.remove(isolated_connection) + return result + else: + # Threaded mode - send to write thread + return await self._send_to_write_thread(fn, isolated_connection=True) + async def execute_write_fn(self, fn, block=True): if self.ds.executor is None: # non-threaded mode @@ -166,9 +182,10 @@ async def execute_write_fn(self, fn, block=True): self._write_connection = self.connect(write=True) self.ds._prepare_connection(self._write_connection, self.name) return fn(self._write_connection) + else: + return await self._send_to_write_thread(fn, block) - # threaded mode - task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") + async def _send_to_write_thread(self, fn, block=True, isolated_connection=False): if self._write_queue is None: self._write_queue = queue.Queue() if self._write_thread is None: @@ -176,8 +193,9 @@ async def execute_write_fn(self, fn, block=True): target=self._execute_writes, daemon=True ) self._write_thread.start() + task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") reply_queue = janus.Queue() - self._write_queue.put(WriteTask(fn, task_id, reply_queue)) + self._write_queue.put(WriteTask(fn, task_id, reply_queue, isolated_connection)) if block: result = await reply_queue.async_q.get() if isinstance(result, Exception): @@ -202,12 +220,24 @@ def _execute_writes(self): if conn_exception is not None: result = conn_exception else: - try: - result = task.fn(conn) - except Exception as e: - sys.stderr.write("{}\n".format(e)) - sys.stderr.flush() - result = e + if task.isolated_connection: + isolated_connection = self.connect(write=True) + try: + result = task.fn(isolated_connection) + except Exception as e: + sys.stderr.write("{}\n".format(e)) + sys.stderr.flush() + result = e + finally: + isolated_connection.close() + self._all_file_connections.remove(isolated_connection) + else: + try: + result = task.fn(conn) + except Exception as e: + sys.stderr.write("{}\n".format(e)) + sys.stderr.flush() + result = e task.reply_queue.sync_q.put(result) async def execute_fn(self, fn): @@ -515,12 +545,13 @@ def __repr__(self): class WriteTask: - __slots__ = ("fn", "task_id", "reply_queue") + __slots__ = ("fn", "task_id", "reply_queue", "isolated_connection") - def __init__(self, fn, task_id, reply_queue): + def __init__(self, fn, task_id, reply_queue, isolated_connection): self.fn = fn self.task_id = task_id self.reply_queue = reply_queue + self.isolated_connection = isolated_connection class QueryInterrupted(Exception):