From af63584d3f422b382f549771bcde5fdba5a9da73 Mon Sep 17 00:00:00 2001 From: Jeff Zohrab Date: Sat, 2 Nov 2024 12:26:25 -0700 Subject: [PATCH] Pass session arg explicitly where needed. - datatables methods - lute/db data cleanup, demo, management - stats --- lute/app_factory.py | 20 ++++---- lute/book/datatables.py | 5 +- lute/book/routes.py | 10 ++-- lute/bookmarks/datatables.py | 5 +- lute/bookmarks/routes.py | 2 +- lute/db/data_cleanup.py | 13 +++-- lute/db/demo.py | 57 +++++++++++----------- lute/db/management.py | 9 ++-- lute/dev_api/routes.py | 12 ++--- lute/stats/routes.py | 5 +- lute/stats/service.py | 13 +++-- lute/term/datatables.py | 5 +- lute/term/routes.py | 6 +-- lute/termtag/datatables.py | 4 +- lute/termtag/routes.py | 2 +- lute/utils/formutils.py | 13 +++-- tests/conftest.py | 2 +- tests/unit/book/test_datatables.py | 12 ++--- tests/unit/db/test_demo.py | 32 ++++++------ tests/unit/db/test_management.py | 4 +- tests/unit/stats/test_service.py | 8 +-- tests/unit/term/test_datatables.py | 5 +- tests/unit/termtag/test_datatables.py | 3 +- tests/unit/textbookmark/test_datatables.py | 3 +- tests/unit/utils/test_formutils.py | 4 +- 25 files changed, 124 insertions(+), 130 deletions(-) diff --git a/lute/app_factory.py b/lute/app_factory.py index 892aca69..b3ee99ee 100644 --- a/lute/app_factory.py +++ b/lute/app_factory.py @@ -126,13 +126,15 @@ def inject_menu_bar_vars(): @app.route("/") def index(): - is_production = not lute.db.demo.contains_demo_data() + is_production = not lute.db.demo.contains_demo_data(db.session) bkp_settings = BackupSettings(db.session) have_books = len(db.session.query(Book).all()) > 0 have_languages = len(db.session.query(Language).all()) > 0 - language_choices = lute.utils.formutils.language_choices("(all languages)") - current_language_id = lute.utils.formutils.valid_current_language_id() + language_choices = lute.utils.formutils.language_choices( + db.session, "(all languages)" + ) + current_language_id = lute.utils.formutils.valid_current_language_id(db.session) bs = BackupService(db.session) should_run_auto_backup = bs.should_run_auto_backup(bkp_settings) @@ -154,7 +156,7 @@ def index(): hide_homelink=True, dbname=app_config.dbname, datapath=app_config.datapath, - tutorial_book_id=lute.db.demo.tutorial_book_id(), + tutorial_book_id=lute.db.demo.tutorial_book_id(db.session), have_books=have_books, have_languages=have_languages, language_choices=language_choices, @@ -176,8 +178,8 @@ def refresh_all_stats(): @app.route("/wipe_database") def wipe_db(): - if lute.db.demo.contains_demo_data(): - lute.db.demo.delete_demo_data() + if lute.db.demo.contains_demo_data(db.session): + lute.db.demo.delete_demo_data(db.session) msg = """ The database has been wiped clean. Have fun!

(Lute has automatically enabled backups -- @@ -188,8 +190,8 @@ def wipe_db(): @app.route("/remove_demo_flag") def remove_demo(): - if lute.db.demo.contains_demo_data(): - lute.db.demo.remove_flag() + if lute.db.demo.contains_demo_data(db.session): + lute.db.demo.remove_flag(db.session) msg = """ Demo mode deactivated. Have fun!

