diff --git a/CHANGELOG.md b/CHANGELOG.md index b7c0104fa..014de0b65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - #710, #561 Implement `except*` syntax (@lieryan) - #711 allow building documentation without having rope module installed (@kloczek) +- #720 create one sqlite3.Connection per thread using a thread local (@tkrabel) # Release 1.10.0 diff --git a/rope/contrib/autoimport/sqlite.py b/rope/contrib/autoimport/sqlite.py index eb7c27ded..9cc9647b5 100644 --- a/rope/contrib/autoimport/sqlite.py +++ b/rope/contrib/autoimport/sqlite.py @@ -11,6 +11,7 @@ from datetime import datetime from itertools import chain from pathlib import Path +from threading import local from typing import Generator, Iterable, Iterator, List, Optional, Set, Tuple from rope.base import exceptions, libutils, resourceobserver, taskhandle, versioning @@ -78,9 +79,10 @@ class AutoImport: """ connection: sqlite3.Connection - underlined: bool + memory: bool project: Project project_package: Package + underlined: bool def __init__( self, @@ -114,8 +116,9 @@ def __init__( assert project_package.path is not None self.project_package = project_package self.underlined = underlined + self.memory = memory if memory is _deprecated_default: - memory = True + self.memory = True warnings.warn( "The default value for `AutoImport(memory)` argument will " "change to use an on-disk database by default in the future. " @@ -123,6 +126,7 @@ def __init__( "`AutoImport(memory=True)` explicitly.", DeprecationWarning, ) + self.thread_local = local() self.connection = self.create_database_connection( project=project, memory=memory, @@ -158,6 +162,24 @@ def create_database_connection( db_path = str(Path(project.ropefolder.real_path) / "autoimport.db") return sqlite3.connect(db_path) + @property + def connection(self): + """ + Creates a new connection if called from a new thread. + + This makes sure AutoImport can be shared across threads. + """ + if not hasattr(self.thread_local, "connection"): + self.thread_local.connection = self.create_database_connection( + project=self.project, + memory=self.memory, + ) + return self.thread_local.connection + + @connection.setter + def connection(self, value: sqlite3.Connection): + self.thread_local.connection = value + def _setup_db(self): models.Metadata.create_table(self.connection) version_hash = list( diff --git a/ropetest/conftest.py b/ropetest/conftest.py index d2efc68c6..32a35aaa6 100644 --- a/ropetest/conftest.py +++ b/ropetest/conftest.py @@ -18,6 +18,13 @@ def project_path(project): yield pathlib.Path(project.address) +@pytest.fixture +def project2(): + project = testutils.sample_project("sample_project2") + yield project + testutils.remove_project(project) + + """ Standard project structure for pytest fixtures /mod1.py -- mod1 diff --git a/ropetest/contrib/autoimport/autoimporttest.py b/ropetest/contrib/autoimport/autoimporttest.py index 4b0430119..d65d8b2bf 100644 --- a/ropetest/contrib/autoimport/autoimporttest.py +++ b/ropetest/contrib/autoimport/autoimporttest.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor from contextlib import closing, contextmanager from textwrap import dedent from unittest.mock import ANY, patch @@ -85,6 +86,37 @@ def foo(): assert [("from pkg1 import foo", "foo")] == results +def test_multithreading( + autoimport: AutoImport, + project: Project, + pkg1: Folder, + mod1: File, +): + mod1_init = pkg1.get_child("__init__.py") + mod1_init.write(dedent("""\ + def foo(): + pass + """)) + mod1.write(dedent("""\ + foo + """)) + autoimport = AutoImport(project, memory=False) + autoimport.generate_cache([mod1_init]) + + tp = ThreadPoolExecutor(1) + results = tp.submit(autoimport.search, "foo", True).result() + assert [("from pkg1 import foo", "foo")] == results + + +def test_connection(project: Project, project2: Project): + ai1 = AutoImport(project) + ai2 = AutoImport(project) + ai3 = AutoImport(project2) + + assert ai1.connection is not ai2.connection + assert ai1.connection is not ai3.connection + + @contextmanager def assert_database_is_reset(conn): conn.execute("ALTER TABLE names ADD COLUMN deprecated_column")