(Lute has automatically enabled backups -- @@ -310,7 +312,7 @@ def _pragmas_on_connect(dbapi_con, con_record): # pylint: disable=unused-argume db.create_all() load_settings(db.session, app_config.default_user_backup_path) # TODO valid parsers: do parser check, mark valid as active, invalid as inactive. - clean_data() + clean_data(db.session) app.db = db _add_base_routes(app, app_config) diff --git a/lute/book/datatables.py b/lute/book/datatables.py index a4b799f1..726ad1be 100644 --- a/lute/book/datatables.py +++ b/lute/book/datatables.py @@ -2,11 +2,10 @@ Show books in datatables. """ -from lute.db import db from lute.utils.data_tables import DataTablesSqliteQuery, supported_parser_type_criteria -def get_data_tables_list(parameters, is_archived): +def get_data_tables_list(parameters, is_archived, session): "Book json data for datatables." archived = "true" if is_archived else "false" @@ -70,7 +69,5 @@ def get_data_tables_list(parameters, is_archived): if language_id != 0: base_sql += f" and LgID = {language_id}" - session = db.session connection = session.connection() - return DataTablesSqliteQuery.get_data(base_sql, parameters, connection) diff --git a/lute/book/routes.py b/lute/book/routes.py index 62a69367..c2802ff5 100644 --- a/lute/book/routes.py +++ b/lute/book/routes.py @@ -46,7 +46,7 @@ def datatables_source(is_archived): # (currently unused) parameters = DataTablesFlaskParamParser.parse_params(request.form) _load_term_custom_filters(request.form, parameters) - data = get_data_tables_list(parameters, is_archived) + data = get_data_tables_list(parameters, is_archived, db.session) return jsonify(data) @@ -59,8 +59,10 @@ def datatables_active_source(): @bp.route("/archived", methods=["GET"]) def archived(): "List archived books." - language_choices = lute.utils.formutils.language_choices("(all languages)") - current_language_id = lute.utils.formutils.valid_current_language_id() + language_choices = lute.utils.formutils.language_choices( + db.session, "(all languages)" + ) + current_language_id = lute.utils.formutils.valid_current_language_id(db.session) return render_template( "book/index.html", @@ -108,7 +110,7 @@ def new(): b = _book_from_url(import_url) form = NewBookForm(obj=b) - form.language_id.choices = lute.utils.formutils.language_choices() + form.language_id.choices = lute.utils.formutils.language_choices(db.session) repo = Repository(db.session) if form.validate_on_submit(): diff --git a/lute/bookmarks/datatables.py b/lute/bookmarks/datatables.py index f70c8fcc..6dcdeac2 100644 --- a/lute/bookmarks/datatables.py +++ b/lute/bookmarks/datatables.py @@ -2,11 +2,10 @@ Show bookmarks in datatables. """ -from lute.db import db from lute.utils.data_tables import DataTablesSqliteQuery -def get_data_tables_list(parameters, book_id): +def get_data_tables_list(parameters, book_id, session): "Bookmark json data for datatables." base_sql = f""" @@ -16,7 +15,5 @@ def get_data_tables_list(parameters, book_id): WHERE tx.TxBkID = { book_id } """ - session = db.session connection = session.connection() - return DataTablesSqliteQuery.get_data(base_sql, parameters, connection) diff --git a/lute/bookmarks/routes.py b/lute/bookmarks/routes.py index 90691da0..3157afb1 100644 --- a/lute/bookmarks/routes.py +++ b/lute/bookmarks/routes.py @@ -15,7 +15,7 @@ def datatables_bookmarks(bookid): "Get datatables json for bookmarks." parameters = DataTablesFlaskParamParser.parse_params(request.form) - data = get_data_tables_list(parameters, bookid) + data = get_data_tables_list(parameters, bookid, db.session) return jsonify(data) diff --git a/lute/db/data_cleanup.py b/lute/db/data_cleanup.py index 2c7aef95..91f1c169 100644 --- a/lute/db/data_cleanup.py +++ b/lute/db/data_cleanup.py @@ -5,11 +5,10 @@ These cleanup routines will be called by the app_factory. """ -from lute.db import db from lute.models.book import Text -def _set_texts_word_count(): +def _set_texts_word_count(session): """ texts.TxWordCount should be set for all texts. @@ -18,7 +17,7 @@ def _set_texts_word_count(): Ref https://github.com/jzohrab/lute-v3/issues/95 """ - calc_counts = db.session.query(Text).filter(Text.word_count.is_(None)).all() + calc_counts = session.query(Text).filter(Text.word_count.is_(None)).all() # Don't recalc with invalid parsers!!!! recalc = [t for t in calc_counts if t.book.language.is_supported] @@ -31,10 +30,10 @@ def _set_texts_word_count(): pt = t.book.language.get_parsed_tokens(t.text) words = [w for w in pt if w.is_word] t.word_count = len(words) - db.session.add(t) - db.session.commit() + session.add(t) + session.commit() -def clean_data(): +def clean_data(session): "Clean all data as required." - _set_texts_word_count() + _set_texts_word_count(session) diff --git a/lute/db/demo.py b/lute/db/demo.py index cb530e52..6e65652c 100644 --- a/lute/db/demo.py +++ b/lute/db/demo.py @@ -13,7 +13,6 @@ from lute.book.model import Repository from lute.book.stats import Service as StatsService from lute.models.setting import SystemSettingRepository -from lute.db import db import lute.db.management @@ -39,100 +38,100 @@ def _demo_languages(): ] -def contains_demo_data(): +def contains_demo_data(session): """ True if IsDemoData setting is present. """ - repo = SystemSettingRepository(db.session) + repo = SystemSettingRepository(session) ss = repo.get_value("IsDemoData") if ss is None: return False return True -def remove_flag(): +def remove_flag(session): """ Remove IsDemoData setting. """ - if not contains_demo_data(): + if not contains_demo_data(session): raise RuntimeError("Can't delete non-demo data.") - repo = SystemSettingRepository(db.session) + repo = SystemSettingRepository(session) repo.delete_key("IsDemoData") - db.session.commit() + session.commit() -def tutorial_book_id(): +def tutorial_book_id(session): """ Return the book id of the tutorial. """ - if not contains_demo_data(): + if not contains_demo_data(session): return None sql = """select BkID from books inner join languages on LgID = BkLgID where LgName = 'English' and BkTitle = 'Tutorial' """ - r = db.session.execute(text(sql)).first() + r = session.execute(text(sql)).first() if r is None: return None return int(r[0]) -def delete_demo_data(): +def delete_demo_data(session): """ If this is a demo, wipe everything. """ - if not contains_demo_data(): + if not contains_demo_data(session): raise RuntimeError("Can't delete non-demo data.") - remove_flag() - lute.db.management.delete_all_data() + remove_flag(session) + lute.db.management.delete_all_data(session) # Loading demo data. -def load_demo_languages(): +def load_demo_languages(session): """ Load selected predefined languages. Assume everything is supported. This method will also be called during acceptance tests, so it's public. """ demo_langs = _demo_languages() - service = Service(db.session) + service = Service(session) langs = [service.get_language_def(langname)["language"] for langname in demo_langs] supported = [lang for lang in langs if lang.is_supported] for lang in supported: - db.session.add(lang) - db.session.commit() + session.add(lang) + session.commit() -def load_demo_stories(): +def load_demo_stories(session): "Load the stories." demo_langs = _demo_languages() - service = Service(db.session) + service = Service(session) langdefs = [service.get_language_def(langname) for langname in demo_langs] langdefs = [d for d in langdefs if d["language"].is_supported] - r = Repository(db.session) + r = Repository(session) for d in langdefs: for b in d["books"]: r.add(b) r.commit() - repo = SystemSettingRepository(db.session) + repo = SystemSettingRepository(session) repo.set_value("IsDemoData", True) - db.session.commit() + session.commit() - svc = StatsService(db.session) + svc = StatsService(session) svc.refresh_stats() -def load_demo_data(): +def load_demo_data(session): """ Load the data. """ - load_demo_languages() - load_demo_stories() - repo = SystemSettingRepository(db.session) + load_demo_languages(session) + load_demo_stories(session) + repo = SystemSettingRepository(session) repo.set_value("IsDemoData", True) - db.session.commit() + session.commit() diff --git a/lute/db/management.py b/lute/db/management.py index bed25058..6ed0c51a 100644 --- a/lute/db/management.py +++ b/lute/db/management.py @@ -4,11 +4,10 @@ from sqlalchemy import text from flask import current_app -from lute.db import db from lute.settings.current import load -def delete_all_data(): +def delete_all_data(session): """ DANGEROUS! Delete everything, restore user settings, clear sys settings. @@ -24,6 +23,6 @@ def delete_all_data(): "delete from settings", ] for s in statements: - db.session.execute(text(s)) - db.session.commit() - load(db.session, current_app.env_config.default_user_backup_path) + session.execute(text(s)) + session.commit() + load(session, current_app.env_config.default_user_backup_path) diff --git a/lute/dev_api/routes.py b/lute/dev_api/routes.py index 88e9f1b2..fdc3f5e3 100644 --- a/lute/dev_api/routes.py +++ b/lute/dev_api/routes.py @@ -38,7 +38,7 @@ def _ensure_is_test_db(): @bp.route("/wipe_db", methods=["GET"]) def wipe_db(): "Clean it all." - lute.db.management.delete_all_data() + lute.db.management.delete_all_data(db.session) flash("db wiped") return redirect("/", 302) @@ -46,8 +46,8 @@ def wipe_db(): @bp.route("/load_demo", methods=["GET"]) def load_demo(): "Clean out everything, and load the demo." - lute.db.management.delete_all_data() - lute.db.demo.load_demo_data() + lute.db.management.delete_all_data(db.session) + lute.db.demo.load_demo_data(db.session) flash("demo loaded") return redirect("/", 302) @@ -55,8 +55,8 @@ def load_demo(): @bp.route("/load_demo_languages", methods=["GET"]) def load_demo_languages(): "Clean out everything, and load the demo langs with dummy dictionaries." - lute.db.management.delete_all_data() - lute.db.demo.load_demo_languages() + lute.db.management.delete_all_data(db.session) + lute.db.demo.load_demo_languages(db.session) langs = db.session.query(Language).all() for lang in langs: d = lang.dictionaries[0] @@ -70,7 +70,7 @@ def load_demo_languages(): @bp.route("/load_demo_stories", methods=["GET"]) def load_demo_stories(): "Stories only. No db wipe." - lute.db.demo.load_demo_stories() + lute.db.demo.load_demo_stories(db.session) flash("stories loaded") return redirect("/", 302) diff --git a/lute/stats/routes.py b/lute/stats/routes.py index d2ecfe8f..25f577d0 100644 --- a/lute/stats/routes.py +++ b/lute/stats/routes.py @@ -4,6 +4,7 @@ from flask import Blueprint, render_template, jsonify from lute.stats.service import get_chart_data, get_table_data +from lute.db import db bp = Blueprint("stats", __name__, url_prefix="/stats") @@ -11,12 +12,12 @@ @bp.route("/") def index(): "Main page." - read_table_data = get_table_data() + read_table_data = get_table_data(db.session) return render_template("stats/index.html", read_table_data=read_table_data) @bp.route("/data") def get_data(): "Ajax call." - chartdata = get_chart_data() + chartdata = get_chart_data(db.session) return jsonify(chartdata) diff --git a/lute/stats/service.py b/lute/stats/service.py index 9613796b..146475f7 100644 --- a/lute/stats/service.py +++ b/lute/stats/service.py @@ -4,10 +4,9 @@ from datetime import datetime, timedelta from sqlalchemy import text -from lute.db import db -def _get_data_per_lang(): +def _get_data_per_lang(session): "Return dict of lang name to dict[date_yyyymmdd}: count" ret = {} sql = """ @@ -22,7 +21,7 @@ def _get_data_per_lang(): ) raw group by lang, dt """ - result = db.session.execute(text(sql)).all() + result = session.execute(text(sql)).all() for row in result: langname = row[0] if langname not in ret: @@ -53,9 +52,9 @@ def _charting_data(readbydate): return data -def get_chart_data(): +def get_chart_data(session): "Get data for chart for each language." - raw_data = _get_data_per_lang() + raw_data = _get_data_per_lang(session) chartdata = {} for k, v in raw_data.items(): chartdata[k] = _charting_data(v) @@ -90,9 +89,9 @@ def _in_range(i): } -def get_table_data(): +def get_table_data(session): "Wordcounts by lang in time intervals." - raw_data = _get_data_per_lang() + raw_data = _get_data_per_lang(session) ret = [] for langname, readbydate in raw_data.items(): diff --git a/lute/term/datatables.py b/lute/term/datatables.py index a07bfba6..36b82161 100644 --- a/lute/term/datatables.py +++ b/lute/term/datatables.py @@ -2,11 +2,10 @@ Show terms in datatables. """ -from lute.db import db from lute.utils.data_tables import DataTablesSqliteQuery, supported_parser_type_criteria -def get_data_tables_list(parameters): +def get_data_tables_list(parameters, session): "Term json data for datatables." base_sql = """SELECT @@ -93,5 +92,5 @@ def get_data_tables_list(parameters): # Phew. return DataTablesSqliteQuery.get_data( - base_sql + " WHERE " + " AND ".join(wheres), parameters, db.session.connection() + base_sql + " WHERE " + " AND ".join(wheres), parameters, session.connection() ) diff --git a/lute/term/routes.py b/lute/term/routes.py index 45063b97..db446dac 100644 --- a/lute/term/routes.py +++ b/lute/term/routes.py @@ -65,7 +65,7 @@ def datatables_active_source(): "Datatables data for terms." parameters = DataTablesFlaskParamParser.parse_params(request.form) _load_term_custom_filters(request.form, parameters) - data = get_data_tables_list(parameters) + data = get_data_tables_list(parameters, db.session) return jsonify(data) @@ -76,7 +76,7 @@ def export_terms(): _load_term_custom_filters(request.form, parameters) parameters["length"] = 1000000 outfile = os.path.join(current_app.env_config.temppath, "export_terms.csv") - data = get_data_tables_list(parameters) + data = get_data_tables_list(parameters, db.session) term_data = data["data"] # Term data is an array of dicts, with the sql field name as dict @@ -127,7 +127,7 @@ def handle_term_form( # The user opening the form is treated as an acknowledgement. term.flash_message = None - form.language_id.choices = lute.utils.formutils.language_choices() + form.language_id.choices = lute.utils.formutils.language_choices(session) if form.validate_on_submit(): form.populate_obj(term) diff --git a/lute/termtag/datatables.py b/lute/termtag/datatables.py index 9aba0b4f..353f20cc 100644 --- a/lute/termtag/datatables.py +++ b/lute/termtag/datatables.py @@ -2,11 +2,10 @@ Show terms in datatables. """ -from lute.db import db from lute.utils.data_tables import DataTablesSqliteQuery -def get_data_tables_list(parameters): +def get_data_tables_list(parameters, session): "json data for datatables." base_sql = """SELECT TgID, @@ -21,6 +20,5 @@ def get_data_tables_list(parameters): group by WtTgID ) src on src.WtTgID = TgID """ - session = db.session connection = session.connection() return DataTablesSqliteQuery.get_data(base_sql, parameters, connection) diff --git a/lute/termtag/routes.py b/lute/termtag/routes.py index caf04f97..6040bb4d 100644 --- a/lute/termtag/routes.py +++ b/lute/termtag/routes.py @@ -24,7 +24,7 @@ def index(search): def datatables_active_source(): "Datatables data for terms." parameters = DataTablesFlaskParamParser.parse_params(request.form) - data = get_data_tables_list(parameters) + data = get_data_tables_list(parameters, db.session) return jsonify(data) diff --git a/lute/utils/formutils.py b/lute/utils/formutils.py index 827ddf82..34002822 100644 --- a/lute/utils/formutils.py +++ b/lute/utils/formutils.py @@ -4,17 +4,16 @@ from lute.models.language import Language from lute.models.setting import UserSettingRepository -from lute.db import db -def language_choices(dummy_entry_placeholder="-"): +def language_choices(session, dummy_entry_placeholder="-"): """ Return the list of languages for select boxes. If only one lang exists, only return that, otherwise add a '-' dummy entry at the top. """ - langs = db.session.query(Language).order_by(Language.name).all() + langs = session.query(Language).order_by(Language.name).all() supported = [lang for lang in langs if lang.is_supported] lang_choices = [(s.id, s.name) for s in supported] # Add a dummy placeholder even if there are no languages. @@ -23,20 +22,20 @@ def language_choices(dummy_entry_placeholder="-"): return lang_choices -def valid_current_language_id(): +def valid_current_language_id(session): """ Get the current language id from UserSetting, ensuring it's still valid. If not, change it. """ - repo = UserSettingRepository(db.session) + repo = UserSettingRepository(session) current_language_id = repo.get_value("current_language_id") current_language_id = int(current_language_id) - valid_language_ids = [int(p[0]) for p in language_choices()] + valid_language_ids = [int(p[0]) for p in language_choices(session)] if current_language_id in valid_language_ids: return current_language_id current_language_id = valid_language_ids[0] repo.set_value("current_language_id", current_language_id) - db.session.commit() + session.commit() return current_language_id diff --git a/tests/conftest.py b/tests/conftest.py index dcb0574d..0b36e28e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,7 +88,7 @@ def fixture_empty_db(app_context): """ Wipe the db. """ - lute.db.management.delete_all_data() + lute.db.management.delete_all_data(db.session) @pytest.fixture(name="client") diff --git a/tests/unit/book/test_datatables.py b/tests/unit/book/test_datatables.py index 35bad76d..e93d07fd 100644 --- a/tests/unit/book/test_datatables.py +++ b/tests/unit/book/test_datatables.py @@ -35,8 +35,8 @@ def test_smoke_book_datatables_query_runs(app_context, _dt_params): """ Smoke test only, ensure query runs. """ - load_demo_stories() - get_data_tables_list(_dt_params, False) + load_demo_stories(db.session) + get_data_tables_list(_dt_params, False, db.session) # print(d['data']) a = 1 assert a == 1, "dummy check" @@ -46,12 +46,12 @@ def test_book_query_only_returns_supported_language_books(app_context, _dt_param """ Smoke test only, ensure query runs. """ - load_demo_stories() + load_demo_stories(db.session) for lang in db.session.query(Language).all(): lang.parser_type = "unknown" db.session.add(lang) db.session.commit() - d = get_data_tables_list(_dt_params, False) + d = get_data_tables_list(_dt_params, False, db.session) assert len(d["data"]) == 0, "no books should be active" @@ -63,7 +63,7 @@ def test_book_data_says_completed_if_last_page_has_been_read( db.session.add(b) db.session.commit() _dt_params["search"] = {"value": "title", "regex": False} - d = get_data_tables_list(_dt_params, False) + d = get_data_tables_list(_dt_params, False, db.session) actual = d["data"][0] assert actual["BkID"] == b.id, "correct book" assert actual["IsCompleted"] == 0, "not completed" @@ -71,7 +71,7 @@ def test_book_data_says_completed_if_last_page_has_been_read( t.read_date = datetime.now() db.session.add(t) db.session.commit() - d = get_data_tables_list(_dt_params, False) + d = get_data_tables_list(_dt_params, False, db.session) actual = d["data"][0] assert actual["BkID"] == b.id, "correct book" assert actual["IsCompleted"] == 1, "completed" diff --git a/tests/unit/db/test_demo.py b/tests/unit/db/test_demo.py index 45749823..f511fdce 100644 --- a/tests/unit/db/test_demo.py +++ b/tests/unit/db/test_demo.py @@ -18,44 +18,44 @@ def test_new_db_is_demo(app_context): "New db created from the baseline has the demo flag set." - assert contains_demo_data() is True, "new db contains demo." + assert contains_demo_data(db.session) is True, "new db contains demo." def test_removing_flag_means_not_demo(app_context): "Unsetting the flag means the db is not a demo." - remove_flag() - assert contains_demo_data() is False, "not a demo." + remove_flag(db.session) + assert contains_demo_data(db.session) is False, "not a demo." def test_wiping_db_clears_flag(app_context): "No longer a demo if the demo is wiped out!" - delete_demo_data() - assert contains_demo_data() is False, "not a demo." + delete_demo_data(db.session) + assert contains_demo_data(db.session) is False, "not a demo." def test_wipe_db_only_works_if_flag_is_set(app_context): "Can only wipe a demo db!" - remove_flag() + remove_flag(db.session) with pytest.raises(Exception): - delete_demo_data() + delete_demo_data(db.session) def test_tutorial_id_returned_if_present(app_context): "Sanity check." - assert tutorial_book_id() > 0, "have tutorial" + assert tutorial_book_id(db.session) > 0, "have tutorial" sql = 'update books set bktitle = "xxTutorial" where bktitle = "Tutorial"' db.session.execute(text(sql)) db.session.commit() - assert tutorial_book_id() is None, "no tutorial" + assert tutorial_book_id(db.session) is None, "no tutorial" sql = 'update books set bktitle = "Tutorial" where bktitle = "xxTutorial"' db.session.execute(text(sql)) db.session.commit() - assert tutorial_book_id() > 0, "have tutorial again" + assert tutorial_book_id(db.session) > 0, "have tutorial again" - delete_demo_data() - assert tutorial_book_id() is None, "no tutorial" + delete_demo_data(db.session) + assert tutorial_book_id(db.session) is None, "no tutorial" # Loading. @@ -69,12 +69,12 @@ def test_load_demo_loads_language_yaml_files(app_context): This test is also used from "inv db.reset" in tasks.py (see .pytest.ini). """ - delete_demo_data() - assert contains_demo_data() is False, "not a demo." + delete_demo_data(db.session) + assert contains_demo_data(db.session) is False, "not a demo." assert_record_count_equals("languages", 0, "wiped out") - load_demo_data() - assert contains_demo_data() is True, "demo loaded" + load_demo_data(db.session) + assert contains_demo_data(db.session) is True, "demo loaded" checks = [ "select * from languages where LgName = 'English'", "select * from books where BkTitle = 'Tutorial'", diff --git a/tests/unit/db/test_management.py b/tests/unit/db/test_management.py index d64cf857..12f4f736 100644 --- a/tests/unit/db/test_management.py +++ b/tests/unit/db/test_management.py @@ -20,7 +20,7 @@ def test_wiping_db_clears_out_all_tables(app_context): """ old_user_settings = db.session.query(UserSetting).all() - delete_all_data() + delete_all_data(db.session) tables = [ "books", "bookstats", @@ -47,7 +47,7 @@ def test_wiping_db_clears_out_all_tables(app_context): def test_can_get_backup_settings_when_db_is_wiped(app_context): "The backupsettings struct assumes certain things about the data." - delete_all_data() + delete_all_data(db.session) bs = BackupSettings(db.session) assert bs.backup_enabled, "backup is back to being enabled" assert bs.backup_dir is not None, "default restored" diff --git a/tests/unit/stats/test_service.py b/tests/unit/stats/test_service.py index ae0bf1c5..c306689f 100644 --- a/tests/unit/stats/test_service.py +++ b/tests/unit/stats/test_service.py @@ -48,7 +48,7 @@ def test_get_chart_data(spanish, english, app_context): {"readdate": today.strftime("%Y-%m-%d"), "wordcount": 2, "runningTotal": 2}, ], } - assert get_chart_data() == expected + assert get_chart_data(db.session) == expected def test_get_table_data(spanish, english, app_context): @@ -71,11 +71,11 @@ def test_get_table_data(spanish, english, app_context): "counts": {"day": 4, "week": 7, "month": 7, "year": 7, "total": 7}, }, ] - actual = get_table_data() + actual = get_table_data(db.session) assert actual == expected def test_get_data_works_when_nothing_read(app_context): "Nothing read should still be ok, empty chart." - assert not get_chart_data(), "nothing present" - assert not get_table_data(), "nothing" + assert not get_chart_data(db.session), "nothing present" + assert not get_table_data(db.session), "nothing" diff --git a/tests/unit/term/test_datatables.py b/tests/unit/term/test_datatables.py index 9836fd4b..ea8b07d5 100644 --- a/tests/unit/term/test_datatables.py +++ b/tests/unit/term/test_datatables.py @@ -4,6 +4,7 @@ import pytest from lute.term.datatables import get_data_tables_list +from lute.db import db @pytest.fixture(name="_dt_params") @@ -38,7 +39,7 @@ def test_smoke_term_datatables_query_runs(app_context, _dt_params): """ Smoke test only, ensure query runs. """ - get_data_tables_list(_dt_params) + get_data_tables_list(_dt_params, db.session) # print(d['data']) a = 1 assert a == 1, "dummy check" @@ -53,4 +54,4 @@ def test_smoke_query_with_filter_params_runs(app_context, _dt_params): _dt_params["filtStatusMin"] = "2" _dt_params["filtStatusMax"] = "4" _dt_params["filtIncludeIgnored"] = "true" - get_data_tables_list(_dt_params) + get_data_tables_list(_dt_params, db.session) diff --git a/tests/unit/termtag/test_datatables.py b/tests/unit/termtag/test_datatables.py index 521e1f03..03ec9d6f 100644 --- a/tests/unit/termtag/test_datatables.py +++ b/tests/unit/termtag/test_datatables.py @@ -3,6 +3,7 @@ """ from lute.termtag.datatables import get_data_tables_list +from lute.db import db def test_smoke_datatables_query_runs(app_context): @@ -22,7 +23,7 @@ def test_smoke_datatables_query_runs(app_context): "search": {"value": "", "regex": False}, } - d = get_data_tables_list(params) + d = get_data_tables_list(params, db.session) print(d) a = 1 assert a == 1, "dummy check" diff --git a/tests/unit/textbookmark/test_datatables.py b/tests/unit/textbookmark/test_datatables.py index 0e1655c4..c4d9650b 100644 --- a/tests/unit/textbookmark/test_datatables.py +++ b/tests/unit/textbookmark/test_datatables.py @@ -4,6 +4,7 @@ import pytest from lute.bookmarks.datatables import get_data_tables_list +from lute.db import db @pytest.fixture(name="_dt_params") @@ -28,5 +29,5 @@ def test_smoke_term_datatables_query_runs(app_context, _dt_params): """ Smoke test only, ensure query runs. """ - data = get_data_tables_list(_dt_params, 1) + data = get_data_tables_list(_dt_params, 1, db.session) assert data is not None diff --git a/tests/unit/utils/test_formutils.py b/tests/unit/utils/test_formutils.py index 5bf44876..8a0b4a89 100644 --- a/tests/unit/utils/test_formutils.py +++ b/tests/unit/utils/test_formutils.py @@ -9,7 +9,7 @@ def test_language_choices(app_context): "Gets all languages." - choices = language_choices() + choices = language_choices(db.session) assert choices[0][1] == "-", "- at the top" langnames = [c[1] for c in choices] assert "Spanish" in langnames, "have Spanish" @@ -20,6 +20,6 @@ def test_valid_current_language_id(app_context): repo = UserSettingRepository(db.session) repo.set_value("current_language_id", 9999) assert int(repo.get_value("current_language_id")) == 9999, "pre-check" - curr_lang_id = int(valid_current_language_id()) + curr_lang_id = int(valid_current_language_id(db.session)) assert curr_lang_id == 0, "set back to 0" assert int(repo.get_value("current_language_id")) == 0, "re-set to 0"