diff --git a/.github/scripts/build_docs.sh b/.github/scripts/build_docs.sh index 6a50c20..70540ed 100755 --- a/.github/scripts/build_docs.sh +++ b/.github/scripts/build_docs.sh @@ -4,30 +4,32 @@ # Rebuild Sphinx docs from scratch and check generated files match those checked # in -set -eux -o pipefail +set -euo pipefail + +THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +PROJECT_ROOT_DIR=${THIS_DIR}/../.. +DOCS_DIR=${PROJECT_ROOT_DIR}/docs +VENV_DIR=${HOME}/venv + +cd "${PROJECT_ROOT_DIR}" sudo apt-get install texlive-latex-extra dvipng -python -m venv "${HOME}/venv" -source "${HOME}/venv/bin/activate" -python -VV -python -m site -python -m pip install -U pip -echo installing pip packages -python -m pip install -r docs/docs_requirements.txt -python -m pip install -e . +PYTHON=${VENV_DIR}/bin/python +echo Installing pip packages for docs +${PYTHON} -m pip install -r "${DOCS_DIR}/docs_requirements.txt" ######################################################################################## -cd "${GITHUB_WORKSPACE}/docs" echo Creating autodocs -python ./create_all_autodocs.py --make --destroy_first -cd "${GITHUB_WORKSPACE}" +${PYTHON} "${DOCS_DIR}/create_all_autodocs.py" --make --destroy_first echo Checking if files generated by create_all_autodocs need to be checked in git diff git update-index --refresh git diff-index --quiet HEAD -- test -z "$(git ls-files --exclude-standard --others)" -cd docs echo Rebuilding docs -python ./rebuild_docs.py --warnings_as_errors + +# Have to be in the virtualenv for sphinx-build to be picked up. +source "${VENV_DIR}"/bin/activate +python "${DOCS_DIR}/rebuild_docs.py" --warnings_as_errors echo Checking if files generated by rebuild_docs need to be checked in git diff git update-index --refresh diff --git a/.github/scripts/create_virtualenv.sh b/.github/scripts/create_virtualenv.sh new file mode 100755 index 0000000..6d7a45e --- /dev/null +++ b/.github/scripts/create_virtualenv.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SYSTEM_PYTHON=python3 + +if [ $# -eq 1 ]; then + # Script needs at least Python 3.10 for docs. Allow this to be specified + SYSTEM_PYTHON=$1 +fi + +VENV_DIR=${HOME}/venv + +${SYSTEM_PYTHON} -m venv "${VENV_DIR}" +PYTHON=${VENV_DIR}/bin/python +${PYTHON} -VV +${PYTHON} -m site +${PYTHON} -m pip install -U pip setuptools +echo Dumping pre-installed packages +${PYTHON} -m pip freeze diff --git a/.github/scripts/install_base_python_packages.sh b/.github/scripts/install_base_python_packages.sh new file mode 100755 index 0000000..3a17cca --- /dev/null +++ b/.github/scripts/install_base_python_packages.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -euo pipefail + +THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +PROJECT_ROOT_DIR=${THIS_DIR}/../.. +VENV_DIR=${HOME}/venv + +PYTHON=${VENV_DIR}/bin/python +echo Installing pip packages +${PYTHON} -m pip install -e "${PROJECT_ROOT_DIR}" diff --git a/.github/scripts/install_test_python_packages.sh b/.github/scripts/install_test_python_packages.sh new file mode 100755 index 0000000..12eb372 --- /dev/null +++ b/.github/scripts/install_test_python_packages.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -euo pipefail + +VENV_DIR=${HOME}/venv + +PYTHON=${VENV_DIR}/bin/python +${PYTHON} -m pip install "numpy<1.23" # 1.23 incompatible with numba +${PYTHON} -m pip install xlrd +${PYTHON} -m pip install dogpile.cache==0.9.2 # Later versions incompatible +${PYTHON} -m pip install pytest +${PYTHON} -m pip install xhtml2pdf weasyprint pdfkit # For PDF tests diff --git a/.github/scripts/python_checks.sh b/.github/scripts/python_checks.sh index 0f13bf9..fc017e8 100755 --- a/.github/scripts/python_checks.sh +++ b/.github/scripts/python_checks.sh @@ -2,16 +2,19 @@ # Run from .github/workflows/python_checks.yml -set -eux -o pipefail +set -euo pipefail -python3 -m venv "${HOME}/venv" -source "${HOME}/venv/bin/activate" -python -m site -python -m pip install -U pip setuptools +THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +PROJECT_ROOT_DIR=${THIS_DIR}/../.. +VENV_DIR=${HOME}/venv +PYTHON=${VENV_DIR}/bin/python +PRECOMMIT=${VENV_DIR}/bin/pre-commit -echo installing pip packages -python -m pip install -e . -python -m pip install pre-commit +cd "${PROJECT_ROOT_DIR}" -echo running pre-commit checks -pre-commit run --all-files +echo Installing pre-commit +${PYTHON} -m pip install -e . +${PYTHON} -m pip install pre-commit + +echo Running pre-commit checks +${PRECOMMIT} run --all-files diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index f0bb8e1..cf16803 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -2,22 +2,16 @@ # Run from .github/workflows/tests.yml -set -eux -o pipefail +set -euo pipefail -sudo apt-get install wkhtmltopdf +THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +PROJECT_ROOT_DIR=${THIS_DIR}/../.. +VENV_DIR=${HOME}/venv +PYTEST=${VENV_DIR}/bin/pytest -python -m venv "${HOME}/venv" -source "${HOME}/venv/bin/activate" -python -m site -python -m pip install -U pip -echo installing pip packages +sudo apt-get install wkhtmltopdf -python -m pip install "numpy<1.23" # 1.23 incompatible with numba -python -m pip install xlrd -python -m pip install dogpile.cache==0.9.2 # Later versions incompatible -python -m pip install pytest -python -m pip install xhtml2pdf weasyprint pdfkit # For PDF tests -python -m pip install -e . +cd "${PROJECT_ROOT_DIR}" # pytest --log-cli-level=INFO -pytest +${PYTEST} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c2d2c23..a4a8415 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -6,8 +6,27 @@ on: push jobs: build-docs: - runs-on: ubuntu-latest + strategy: + matrix: + include: + - name: ubuntu-22.04 + os: ubuntu-22.04 + python-version: "3.10" + runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Create virtualenv + run: ${GITHUB_WORKSPACE}/.github/scripts/create_virtualenv.sh + - name: Install base Python packages + run: ${GITHUB_WORKSPACE}/.github/scripts/install_base_python_packages.sh - name: Build docs run: ${GITHUB_WORKSPACE}/.github/scripts/build_docs.sh + - name: Dump stuff on failure + if: failure() + run: | + set -euxo pipefail + ls -l ${HOME}/venv/bin + ${HOME}/venv/bin/python -m pip freeze diff --git a/.github/workflows/python_checks.yml b/.github/workflows/python_checks.yml index a7e9cf0..d3f8316 100644 --- a/.github/workflows/python_checks.yml +++ b/.github/workflows/python_checks.yml @@ -6,18 +6,24 @@ on: push: paths: - '**.py' - - .github/workflows/python-checks.yml + - .github/scripts/create_virtualenv.sh + - .github/scripts/install_base_python_packages.sh - .github/scripts/python_checks.sh + - .github/workflows/python-checks.yml jobs: python-checks: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10"] steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Create virtualenv + run: ${GITHUB_WORKSPACE}/.github/scripts/create_virtualenv.sh + - name: Install base Python packages + run: ${GITHUB_WORKSPACE}/.github/scripts/install_base_python_packages.sh - name: Python checks run: ${GITHUB_WORKSPACE}/.github/scripts/python_checks.sh diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 4d561e9..2cc89ef 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -1,28 +1,35 @@ --- # yamllint disable rule:line-length -name: Tests +name: Run tests # yamllint disable-line rule:truthy on: push: paths: - '**.py' - .github/scripts/change_apt_mirror.sh - - .github/workflows/run_tests.yml + - .github/scripts/create_virtualenv.sh + - .github/scripts/install_base_python_packages.sh + - .github/scripts/install_test_python_packages.sh - .github/scripts/run_tests.sh + - .github/workflows/run_tests.yml jobs: - pip-install-and-tests: + run-tests: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10"] steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Change apt mirror - run: | - set -euxo pipefail - ${GITHUB_WORKSPACE}/.github/scripts/change_apt_mirror.sh + run: ${GITHUB_WORKSPACE}/.github/scripts/change_apt_mirror.sh + - name: Create virtualenv + run: ${GITHUB_WORKSPACE}/.github/scripts/create_virtualenv.sh + - name: Install test Python packages + run: ${GITHUB_WORKSPACE}/.github/scripts/install_test_python_packages.sh + - name: Install base Python packages + run: ${GITHUB_WORKSPACE}/.github/scripts/install_base_python_packages.sh - name: Run tests run: ${GITHUB_WORKSPACE}/.github/scripts/run_tests.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc00319..2c39bb5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,3 +18,11 @@ repos: rev: 5.0.4 hooks: - id: flake8 +- repo: https://github.com/asottile/yesqa + rev: v1.5.0 + hooks: + - id: yesqa +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.9.0 + hooks: + - id: python-check-blanket-noqa diff --git a/cardinal_pythonlib/argparse_func.py b/cardinal_pythonlib/argparse_func.py index cbc9f0c..e1d1dd1 100644 --- a/cardinal_pythonlib/argparse_func.py +++ b/cardinal_pythonlib/argparse_func.py @@ -55,7 +55,7 @@ class ShowAllSubparserHelpAction(_HelpAction): shows help for all subparsers. As per https://stackoverflow.com/questions/20094215/argparse-subparser-monolithic-help-output - """ # noqa: E501 + """ def __call__( self, @@ -147,7 +147,7 @@ def str2bool(v: str) -> bool: default=NICE, # if the argument is entirely absent help="Activate nice mode.") - """ # noqa: E501 + """ lv = v.lower() if lv in ("yes", "true", "t", "y", "1"): return True diff --git a/cardinal_pythonlib/athena_ohdsi.py b/cardinal_pythonlib/athena_ohdsi.py index 4906ae5..8480213 100644 --- a/cardinal_pythonlib/athena_ohdsi.py +++ b/cardinal_pythonlib/athena_ohdsi.py @@ -366,7 +366,7 @@ def get_athena_concepts( timeit.timeit(concept_testcode, number=1, globals=globals()) # After speedup: 3.9 s for 1.1m rows. - """ # noqa + """ # noqa: E501 assert bool(tsv_filename) != bool( cached_concepts ), "Specify either tsv_filename or cached_concepts" @@ -461,7 +461,7 @@ def get_athena_concept_relationships( tsv_filename: str = "", cached_concept_relationships: Iterable[ AthenaConceptRelationshipRow - ] = None, # noqa + ] = None, concept_id_1_values: Collection[int] = None, concept_id_2_values: Collection[int] = None, relationship_id_values: Collection[str] = None, diff --git a/cardinal_pythonlib/betweendict.py b/cardinal_pythonlib/betweendict.py index b7e762e..ad97f18 100644 --- a/cardinal_pythonlib/betweendict.py +++ b/cardinal_pythonlib/betweendict.py @@ -72,7 +72,7 @@ class BetweenDict(dict): ... NB has initialization default argument bug - https://pypi.python.org/pypi/rangedict/0.1.5 - https://stackoverflow.com/questions/30254739/is-there-a-library-implemented-rangedict-in-python - """ # noqa + """ # noqa: E501 INVALID_MSG_TYPE = "Key must be an iterable with length 2" INVALID_MSG_VALUE = "First element of key must be less than second element" diff --git a/cardinal_pythonlib/bulk_email/constants.py b/cardinal_pythonlib/bulk_email/constants.py index 94a100a..cdd8235 100644 --- a/cardinal_pythonlib/bulk_email/constants.py +++ b/cardinal_pythonlib/bulk_email/constants.py @@ -31,7 +31,7 @@ CONTENT_TYPE_MAX_LENGTH = 255 # Can be quite long; see cardinal_pythonlib.httpconst.MimeType # 255 is the formal limit: -# https://stackoverflow.com/questions/643690/maximum-mimetype-length-when-storing-type-in-db # noqa +# https://stackoverflow.com/questions/643690/maximum-mimetype-length-when-storing-type-in-db # noqa: E501 DEFAULT_TIME_BETWEEN_EMAILS_S = 0.5 diff --git a/cardinal_pythonlib/bulk_email/main.py b/cardinal_pythonlib/bulk_email/main.py index f4a1699..7757574 100644 --- a/cardinal_pythonlib/bulk_email/main.py +++ b/cardinal_pythonlib/bulk_email/main.py @@ -531,9 +531,9 @@ def main() -> None: f"via the environment variable {DB_URL_ENVVAR}" ) sys.exit(EXIT_FAILURE) - engine = create_engine(db_url, echo=args.echo) + engine = create_engine(db_url, echo=args.echo, future=True) log.info(f"Using database: {get_safe_url_from_engine(engine)}") - session = Session(engine) + session = Session(engine, future=True) # ------------------------------------------------------------------------- # Launch subcommand diff --git a/cardinal_pythonlib/bulk_email/models.py b/cardinal_pythonlib/bulk_email/models.py index 51f38ce..c6d210f 100644 --- a/cardinal_pythonlib/bulk_email/models.py +++ b/cardinal_pythonlib/bulk_email/models.py @@ -113,7 +113,7 @@ def make_table_args(*args, **kwargs) -> Tuple[Any]: # noinspection PyUnusedLocal @listens_for(Session, "before_flush") def before_flush(session: Session, flush_context, instances) -> None: - # https://docs.sqlalchemy.org/en/14/orm/events.html#sqlalchemy.orm.SessionEvents.before_flush # noqa + # https://docs.sqlalchemy.org/en/14/orm/events.html#sqlalchemy.orm.SessionEvents.before_flush # noqa: E501 for instance in session.dirty: if not isinstance(instance, Config): continue @@ -136,7 +136,7 @@ class Config(Base): - https://docs.sqlalchemy.org/en/14/_modules/examples/versioned_rows/versioned_rows.html - https://docs.sqlalchemy.org/en/14/orm/examples.html#module-examples.versioned_rows - """ # noqa + """ # noqa: E501 __tablename__ = "config" __table_args__ = make_table_args(comment="Stores configuration records.") @@ -148,9 +148,9 @@ class Config(Base): config_id = Column( Integer, # ... may not be supported by all databases; see - # https://docs.sqlalchemy.org/en/14/core/constraints.html#check-constraint # noqa + # https://docs.sqlalchemy.org/en/14/core/constraints.html#check-constraint # noqa: E501 # ... though MySQL has recently added this: - # https://dev.mysql.com/doc/refman/8.0/en/create-table-check-constraints.html # noqa + # https://dev.mysql.com/doc/refman/8.0/en/create-table-check-constraints.html # noqa: E501 primary_key=True, autoincrement=True, comment="Primary key.", @@ -457,7 +457,7 @@ def _pending_job_query(session: Session) -> CountStarSpecializedQuery: .where( and_( SendAttempt.job_id == Job.job_id, - SendAttempt.success == True, # noqa + SendAttempt.success == True, # noqa: E712 ) ) ) @@ -474,7 +474,7 @@ def n_completed_jobs(session: Session) -> int: query = ( CountStarSpecializedQuery([Job], session=session) .join(SendAttempt) - .filter(SendAttempt.success == True) # noqa + .filter(SendAttempt.success == True) # noqa: E712 ) return query.count_star() diff --git a/cardinal_pythonlib/chebi.py b/cardinal_pythonlib/chebi.py index 892124f..764c96c 100644 --- a/cardinal_pythonlib/chebi.py +++ b/cardinal_pythonlib/chebi.py @@ -105,7 +105,7 @@ agomelatine, antidepressant -""" # noqa +""" # noqa: E501 import argparse import csv @@ -911,7 +911,7 @@ def add_entities(p: argparse.ArgumentParser) -> None: add_entities(parser_describe) add_exact(parser_describe) parser_describe.set_defaults( - func=lambda args: search_and_describe_multiple( # noqa + func=lambda args: search_and_describe_multiple( search_terms=args.entity, exact_search=args.exact_search, exact_match=args.exact_match, diff --git a/cardinal_pythonlib/classes.py b/cardinal_pythonlib/classes.py index e5bc1ad..61b006b 100644 --- a/cardinal_pythonlib/classes.py +++ b/cardinal_pythonlib/classes.py @@ -105,7 +105,7 @@ def derived_class_implements_method( # ============================================================================= # Subclasses # ============================================================================= -# https://stackoverflow.com/questions/3862310/how-can-i-find-all-subclasses-of-a-class-given-its-name # noqa +# https://stackoverflow.com/questions/3862310/how-can-i-find-all-subclasses-of-a-class-given-its-name # noqa: E501 def gen_all_subclasses(cls: Type) -> Generator[Type, None, None]: diff --git a/cardinal_pythonlib/cmdline.py b/cardinal_pythonlib/cmdline.py index 2de729e..e62b1fa 100644 --- a/cardinal_pythonlib/cmdline.py +++ b/cardinal_pythonlib/cmdline.py @@ -48,7 +48,7 @@ def cmdline_split(s: str, platform: Union[int, str] = "this") -> List[str]: - ``1`` = POSIX; - ``0`` = Windows/CMD - (other values reserved) - """ # noqa: E501 + """ if platform == "this": platform = sys.platform != "win32" # RNC: includes 64-bit Windows @@ -78,7 +78,7 @@ def cmdline_split(s: str, platform: Union[int, str] = "this") -> List[str]: elif qs: word = qs.replace(r"\"", '"').replace(r"\\", "\\") # ... raw strings can't end in single backslashes; - # https://stackoverflow.com/questions/647769/why-cant-pythons-raw-string-literals-end-with-a-single-backslash # noqa + # https://stackoverflow.com/questions/647769/why-cant-pythons-raw-string-literals-end-with-a-single-backslash # noqa: E501 if platform == 0: word = word.replace('""', '"') else: diff --git a/cardinal_pythonlib/compression.py b/cardinal_pythonlib/compression.py index 05ffac4..bbd4a77 100644 --- a/cardinal_pythonlib/compression.py +++ b/cardinal_pythonlib/compression.py @@ -88,7 +88,7 @@ def gzip_string(text: str, encoding: str = "utf-8") -> bytes: print(gz1 == gz2) # False # ... but the difference is probably in the timestamp bytes! - """ # noqa: E501 + """ data = text.encode(encoding) return gzip.compress(data) diff --git a/cardinal_pythonlib/datetimefunc.py b/cardinal_pythonlib/datetimefunc.py index 134845e..2e5db36 100644 --- a/cardinal_pythonlib/datetimefunc.py +++ b/cardinal_pythonlib/datetimefunc.py @@ -307,7 +307,7 @@ def strfdelta( Modified from https://stackoverflow.com/questions/538666/python-format-timedelta-to-string - """ # noqa + """ # Convert tdelta to integer seconds. if inputtype == "timedelta": @@ -639,7 +639,7 @@ def duration_to_iso( realistic (negative, 1000 years, 11 months, and the maximum length for seconds/microseconds). - """ # noqa + """ prefix = "" negative = d < Duration() if negative and minus_sign_at_front: diff --git a/cardinal_pythonlib/dicts.py b/cardinal_pythonlib/dicts.py index 02816f2..3beeefb 100644 --- a/cardinal_pythonlib/dicts.py +++ b/cardinal_pythonlib/dicts.py @@ -120,7 +120,7 @@ def rename_keys_in_dict(d: Dict[str, Any], renames: Dict[str, str]) -> None: See https://stackoverflow.com/questions/4406501/change-the-name-of-a-key-in-dictionary. - """ # noqa + """ for old_key, new_key in renames.items(): if new_key == old_key: continue @@ -257,7 +257,7 @@ class LazyDict(dict): The ``*args``/``**kwargs`` parts are useful, but we don't want to have to name 'thunk' explicitly. - """ # noqa + """ def get( self, key: Hashable, thunk: Any = None, *args: Any, **kwargs: Any @@ -292,7 +292,7 @@ class LazyButHonestDict(dict): Compared to the StackOverflow version: no obvious need to have a default returning ``None``, when we're implementing this as a special function. In contrast, helpful to have ``*args``/``**kwargs`` options. - """ # noqa + """ def lazyget( self, key: Hashable, thunk: Callable, *args: Any, **kwargs: Any @@ -365,7 +365,7 @@ class CaseInsensitiveDict(dict): d1.update([('K', 11), ('L', 12)]) d1 # {'e': 5, 'f': 6, 'g': 7, 'j': 10, 'k': 11, 'l': 12} - """ # noqa + """ # noqa: E501 @classmethod def _k(cls, key: Any) -> Any: diff --git a/cardinal_pythonlib/django/admin.py b/cardinal_pythonlib/django/admin.py index c21db94..88fd6d2 100644 --- a/cardinal_pythonlib/django/admin.py +++ b/cardinal_pythonlib/django/admin.py @@ -41,7 +41,7 @@ # ============================================================================= # Disable boolean icons for a ModelAdmin field # ============================================================================= -# https://stackoverflow.com/questions/13990846/disable-on-off-icon-for-boolean-field-in-django # noqa +# https://stackoverflow.com/questions/13990846/disable-on-off-icon-for-boolean-field-in-django # noqa: E501 # ... extended to use closures @@ -114,7 +114,7 @@ def admin_view_fk_link( app_name = linked_obj._meta.app_label.lower() model_name = linked_obj._meta.object_name.lower() viewname = f"admin:{app_name}_{model_name}_{view_type}" - # https://docs.djangoproject.com/en/dev/ref/contrib/admin/#reversing-admin-urls # noqa + # https://docs.djangoproject.com/en/dev/ref/contrib/admin/#reversing-admin-urls # noqa: E501 if current_app is None: current_app = modeladmin.admin_site.name # ... plus a bit of home-grown magic; see Django source diff --git a/cardinal_pythonlib/django/django_constants.py b/cardinal_pythonlib/django/django_constants.py index 8a15c82..b0a5002 100644 --- a/cardinal_pythonlib/django/django_constants.py +++ b/cardinal_pythonlib/django/django_constants.py @@ -36,11 +36,11 @@ class ConnectionVendors(object): ORACLE = "oracle" # built in; [1] POSTGRESQL = "postgresql" # built in; [1] SQLITE = "sqlite" # built in; [1] - # [1] https://docs.djangoproject.com/en/1.10/howto/custom-lookups/#writing-alternative-implementations-for-existing-lookups # noqa + # [1] https://docs.djangoproject.com/en/1.10/howto/custom-lookups/#writing-alternative-implementations-for-existing-lookups # noqa: E501 # I think this is HYPOTHETICAL: SQLSERVER = 'sqlserver' # [2] # [2] https://docs.djangoproject.com/en/1.11/ref/models/expressions/ MICROSOFT = "microsoft" # [3] # [3] "pip install django-mssql" = sqlserver_ado; - # https://bitbucket.org/Manfre/django-mssql/src/d44721ba17acf95da89f06bd7270dabc1cd33deb/sqlserver_ado/base.py?at=master&fileviewer=file-view-default # noqa + # https://bitbucket.org/Manfre/django-mssql/src/d44721ba17acf95da89f06bd7270dabc1cd33deb/sqlserver_ado/base.py?at=master&fileviewer=file-view-default # noqa: E501 diff --git a/cardinal_pythonlib/django/fields/isodatetimetz.py b/cardinal_pythonlib/django/fields/isodatetimetz.py index 7c17bba..fbb2b65 100644 --- a/cardinal_pythonlib/django/fields/isodatetimetz.py +++ b/cardinal_pythonlib/django/fields/isodatetimetz.py @@ -185,9 +185,9 @@ class IsoDateTimeTzField(models.CharField): https://docs.djangoproject.com/en/1.8/ref/databases/#fractional-seconds-support-for-time-and-datetime-fields - """ # noqa + """ # noqa: E501 - # https://docs.djangoproject.com/en/1.8/ref/models/fields/#field-api-reference # noqa + # https://docs.djangoproject.com/en/1.8/ref/models/fields/#field-api-reference # noqa: E501 description = "ISO-8601 date/time field with timezone, stored as text" diff --git a/cardinal_pythonlib/django/fields/jsonclassfield.py b/cardinal_pythonlib/django/fields/jsonclassfield.py index 062d99f..43fbaf5 100644 --- a/cardinal_pythonlib/django/fields/jsonclassfield.py +++ b/cardinal_pythonlib/django/fields/jsonclassfield.py @@ -122,7 +122,7 @@ def my_decoder_hook(d: Dict) -> Any: print(repr(x)) print(repr(x2)) -""" # noqa +""" # noqa: E501 # noinspection PyUnresolvedReferences from django.core.exceptions import ValidationError diff --git a/cardinal_pythonlib/django/forms.py b/cardinal_pythonlib/django/forms.py index 196b524..4a66972 100644 --- a/cardinal_pythonlib/django/forms.py +++ b/cardinal_pythonlib/django/forms.py @@ -69,7 +69,7 @@ class MultipleIntAreaField(forms.Field): Django ``forms.Field`` to capture multiple integers. """ - # See also https://stackoverflow.com/questions/29303902/django-form-with-list-of-integers # noqa + # See also https://stackoverflow.com/questions/29303902/django-form-with-list-of-integers # noqa: E501 widget = forms.Textarea def clean(self, value) -> List[int]: diff --git a/cardinal_pythonlib/django/function_cache.py b/cardinal_pythonlib/django/function_cache.py index a994ecb..11127b1 100644 --- a/cardinal_pythonlib/django/function_cache.py +++ b/cardinal_pythonlib/django/function_cache.py @@ -135,7 +135,7 @@ def decorator(fn): def wrapper(*args, **kwargs): # - NOTE that Django returns None from cache.get() for "not in # cache", so can't cache a None value; - # https://docs.djangoproject.com/en/1.10/topics/cache/#basic-usage # noqa + # https://docs.djangoproject.com/en/1.10/topics/cache/#basic-usage # noqa: E501 # - We need to store a bit more than just the function result # anyway, to detect hash collisions when the user doesn't specify # the cache_key, so we may as well use that format even if the diff --git a/cardinal_pythonlib/django/middleware.py b/cardinal_pythonlib/django/middleware.py index 4820608..b201a08 100644 --- a/cardinal_pythonlib/django/middleware.py +++ b/cardinal_pythonlib/django/middleware.py @@ -234,7 +234,7 @@ def process_view( # ============================================================================= # DisableClientSideCachingMiddleware # ============================================================================= -# https://stackoverflow.com/questions/2095520/fighting-client-side-caching-in-django # noqa +# https://stackoverflow.com/questions/2095520/fighting-client-side-caching-in-django # noqa: E501 class DisableClientSideCachingMiddleware(MiddlewareMixin): diff --git a/cardinal_pythonlib/django/serve.py b/cardinal_pythonlib/django/serve.py index 7988b81..c368cf3 100644 --- a/cardinal_pythonlib/django/serve.py +++ b/cardinal_pythonlib/django/serve.py @@ -157,8 +157,8 @@ def serve_file( HTTP content type to use as default, if ``content_type`` is ``None`` """ - # https://stackoverflow.com/questions/1156246/having-django-serve-downloadable-files # noqa - # https://docs.djangoproject.com/en/dev/ref/request-response/#telling-the-browser-to-treat-the-response-as-a-file-attachment # noqa + # https://stackoverflow.com/questions/1156246/having-django-serve-downloadable-files # noqa: E501 + # https://docs.djangoproject.com/en/dev/ref/request-response/#telling-the-browser-to-treat-the-response-as-a-file-attachment # noqa: E501 # https://djangosnippets.org/snippets/365/ if offered_filename is None: offered_filename = os.path.basename(path_to_file) or "" @@ -182,7 +182,7 @@ def serve_file( # Note for debugging: Chrome may request a file more than once (e.g. with a # GET request that's then marked 'canceled' in the Network tab of the # developer console); this is normal: - # https://stackoverflow.com/questions/4460661/what-to-do-with-chrome-sending-extra-requests # noqa + # https://stackoverflow.com/questions/4460661/what-to-do-with-chrome-sending-extra-requests # noqa: E501 def serve_buffer( diff --git a/cardinal_pythonlib/docker.py b/cardinal_pythonlib/docker.py index e637100..5d58640 100644 --- a/cardinal_pythonlib/docker.py +++ b/cardinal_pythonlib/docker.py @@ -35,7 +35,7 @@ def running_under_docker() -> bool: As per https://stackoverflow.com/questions/43878953/how-does-one-detect-if-one-is-running-within-a-docker-container-within-python ... but without leaving a file open. - """ # noqa: E501 + """ # 1. Does /.dockerenv exist? if os.path.exists("/.dockerenv"): return True diff --git a/cardinal_pythonlib/dogpile_cache.py b/cardinal_pythonlib/dogpile_cache.py index 6f7cd1d..0f8248d 100644 --- a/cardinal_pythonlib/dogpile_cache.py +++ b/cardinal_pythonlib/dogpile_cache.py @@ -128,7 +128,7 @@ def get_namespace(fn: Callable, namespace: Optional[str]) -> str: normally a ``str``; if not ``None``, ``str(namespace)`` will be added to the result. See https://dogpilecache.readthedocs.io/en/latest/api.html#dogpile.cache.region.CacheRegion.cache_on_arguments - """ # noqa: E501 + """ # See hidden attributes with dir(fn) # noinspection PyUnresolvedReferences return "{module}:{name}{extra}".format( diff --git a/cardinal_pythonlib/dsp.py b/cardinal_pythonlib/dsp.py index 0047ecd..f017d4d 100644 --- a/cardinal_pythonlib/dsp.py +++ b/cardinal_pythonlib/dsp.py @@ -65,7 +65,7 @@ def normalized_frequency(f: float, sampling_freq: float) -> float: - e.g. see https://en.wikipedia.org/wiki/Nyquist_frequency, https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirnotch.html - """ # noqa: E501 + """ return f / (sampling_freq / 2.0) diff --git a/cardinal_pythonlib/email/sendmail.py b/cardinal_pythonlib/email/sendmail.py index d90c73c..a286fb8 100755 --- a/cardinal_pythonlib/email/sendmail.py +++ b/cardinal_pythonlib/email/sendmail.py @@ -179,7 +179,7 @@ def _assert_nocomma(x: Union[str, List[str]]) -> None: attachment_binaries = attachment_binaries or [] # type: List[bytes] attachment_binary_filenames = ( attachment_binary_filenames or [] - ) # type: List[str] # noqa + ) # type: List[str] assert len(attachment_binaries) == len(attachment_binary_filenames), ( "If you specify attachment_binaries or attachment_binary_filenames, " "they must be iterables of the same length." @@ -462,7 +462,7 @@ def send_email( https://docs.djangoproject.com/en/2.1/ref/settings/#email-use-ssl). We don't support that here. - """ # noqa + """ if isinstance(to, str): to = [to] if isinstance(cc, str): @@ -533,7 +533,7 @@ def is_email_valid(email_: str) -> bool: See https://stackoverflow.com/questions/8022530/how-to-check-for-valid-email-address. - """ # noqa + """ # Very basic checks! if not email_: return False diff --git a/cardinal_pythonlib/enumlike.py b/cardinal_pythonlib/enumlike.py index 80078f9..2b3bdff 100644 --- a/cardinal_pythonlib/enumlike.py +++ b/cardinal_pythonlib/enumlike.py @@ -321,7 +321,7 @@ class Animal(AutoStrEnum): https://stackoverflow.com/questions/32214614/automatically-setting-an-enum-members-value-to-its-name/32215467 and then inherit from :class:`StrEnum` rather than :class:`Enum`. - """ # noqa: E501 + """ pass @@ -437,7 +437,7 @@ def __new__(mcs, name, bases, classdict): # mcs: was cls '__module__': 0 } ) - """ # noqa + """ # noqa: E501 # print("__new__: name={}, bases={}, classdict={}".format( # repr(name), repr(bases), repr(classdict))) cls = type.__new__(mcs, name, bases, dict(classdict)) @@ -689,7 +689,7 @@ class TestEnum(Enum, metaclass=CaseInsensitiveEnumMeta): TestEnum["PineApple"] # TestEnum["PineApplE"] # - """ # noqa + """ def __getitem__(self, item: Any) -> Any: if isinstance(item, str): diff --git a/cardinal_pythonlib/exceptions.py b/cardinal_pythonlib/exceptions.py index 47cb539..856107e 100644 --- a/cardinal_pythonlib/exceptions.py +++ b/cardinal_pythonlib/exceptions.py @@ -50,7 +50,7 @@ def add_info_to_exception(err: Exception, info: Dict) -> None: Args: err: the exception to be modified info: the information to add - """ # noqa: E501 + """ if not err.args: err.args = ("",) err.args += (info,) @@ -106,7 +106,7 @@ def fail(): echo $? # show exit code - """ # noqa: E501 + """ if exc: lines = traceback.format_exception( None, exc, exc.__traceback__ # etype: ignored diff --git a/cardinal_pythonlib/extract_text.py b/cardinal_pythonlib/extract_text.py index 854bcea..8cf4f78 100755 --- a/cardinal_pythonlib/extract_text.py +++ b/cardinal_pythonlib/extract_text.py @@ -613,7 +613,7 @@ def availability_pdf() -> bool: # ----------------------------------------------------------------------------- # In a D.I.Y. fashion # ----------------------------------------------------------------------------- -# DOCX specification: http://www.ecma-international.org/news/TC45_current_work/TC45_available_docs.htm # noqa +# DOCX specification: http://www.ecma-international.org/news/TC45_current_work/TC45_available_docs.htm # noqa: E501 DOCX_HEADER_FILE_REGEX = re.compile("word/header[0-9]*.xml") DOCX_DOC_FILE = "word/document.xml" @@ -630,7 +630,7 @@ def docx_qn(tagroot): DOCX_TEXT = docx_qn("t") DOCX_TABLE = docx_qn( "tbl" -) # https://github.com/python-openxml/python-docx/blob/master/docx/table.py # noqa +) # https://github.com/python-openxml/python-docx/blob/master/docx/table.py DOCX_TAB = docx_qn("tab") DOCX_NEWLINES = [docx_qn("br"), docx_qn("cr")] DOCX_NEWPARA = docx_qn("p") @@ -1315,7 +1315,7 @@ def convert_rtf_to_text( else: return get_cmd_output_from_stdin(blob, *args) elif pyth: # Very memory-consuming: - # https://github.com/brendonh/pyth/blob/master/pyth/plugins/rtf15/reader.py # noqa + # https://github.com/brendonh/pyth/blob/master/pyth/plugins/rtf15/reader.py # noqa: E501 with get_filelikeobject(filename, blob) as fp: doc = pyth.plugins.rtf15.reader.Rtf15Reader.read(fp) return pyth.plugins.plaintext.writer.PlaintextWriter.write( diff --git a/cardinal_pythonlib/file_io.py b/cardinal_pythonlib/file_io.py index e10bb52..ea077b7 100644 --- a/cardinal_pythonlib/file_io.py +++ b/cardinal_pythonlib/file_io.py @@ -123,7 +123,7 @@ def writelines_nl(fileobj: TextIO, lines: Iterable[str]) -> None: (Since :func:`fileobj.writelines` doesn't add newlines... https://stackoverflow.com/questions/13730107/writelines-writes-lines-without-newline-just-fills-the-file) - """ # noqa: E501 + """ fileobj.write("\n".join(lines) + "\n") @@ -365,7 +365,7 @@ def gen_part_from_iterables( """ # RST: make part of word bold/italic: - # https://stackoverflow.com/questions/12771480/part-of-a-word-bold-in-restructuredtext # noqa + # https://stackoverflow.com/questions/12771480/part-of-a-word-bold-in-restructuredtext # noqa: E501 for iterable in iterables: yield iterable[part_index] diff --git a/cardinal_pythonlib/fileops.py b/cardinal_pythonlib/fileops.py index 5a37a14..eed108c 100644 --- a/cardinal_pythonlib/fileops.py +++ b/cardinal_pythonlib/fileops.py @@ -353,7 +353,7 @@ def shutil_rmtree_onerror( See https://stackoverflow.com/questions/2656322/shutil-rmtree-fails-on-windows-with-access-is-denied - """ # noqa + """ if not os.access(path, os.W_OK): # Is the error an access error ? os.chmod(path, stat.S_IWUSR) @@ -539,7 +539,7 @@ def get_directory_contents_size(directory: str = ".") -> int: Returns: int: size in bytes - """ # noqa + """ total_size = 0 for dirpath, dirnames, filenames in os.walk(directory): for f in filenames: diff --git a/cardinal_pythonlib/hash.py b/cardinal_pythonlib/hash.py index 47b7b37..49e373b 100644 --- a/cardinal_pythonlib/hash.py +++ b/cardinal_pythonlib/hash.py @@ -414,7 +414,7 @@ def murmur3_x86_32(data: Union[bytes, bytearray], seed: int = 0) -> int: Returns: integer hash - """ # noqa + """ c1 = 0xCC9E2D51 c2 = 0x1B873593 @@ -481,7 +481,7 @@ def murmur3_64(data: Union[bytes, bytearray], seed: int = 19820125) -> int: Returns: integer hash - """ # noqa + """ m = 0xC6A4A7935BD1E995 r = 47 @@ -1070,7 +1070,7 @@ def main() -> None: print(twos_comp_to_signed(2 ** 32 - 1, n_bits=32)) # -1 print(signed_to_twos_comp(-1, n_bits=32)) # 4294967295 = 2 ** 32 - 1 print(signed_to_twos_comp(-(2 ** 31), n_bits=32)) # 2147483648 = 2 ** 31 - 1 - """ # noqa + """ # noqa: E501 testdata = ["hello", 1, ["bongos", "today"]] for data in testdata: compare_python_to_reference_murmur3_32(data, seed=0) diff --git a/cardinal_pythonlib/httpconst.py b/cardinal_pythonlib/httpconst.py index 99f0d95..1100026 100644 --- a/cardinal_pythonlib/httpconst.py +++ b/cardinal_pythonlib/httpconst.py @@ -49,7 +49,7 @@ >>> print(mimetypes.guess_type("thing.xlsx")) ('application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', None) -""" # noqa +""" # noqa: E501 # ============================================================================= @@ -215,14 +215,14 @@ class MimeType(object): - https://www.openoffice.org/framework/documentation/mimetypes/mimetypes.html - https://stackoverflow.com/questions/31489757/what-is-correct-mimetype-with-apache-openoffice-files-like-odt-ods-odp - """ # noqa + """ # noqa: E501 BINARY = "application/octet-stream" CSV = "text/csv" DOC = "application/msword" - DOCX = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" # noqa + DOCX = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" # noqa: E501 DOT = DOC - DOTX = "application/vnd.openxmlformats-officedocument.wordprocessingml.template" # noqa + DOTX = "application/vnd.openxmlformats-officedocument.wordprocessingml.template" # noqa: E501 FORCE_DOWNLOAD = "application/force-download" HTML = "text/html" JSON = "application/json" diff --git a/cardinal_pythonlib/iterhelp.py b/cardinal_pythonlib/iterhelp.py index 23be5e4..c43db1e 100644 --- a/cardinal_pythonlib/iterhelp.py +++ b/cardinal_pythonlib/iterhelp.py @@ -50,7 +50,7 @@ def product_dict(**kwargs: Iterable) -> Iterable[Dict]: >>> product_dict(a="x", b=range(3)) - """ # noqa + """ # noqa: E501 keys = kwargs.keys() vals = kwargs.values() for instance in product(*vals): diff --git a/cardinal_pythonlib/json/serialize.py b/cardinal_pythonlib/json/serialize.py index 0e44f49..f54c2e2 100644 --- a/cardinal_pythonlib/json/serialize.py +++ b/cardinal_pythonlib/json/serialize.py @@ -500,7 +500,7 @@ def __init__(self, db: str = '', schema: str = '', print(f"register_for_json: args = {args!r}") print(f"register_for_json: kwargs = {kwargs!r}") - # https://stackoverflow.com/questions/653368/how-to-create-a-python-decorator-that-can-be-used-either-with-or-without-paramet # noqa + # https://stackoverflow.com/questions/653368/how-to-create-a-python-decorator-that-can-be-used-either-with-or-without-paramet # noqa: E501 # In brief, # @decorator # x @@ -540,13 +540,13 @@ def __init__(self, db: str = '', schema: str = '', method = kwargs.pop("method", METHOD_SIMPLE) # type: str obj_to_dict_fn = kwargs.pop( "obj_to_dict_fn", None - ) # type: InstanceToDictFnType # noqa + ) # type: InstanceToDictFnType dict_to_obj_fn = kwargs.pop( "dict_to_obj_fn", initdict_to_instance - ) # type: DictToInstanceFnType # noqa + ) # type: DictToInstanceFnType default_factory = kwargs.pop( "default_factory", None - ) # type: DefaultFactoryFnType # noqa + ) # type: DefaultFactoryFnType check_result = kwargs.pop("check_results", True) # type: bool def register_json_class(cls_: ClassType) -> ClassType: diff --git a/cardinal_pythonlib/lists.py b/cardinal_pythonlib/lists.py index 7d308e1..6b81f64 100644 --- a/cardinal_pythonlib/lists.py +++ b/cardinal_pythonlib/lists.py @@ -118,7 +118,7 @@ def flatten_list(x: List[Any]) -> List[Any]: As per https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python - """ # noqa: E501 + """ return [item for sublist in x for item in sublist] @@ -135,7 +135,7 @@ def unique_list(seq: Iterable[Any]) -> List[Any]: As per https://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-in-whilst-preserving-order - """ # noqa: E501 + """ seen = set() seen_add = seen.add return [x for x in seq if not (x in seen or seen_add(x))] diff --git a/cardinal_pythonlib/logs.py b/cardinal_pythonlib/logs.py index ae8fde3..bf479cd 100644 --- a/cardinal_pythonlib/logs.py +++ b/cardinal_pythonlib/logs.py @@ -622,7 +622,7 @@ def __init__( ) -> None: # This version uses args and kwargs, not *args and **kwargs, for # performance reasons: - # https://stackoverflow.com/questions/31992424/performance-implications-of-unpacking-dictionaries-in-python # noqa + # https://stackoverflow.com/questions/31992424/performance-implications-of-unpacking-dictionaries-in-python # noqa: E501 # ... and since we control creation entirely, we may as well go fast self.fmt = fmt self.args = args @@ -671,7 +671,7 @@ def __init__( log.info("Hello {}, {title} {surname}!", "world", title="Mr", surname="Smith") # 2018-09-17 16:13:50.404 __main__:INFO: Hello world, Mr Smith! - """ # noqa + """ # noqa: E501 # noinspection PyTypeChecker super().__init__(logger=logger, extra=None) self.pass_special_logger_args = pass_special_logger_args @@ -686,7 +686,7 @@ def __init__( # ... defaults = tuple of default argument values, or None # signature() returns a Signature object: # ... parameters: ordered mapping of name -> Parameter - # ... ... https://docs.python.org/3/library/inspect.html#inspect.Parameter # noqa + # ... ... https://docs.python.org/3/library/inspect.html#inspect.Parameter # noqa: E501 # Direct equivalence: # https://github.com/praw-dev/praw/issues/541 # So, old: diff --git a/cardinal_pythonlib/maths_py.py b/cardinal_pythonlib/maths_py.py index 70a9b4a..9eccb70 100644 --- a/cardinal_pythonlib/maths_py.py +++ b/cardinal_pythonlib/maths_py.py @@ -115,7 +115,7 @@ def normal_round_float(x: float, dp: int = 0) -> float: Note that round() implements "banker's rounding", which is never what we want: - - https://stackoverflow.com/questions/33019698/how-to-properly-round-up-half-float-numbers-in-python # noqa + - https://stackoverflow.com/questions/33019698/how-to-properly-round-up-half-float-numbers-in-python # noqa: E501 """ if not math.isfinite(x): return x @@ -171,7 +171,7 @@ def round_sf(x: float, n: int = 2) -> float: round_sf(1234567890000, 3) # 1230000000000 round_sf(9876543210000, 3) # 9880000000000 - """ # noqa: E501 + """ y = abs(x) if y <= sys.float_info.min: return 0.0 diff --git a/cardinal_pythonlib/metaclasses.py b/cardinal_pythonlib/metaclasses.py index a57f39b..b36e966 100644 --- a/cardinal_pythonlib/metaclasses.py +++ b/cardinal_pythonlib/metaclasses.py @@ -53,7 +53,7 @@ class CooperativeMeta(type): See also https://blog.ionelmc.ro/2015/02/09/understanding-python-metaclasses/. - """ # noqa + """ def __new__( mcs: Type, name: str, bases: Tuple[Type, ...], members: Dict[str, Any] diff --git a/cardinal_pythonlib/modules.py b/cardinal_pythonlib/modules.py index d81b65c..fe34aa6 100644 --- a/cardinal_pythonlib/modules.py +++ b/cardinal_pythonlib/modules.py @@ -136,7 +136,7 @@ def is_c_extension(module: ModuleType) -> bool: is_c_extension(et) # False on my system (Python 3.5.6). True in the original example. is_c_extension(numpy_multiarray) # True - """ # noqa + """ # noqa: E501 assert inspect.ismodule(module), f'"{module}" not a module.' # If this module was loaded by a PEP 302-compliant CPython-specific loader @@ -220,7 +220,7 @@ def contains_c_extension( contains_c_extension(django) - """ # noqa + """ # noqa: E501 assert inspect.ismodule(module), f'"{module}" not a module.' if seen is None: # only true for the top-level call @@ -275,7 +275,7 @@ def contains_c_extension( # Recurse: if contains_c_extension( module=candidate, - import_all_submodules=False, # only done at the top level, below # noqa + import_all_submodules=False, # only done at the top level, below include_external_imports=include_external_imports, seen=seen, ): @@ -288,7 +288,7 @@ def contains_c_extension( # Otherwise, for things like Django, we need to recurse in a different # way to scan everything. - # See https://stackoverflow.com/questions/3365740/how-to-import-all-submodules. # noqa + # See https://stackoverflow.com/questions/3365740/how-to-import-all-submodules. # noqa: E501 log.debug("Walking path: {!r}", top_path) # noinspection PyBroadException try: @@ -312,7 +312,7 @@ def contains_c_extension( continue if contains_c_extension( module=candidate, - import_all_submodules=False, # only done at the top level # noqa + import_all_submodules=False, # only done at the top level include_external_imports=include_external_imports, seen=seen, ): diff --git a/cardinal_pythonlib/nhs.py b/cardinal_pythonlib/nhs.py index 199228b..aa32ec8 100644 --- a/cardinal_pythonlib/nhs.py +++ b/cardinal_pythonlib/nhs.py @@ -92,7 +92,7 @@ def is_valid_nhs_number(n: int) -> bool: Checksum details are at https://web.archive.org/web/20180311083424/https://www.datadictionary.nhs.uk/version2/data_dictionary/data_field_notes/n/nhs_number_de.asp; https://web.archive.org/web/20220503215904/https://www.datadictionary.nhs.uk/attributes/nhs_number.html - """ # noqa: E501 + """ if not isinstance(n, int): log.debug("is_valid_nhs_number: parameter was not of integer type") return False @@ -127,7 +127,7 @@ def generate_random_nhs_number(official_test_range: bool = True) -> int: https://digital.nhs.uk/services/e-referral-service/document-library/synthetic-data-in-live-environments, saved at https://web.archive.org/web/20210116183039/https://digital.nhs.uk/services/e-referral-service/document-library/synthetic-data-in-live-environments. - """ # noqa + """ check_digit = 10 # NHS numbers with this check digit are all invalid while check_digit == 10: if official_test_range: @@ -218,7 +218,7 @@ def nhs_number_from_text_or_none(s: str) -> Optional[int]: NHS number rules: https://www.datadictionary.nhs.uk/version2/data_dictionary/data_field_notes/n/nhs_number_de.asp?shownav=0 - """ # noqa + """ # None in, None out. funcname = "nhs_number_from_text_or_none: " if not s: diff --git a/cardinal_pythonlib/openxml/find_bad_openxml.py b/cardinal_pythonlib/openxml/find_bad_openxml.py index 916a280..561be2e 100644 --- a/cardinal_pythonlib/openxml/find_bad_openxml.py +++ b/cardinal_pythonlib/openxml/find_bad_openxml.py @@ -305,7 +305,7 @@ def main() -> None: # result.get() # will re-raise any child exceptions # ... but it waits for the process to complete! That's no help. # log.critical("next") - # ... https://stackoverflow.com/questions/22094852/how-to-catch-exceptions-in-workers-in-multiprocessing # noqa + # ... https://stackoverflow.com/questions/22094852/how-to-catch-exceptions-in-workers-in-multiprocessing # noqa: E501 pool.close() pool.join() diff --git a/cardinal_pythonlib/openxml/find_recovered_openxml.py b/cardinal_pythonlib/openxml/find_recovered_openxml.py index 274ab7a..e730a4e 100644 --- a/cardinal_pythonlib/openxml/find_recovered_openxml.py +++ b/cardinal_pythonlib/openxml/find_recovered_openxml.py @@ -364,8 +364,8 @@ def process_file( ) raise # See also good advice, not implemented here, at - # https://stackoverflow.com/questions/19924104/python-multiprocessing-handling-child-errors-in-parent # noqa - # https://stackoverflow.com/questions/6126007/python-getting-a-traceback-from-a-multiprocessing-process/26096355#26096355 # noqa + # https://stackoverflow.com/questions/19924104/python-multiprocessing-handling-child-errors-in-parent # noqa: E501 + # https://stackoverflow.com/questions/6126007/python-getting-a-traceback-from-a-multiprocessing-process/26096355#26096355 # noqa: E501 # log.critical("process_file: end") @@ -560,7 +560,7 @@ def main() -> None: # result.get() # will re-raise any child exceptions # ... but it waits for the process to complete! That's no help. # log.critical("next") - # ... https://stackoverflow.com/questions/22094852/how-to-catch-exceptions-in-workers-in-multiprocessing # noqa + # ... https://stackoverflow.com/questions/22094852/how-to-catch-exceptions-in-workers-in-multiprocessing # noqa: E501 pool.close() pool.join() diff --git a/cardinal_pythonlib/openxml/grep_in_openxml.py b/cardinal_pythonlib/openxml/grep_in_openxml.py index 318b9e0..2d1c0ff 100644 --- a/cardinal_pythonlib/openxml/grep_in_openxml.py +++ b/cardinal_pythonlib/openxml/grep_in_openxml.py @@ -209,7 +209,7 @@ def main() -> None: exe_name = os.path.basename(argv[0]) or "grep_in_openxml" parser = ArgumentParser( formatter_class=RawDescriptionRichHelpFormatter, - description=f""" + description=rf""" Performs a grep (global-regular-expression-print) search of files in OpenXML format, which is to say inside ZIP files. @@ -228,7 +228,7 @@ def main() -> None: "Hardy" in DOC/DOCX documents, in case-insensitive fashion: find . -type f -name "*.doc*" -exec {exe_name} -l -i "laurel" {{}} \; | {exe_name} -x -l -i "hardy" -""", # noqa +""", # noqa: E501 ) parser.add_argument("pattern", help="Regular expression pattern to apply.") parser.add_argument( diff --git a/cardinal_pythonlib/pdf.py b/cardinal_pythonlib/pdf.py index 41725cd..2e844fd 100644 --- a/cardinal_pythonlib/pdf.py +++ b/cardinal_pythonlib/pdf.py @@ -621,11 +621,11 @@ def append_pdf(input_pdf: bytes, output_writer: PdfWriter): # serve the result (e.g. in one go), then delete the temporary file. # This may be more memory-efficient. # However, there can be problems: -# https://stackoverflow.com/questions/7543452/how-to-launch-a-pdftk-subprocess-while-in-wsgi # noqa +# https://stackoverflow.com/questions/7543452/how-to-launch-a-pdftk-subprocess-while-in-wsgi # noqa: E501 # Others' examples: # https://gist.github.com/zyegfryed/918403 # https://gist.github.com/grantmcconnaughey/ce90a689050c07c61c96 -# https://stackoverflow.com/questions/3582414/removing-tmp-file-after-return-httpresponse-in-django # noqa +# https://stackoverflow.com/questions/3582414/removing-tmp-file-after-return-httpresponse-in-django # noqa: E501 def get_concatenated_pdf_from_disk( @@ -642,7 +642,7 @@ def get_concatenated_pdf_from_disk( concatenated PDF, as ``bytes`` """ - # https://stackoverflow.com/questions/17104926/pypdf-merging-multiple-pdf-files-into-one-pdf # noqa + # https://stackoverflow.com/questions/17104926/pypdf-merging-multiple-pdf-files-into-one-pdf # noqa: E501 # https://en.wikipedia.org/wiki/Recto_and_verso # PdfMerger deprecated as of pypdf==5.0.0; use PdfWriter instead. # - https://pypdf.readthedocs.io/en/stable/modules/PdfMerger.html diff --git a/cardinal_pythonlib/probability.py b/cardinal_pythonlib/probability.py index 8b95030..61409ff 100644 --- a/cardinal_pythonlib/probability.py +++ b/cardinal_pythonlib/probability.py @@ -92,7 +92,7 @@ def ln(x: float) -> float: timeit.timeit('(ln(x) for x in range(1, 100))', number=10000) # 0.007783170789480209 - """ # noqa: E501 + """ # return math_ln(x) if x != 0 else MINUS_INFINITY # slower, less helpful try: return math_ln(x) diff --git a/cardinal_pythonlib/process.py b/cardinal_pythonlib/process.py index e18d11c..c85f797 100644 --- a/cardinal_pythonlib/process.py +++ b/cardinal_pythonlib/process.py @@ -167,7 +167,7 @@ def kill_proc_tree( tuple: ``(gone, still_alive)``, where both are sets of :class:`psutil.Process` objects - """ # noqa: E501 + """ parent = psutil.Process(pid) to_kill = parent.children(recursive=True) # type: List[psutil.Process] if including_parent: @@ -192,7 +192,7 @@ def nice_call( Modified from https://stackoverflow.com/questions/34458583/python-subprocess-call-doesnt-handle-signal-correctly - """ # noqa + """ with subprocess.Popen(*popenargs, **kwargs) as p: try: return p.wait(timeout=timeout) diff --git a/cardinal_pythonlib/psychiatry/drugs.py b/cardinal_pythonlib/psychiatry/drugs.py index d279148..d7e859c 100644 --- a/cardinal_pythonlib/psychiatry/drugs.py +++ b/cardinal_pythonlib/psychiatry/drugs.py @@ -215,7 +215,7 @@ # HOWEVER, NOTE THAT LITHIUM IS CURRENTLY OVER-INCLUSIVE and will include # lithium chloride for LiDCO measurement. -""" # noqa +""" # noqa: E501 import re from typing import List, Optional, Pattern, Union @@ -256,7 +256,7 @@ def __init__( add_preceding_word_boundary: bool = True, add_following_wildcards: bool = True, # Psychiatry - psychotropic: bool = None, # special; can be used as override if False # noqa + psychotropic: bool = None, # special; can be used as override if False antidepressant: bool = False, conventional_antidepressant: bool = False, ssri: bool = False, @@ -603,7 +603,7 @@ def deduplicate_wildcards(text: str) -> str: option_groups = bracketed.split("|") options = [c for group in option_groups for c in group] split_and_append(options) - working = working[close_bracket + 1 :] # noqa: E203 + working = working[close_bracket + 1 :] elif len(working) > 1 and working[1] == "?": # e.g. "r?azole" split_and_append(["", working[0]]) @@ -1033,13 +1033,13 @@ def sql_column_like_drug(self, column_name: str) -> str: ["phenylethylhydrazine", "Alazin", "Nardil"], monoamine_oxidase_inhibitor=True, slam_antidepressant_finder=True - # - SLAM code (see e-mail to self 2016-12-02) also has %Alazin%; not sure # noqa + # - SLAM code (see e-mail to self 2016-12-02) also has %Alazin%; not sure # noqa: E501 # that's right; see also # http://www.druglib.com/activeingredient/phenelzine/ # - oh, yes, it is right: - # https://www.pharmacompass.com/active-pharmaceutical-ingredients/alazin # noqa + # https://www.pharmacompass.com/active-pharmaceutical-ingredients/alazin # noqa: E501 # - phenylethylhydrazine is a synonym; see - # http://www.minclinic.ru/drugs/drugs_eng/B/Beta-phenylethylhydrazine.html # noqa + # http://www.minclinic.ru/drugs/drugs_eng/B/Beta-phenylethylhydrazine.html # noqa: E501 ), # not included: pheniprazine Drug( diff --git a/cardinal_pythonlib/psychiatry/mk_r_druglists.py b/cardinal_pythonlib/psychiatry/mk_r_druglists.py index 31d761c..4891dfe 100755 --- a/cardinal_pythonlib/psychiatry/mk_r_druglists.py +++ b/cardinal_pythonlib/psychiatry/mk_r_druglists.py @@ -88,7 +88,7 @@ def rscript() -> str: now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - # https://www.nhs.uk/mental-health/talking-therapies-medicine-treatments/medicines-and-psychiatry/antidepressants/overview/ # noqa + # https://www.nhs.uk/mental-health/talking-therapies-medicine-treatments/medicines-and-psychiatry/antidepressants/overview/ # noqa: E501 ssri = converter(all_drugs_where(ssri=True, mixture=False)) snri = converter(all_drugs_where(snri=True, mixture=False)) mirtazapine = converter([get_drug("mirtazapine")]) @@ -181,7 +181,7 @@ def rscript() -> str: ANTIDEPRESSANTS_BROAD_INC_ANTIDEP_SGA <- {antidepressants_broad_inc_antidep_sga} -""" # noqa +""" # noqa: E501 # ============================================================================= diff --git a/cardinal_pythonlib/psychiatry/rfunc.py b/cardinal_pythonlib/psychiatry/rfunc.py index bb1bb34..02452a4 100644 --- a/cardinal_pythonlib/psychiatry/rfunc.py +++ b/cardinal_pythonlib/psychiatry/rfunc.py @@ -64,7 +64,7 @@ repl_python() # start an interactive Python session -""" # noqa +""" import sys from typing import Any, Dict diff --git a/cardinal_pythonlib/psychiatry/simhelpers.py b/cardinal_pythonlib/psychiatry/simhelpers.py index 46ade62..ba07554 100644 --- a/cardinal_pythonlib/psychiatry/simhelpers.py +++ b/cardinal_pythonlib/psychiatry/simhelpers.py @@ -142,7 +142,7 @@ def gen_params_around_centre( from cardinal_pythonlib.psychiatry.simhelpers import gen_params_around_centre list(gen_params_around_centre(a=[1, 2, 3], b=[4, 5, 6], c=[7, 8, 9], d=[10])) - """ # noqa + """ # noqa: E501 param_order = param_order or [] # type: Sequence[str] def _sorter(x: str) -> Tuple[bool, Union[int, str]]: diff --git a/cardinal_pythonlib/psychiatry/treatment_resistant_depression.py b/cardinal_pythonlib/psychiatry/treatment_resistant_depression.py index 0a58508..0ebfc49 100755 --- a/cardinal_pythonlib/psychiatry/treatment_resistant_depression.py +++ b/cardinal_pythonlib/psychiatry/treatment_resistant_depression.py @@ -29,7 +29,7 @@ - 200 test patients; baseline about 7.65-8.57 seconds (25 Hz). - From https://stackoverflow.com/questions/19237878/ to - https://stackoverflow.com/questions/17071871/select-rows-from-a-dataframe-based-on-values-in-a-column-in-pandas # noqa + https://stackoverflow.com/questions/17071871/select-rows-from-a-dataframe-based-on-values-in-a-column-in-pandas # noqa: E501 - Change from parallel to single-threading: down to 4.38 s (!). - Avoid a couple of slices: down to 3.85 s for 200 patients. - Add test patient E; up to 4.63 s for 250 patients (54 Hz). @@ -55,7 +55,7 @@ - Profiler off: 2.38s for 300 patients, or 126 Hz. Let's call that a day; we've achieved a 5-fold speedup. -""" # noqa +""" import cProfile from concurrent.futures import ThreadPoolExecutor @@ -80,8 +80,8 @@ DTYPE_STRING = " Optional[DataFrame]: """ @@ -349,8 +349,8 @@ def two_antidepressant_episodes_single_patient( # OK; here we have found a combination that we like. # Add it to the results. # --------------------------------------------------------------------- - # https://stackoverflow.com/questions/19365513/how-to-add-an-extra-row-to-a-pandas-dataframe/19368360 # noqa - # http://pandas.pydata.org/pandas-docs/stable/indexing.html#setting-with-enlargement # noqa + # https://stackoverflow.com/questions/19365513/how-to-add-an-extra-row-to-a-pandas-dataframe/19368360 # noqa: E501 + # http://pandas.pydata.org/pandas-docs/stable/indexing.html#setting-with-enlargement # noqa: E501 expect_response_by_date = ( antidepressant_b_first_mention @@ -509,7 +509,7 @@ def _make_example(suffixes: Iterable[Any] = None) -> DataFrame: (bob_s, venla, "2018-04-01"), (bob_s, sert, "2018-05-01"), (bob_s, sert, "2018-06-01"), - # Alice: two consecutive switches; should pick the first, c -> f # noqa + # Alice: two consecutive switches; should pick the first, c -> f # noqa: E501 # ... goes second in the data; should be sorted to first (alice_s, cital, "2018-01-01"), (alice_s, cital, "2018-02-01"), diff --git a/cardinal_pythonlib/pyramid/compression.py b/cardinal_pythonlib/pyramid/compression.py index c0a0c44..3b4edd0 100644 --- a/cardinal_pythonlib/pyramid/compression.py +++ b/cardinal_pythonlib/pyramid/compression.py @@ -62,7 +62,7 @@ class CompressionTweenFactory(object): - https://docs.pylonsproject.org/projects/pyramid/en/latest/narr/hooks.html - https://docs.pylonsproject.org/projects/pyramid/en/latest/api/request.html - https://docs.pylonsproject.org/projects/pyramid/en/latest/api/response.html - """ # noqa + """ # noqa: E501 def __init__( self, handler: PyramidHandlerType, registry: Registry diff --git a/cardinal_pythonlib/pyramid/requests.py b/cardinal_pythonlib/pyramid/requests.py index d444d3c..669fae9 100644 --- a/cardinal_pythonlib/pyramid/requests.py +++ b/cardinal_pythonlib/pyramid/requests.py @@ -125,7 +125,7 @@ def request_accepts_gzip(request: Request) -> bool: - So we'll do a case-sensitive check for "gzip". - But there is also a bit of other syntax possible; see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding. - """ # noqa + """ headers = request.headers # type: EnvironHeaders if HTTP_ACCEPT_ENCODING not in headers: return False @@ -185,4 +185,4 @@ def decompress_request(request: Request) -> None: else: raise ValueError(f"Unknown Content-Encoding: {encoding}") # ... e.g. "compress"; LZW; patent expired; see - # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding # noqa + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding # noqa: E501 diff --git a/cardinal_pythonlib/regexfunc.py b/cardinal_pythonlib/regexfunc.py index d8bcf03..0690fb7 100644 --- a/cardinal_pythonlib/regexfunc.py +++ b/cardinal_pythonlib/regexfunc.py @@ -46,7 +46,7 @@ class CompiledRegexMemory(object): Based on https://stackoverflow.com/questions/597476/how-to-concisely-cascade-through-multiple-regex-statements-in-python. - """ # noqa: E501 + """ def __init__(self) -> None: self.last_match = None # type: Optional[Match] diff --git a/cardinal_pythonlib/rounding.py b/cardinal_pythonlib/rounding.py index 9c29902..b7b855e 100644 --- a/cardinal_pythonlib/rounding.py +++ b/cardinal_pythonlib/rounding.py @@ -142,7 +142,7 @@ def range_roundable_up_to( Note that ``dp`` can be negative, as in other Python functions. - """ # noqa: E501 + """ y = Decimal(y) assert num_dp_from_decimal(y, with_negative_dp=True) <= dp, ( f"Number {y} is not rounded to {dp} dp as claimed; it has " diff --git a/cardinal_pythonlib/rpm.py b/cardinal_pythonlib/rpm.py index 4737b0e..ae69fd1 100644 --- a/cardinal_pythonlib/rpm.py +++ b/cardinal_pythonlib/rpm.py @@ -439,7 +439,7 @@ def dbinom_raw_log(x: float, n: float, p: float, q: float) -> float: * Do this in the calling function. */ - """ # noqa + """ # noqa: E501 log_0 = -inf log_1 = 0 @@ -537,7 +537,7 @@ def beta_pdf_fast(x: float, a: float, b: float) -> float: * term is large. We use Loader's code only if both a and b > 2. */ - """ # noqa + """ # noqa: E501 # logger.critical(f"beta_pdf_fast(x={x}, a={a}, b={b})") if a < 0 or b < 0: return NaN @@ -721,7 +721,7 @@ def rpm_probabilities_successes_failures_twochoice_fast( Massively tedious optimization (translation from R's C code to Python) but it works very well. - """ # noqa + """ # noqa: E501 args = (n_success_this, n_failure_this, n_success_other, n_failure_other) # ... tuple, not numpy array, or we get "TypeError: only size-1 arrays can # be converted to Python scalars" @@ -748,8 +748,8 @@ def rpm_integrand_n_choice(args: np.ndarray) -> float: x = args[0] k = int(args[1]) # k is the number of actions current_action = int(args[2]) # zero-based index - n_successes_plus_one = args[3 : k + 3] # noqa: E203 - n_failures_plus_one = args[k + 3 :] # noqa: E203 + n_successes_plus_one = args[3 : k + 3] + n_failures_plus_one = args[k + 3 :] r = beta_pdf_fast( x, diff --git a/cardinal_pythonlib/sizeformatter.py b/cardinal_pythonlib/sizeformatter.py index dbf2eea..10169fe 100644 --- a/cardinal_pythonlib/sizeformatter.py +++ b/cardinal_pythonlib/sizeformatter.py @@ -163,7 +163,7 @@ def human2bytes(s: str) -> int: Traceback (most recent call last): ... ValueError: can't interpret '12 foo' - """ # noqa: E501 + """ if not s: raise ValueError(f"Can't interpret {s!r} as integer") try: diff --git a/cardinal_pythonlib/snomed.py b/cardinal_pythonlib/snomed.py index 74691cd..e92cd8d 100644 --- a/cardinal_pythonlib/snomed.py +++ b/cardinal_pythonlib/snomed.py @@ -177,9 +177,9 @@ def test(s): test("ab'c\"d") test('ab"cd') - """ # noqa + """ # noqa: E501 # For efficiency, we use a list: - # https://stackoverflow.com/questions/3055477/how-slow-is-pythons-string-concatenation-vs-str-join # noqa + # https://stackoverflow.com/questions/3055477/how-slow-is-pythons-string-concatenation-vs-str-join # noqa: E501 # https://waymoot.org/home/python_string/ dquote = '"' ret = [dquote] # type: List[str] diff --git a/cardinal_pythonlib/source_reformatting.py b/cardinal_pythonlib/source_reformatting.py index 66045b1..96b4af3 100644 --- a/cardinal_pythonlib/source_reformatting.py +++ b/cardinal_pythonlib/source_reformatting.py @@ -171,7 +171,7 @@ def _create_dest(self) -> None: docstring_done = True in_body = True # ... and keep dl, so we write the end of the - # docstring, potentially with e.g. "# noqa" on the end + # docstring, potentially with e.g. a noqa on the end elif not docstring_done: # docstring starting in_docstring = True # self._critical("adding our new docstring") diff --git a/cardinal_pythonlib/sphinxtools.py b/cardinal_pythonlib/sphinxtools.py index 2f3bfbd..c0487fc 100644 --- a/cardinal_pythonlib/sphinxtools.py +++ b/cardinal_pythonlib/sphinxtools.py @@ -214,7 +214,7 @@ class FileToAutodocument(object): print(f.rst_content(prefix=".. Hello!", method=AutodocMethod.CONTENTS)) f.write_rst(prefix=".. Hello!") - """ # noqa: E501 + """ def __init__( self, @@ -260,7 +260,7 @@ def __init__( ) self.pygments_language_override = ( pygments_language_override or {} - ) # type: Dict[str, str] # noqa + ) # type: Dict[str, str] assert isfile( self.source_filename ), f"Not a file: source_filename={self.source_filename!r}" @@ -554,7 +554,7 @@ class AutodocIndex(object): print(flatidx.index_content()) flatidx.write_index_and_rst_files(overwrite=True, mock=True) - """ # noqa + """ def __init__( self, @@ -688,7 +688,7 @@ def __init__( self.source_rst_title_style_python = source_rst_title_style_python self.pygments_language_override = ( pygments_language_override or {} - ) # type: Dict[str, str] # noqa + ) # type: Dict[str, str] assert isdir( self.project_root_dir @@ -716,7 +716,7 @@ def __init__( self.files_to_index = ( [] - ) # type: List[Union[FileToAutodocument, AutodocIndex]] # noqa + ) # type: List[Union[FileToAutodocument, AutodocIndex]] if source_filenames_or_globs: self.add_source_files(source_filenames_or_globs) @@ -883,7 +883,7 @@ def write_index_and_rst_files( f.write_rst( prefix=self.rst_prefix, suffix=self.rst_suffix, - heading_underline_char=self.source_rst_heading_underline_char, # noqa + heading_underline_char=self.source_rst_heading_underline_char, # noqa: E501 overwrite=overwrite, mock=mock, ) diff --git a/cardinal_pythonlib/spreadsheets.py b/cardinal_pythonlib/spreadsheets.py index 9629623..1b089c2 100644 --- a/cardinal_pythonlib/spreadsheets.py +++ b/cardinal_pythonlib/spreadsheets.py @@ -72,7 +72,7 @@ def all_same(items: Iterable[Any]) -> bool: https://stackoverflow.com/questions/3787908/python-determine-if-all-items-of-a-list-are-the-same-item ... though we will also allow "no items" to pass the test. - """ # noqa: E501 + """ return len(set(items)) <= 1 @@ -564,7 +564,7 @@ def read_datetime( Reads a datetime from an Excel spreadsheet via xlrd. https://stackoverflow.com/questions/32430679/how-to-read-dates-using-xlrd - """ # noqa: E501 + """ v = self.read_value(row, col, check_header=check_header) if none_or_blank_string(v): return default @@ -891,9 +891,8 @@ def import_referrals(book: Book) -> None: def __init__(self, sheetholder: SheetHolder, row: int) -> None: self.sheetholder = sheetholder self.row = row # zero-based index of our row - self._next_col = ( - 0 # zero-based column index of the next column to read # noqa - ) + self._next_col = 0 + # ... zero-based column index of the next column to read # ------------------------------------------------------------------------- # Information diff --git a/cardinal_pythonlib/sql/sql_grammar.py b/cardinal_pythonlib/sql/sql_grammar.py index 154b59b..3730a8c 100644 --- a/cardinal_pythonlib/sql/sql_grammar.py +++ b/cardinal_pythonlib/sql/sql_grammar.py @@ -61,7 +61,7 @@ - https://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt ... particularly the formal specifications in chapter 5 on words -""" # noqa +""" # noqa: E501 import re from typing import List, Union @@ -115,7 +115,7 @@ def delim_list( WORD_BOUNDARY = r"\b" # The meaning of \b: -# https://stackoverflow.com/questions/4213800/is-there-something-like-a-counter-variable-in-regular-expression-replace/4214173#4214173 # noqa +# https://stackoverflow.com/questions/4213800/is-there-something-like-a-counter-variable-in-regular-expression-replace/4214173#4214173 # noqa: E501 def word_regex_element(word: str) -> str: @@ -403,7 +403,7 @@ def single_quote(fragment: str) -> str: | datetime_literal ).setName("literal_value") -# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add # noqa +# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add # noqa: E501 time_unit = make_words_regex( "MICROSECOND SECOND MINUTE HOUR DAY WEEK MONTH QUARTER YEAR" " SECOND_MICROSECOND" diff --git a/cardinal_pythonlib/sql/sql_grammar_mssql.py b/cardinal_pythonlib/sql/sql_grammar_mssql.py index c7f8b44..35c2cc3 100644 --- a/cardinal_pythonlib/sql/sql_grammar_mssql.py +++ b/cardinal_pythonlib/sql/sql_grammar_mssql.py @@ -119,11 +119,11 @@ # Not in SQL Server (though in MySQL): # -# don't think so: BINARY; http://gilfster.blogspot.co.uk/2005/08/case-sensitivity-in-mysql.html # noqa -# DISTINCTROW: no; https://stackoverflow.com/questions/8562136/distinctrow-equivalent-in-sql-server # noqa -# DIV/MOD: not in SQL Server; use / and % respectively; https://msdn.microsoft.com/en-us/library/ms190279.aspx # noqa -# PARTITION: not in SELECT? - https://msdn.microsoft.com/en-us/library/ms187802.aspx # noqa -# XOR: use ^ instead; https://stackoverflow.com/questions/5411619/t-sql-xor-operator # noqa +# don't think so: BINARY; http://gilfster.blogspot.co.uk/2005/08/case-sensitivity-in-mysql.html # noqa: E501 +# DISTINCTROW: no; https://stackoverflow.com/questions/8562136/distinctrow-equivalent-in-sql-server # noqa: E501 +# DIV/MOD: not in SQL Server; use / and % respectively; https://msdn.microsoft.com/en-us/library/ms190279.aspx # noqa: E501 +# PARTITION: not in SELECT? - https://msdn.microsoft.com/en-us/library/ms187802.aspx # noqa: E501 +# XOR: use ^ instead; https://stackoverflow.com/questions/5411619/t-sql-xor-operator # noqa: E501 # Definitely part of SQL Server: CHECKSUM_AGG = sql_keyword("CHECKSUM_AGG") @@ -231,7 +231,7 @@ class SqlGrammarMSSQLServer(SqlGrammar): # ... who thought "END-EXEC" was a good one? # Then some more: - # - WITH ROLLUP: https://technet.microsoft.com/en-us/library/ms189305(v=sql.90).aspx # noqa + # - WITH ROLLUP: https://technet.microsoft.com/en-us/library/ms189305(v=sql.90).aspx # noqa: E501 # - SOUNDEX: https://msdn.microsoft.com/en-us/library/ms187384.aspx rnc_extra_sql_server_keywords = """ ROLLUP @@ -305,7 +305,7 @@ class SqlGrammarMSSQLServer(SqlGrammar): ).setName("column_spec") # I'm unsure if SQL Server allows keywords in the parts after dots, like # MySQL does. - # - https://stackoverflow.com/questions/285775/how-to-deal-with-sql-column-names-that-look-like-sql-keywords # noqa + # - https://stackoverflow.com/questions/285775/how-to-deal-with-sql-column-names-that-look-like-sql-keywords # noqa: E501 bind_parameter = Literal("?") @@ -317,7 +317,7 @@ class SqlGrammarMSSQLServer(SqlGrammar): function_call = Combine(function_name + LPAR) + argument_list + RPAR # Not supported: index hints - # ... https://stackoverflow.com/questions/11016935/how-can-i-force-a-query-to-not-use-a-index-on-a-given-table # noqa + # ... https://stackoverflow.com/questions/11016935/how-can-i-force-a-query-to-not-use-a-index-on-a-given-table # noqa: E501 # ----------------------------------------------------------------------------- # CASE @@ -446,9 +446,7 @@ class SqlGrammarMSSQLServer(SqlGrammar): + DISTINCT + expr + RPAR - | Combine( # special aggregate function # noqa - aggregate_function + LPAR - ) + | Combine(aggregate_function + LPAR) # special aggregate function + expr + RPAR | expr @@ -641,13 +639,13 @@ def pyparsing_bugtest_delimited_list_combine(fix_problem: bool = True) -> None: word_list_combine = delimitedList(word, combine=True) print( word_list_no_combine.parseString("one, two", parseAll=True) - ) # ['one', 'two'] # noqa + ) # ['one', 'two'] print( word_list_no_combine.parseString("one,two", parseAll=True) - ) # ['one', 'two'] # noqa + ) # ['one', 'two'] print( word_list_combine.parseString("one, two", parseAll=True) - ) # ['one']: ODD ONE OUT # noqa + ) # ['one']: ODD ONE OUT print( word_list_combine.parseString("one,two", parseAll=True) - ) # ['one,two'] # noqa + ) # ['one,two'] diff --git a/cardinal_pythonlib/sql/sql_grammar_mysql.py b/cardinal_pythonlib/sql/sql_grammar_mysql.py index 8d9e1c0..cc2b9da 100644 --- a/cardinal_pythonlib/sql/sql_grammar_mysql.py +++ b/cardinal_pythonlib/sql/sql_grammar_mysql.py @@ -427,7 +427,7 @@ class SqlGrammarMySQL(SqlGrammar): # CASE # ----------------------------------------------------------------------------- # NOT THIS: https://dev.mysql.com/doc/refman/5.7/en/case.html - # THIS: https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case # noqa + # THIS: https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case # noqa: E501 case_expr = ( ( CASE @@ -477,7 +477,7 @@ class SqlGrammarMySQL(SqlGrammar): expr_term = ( INTERVAL + expr + time_unit | - # "{" + identifier + expr + "}" | # see MySQL notes; antique ODBC syntax # noqa + # "{" + identifier + expr + "}" | # see MySQL notes; antique ODBC syntax # noqa: E501 Optional(EXISTS) + LPAR + select_statement + RPAR | # ... e.g. mycol = EXISTS(SELECT ...) @@ -598,9 +598,7 @@ class SqlGrammarMySQL(SqlGrammar): + DISTINCT + expr + RPAR - | Combine( # special aggregate function # noqa - aggregate_function + LPAR - ) + | Combine(aggregate_function + LPAR) # special aggregate function + expr + RPAR | expr diff --git a/cardinal_pythonlib/sql/validation.py b/cardinal_pythonlib/sql/validation.py index 8699739..1fb040d 100644 --- a/cardinal_pythonlib/sql/validation.py +++ b/cardinal_pythonlib/sql/validation.py @@ -24,8 +24,6 @@ **Functions to check table/column names etc. for validity in SQL.** -This is a slight - """ import re @@ -41,13 +39,29 @@ # ... SQL Server is very liberal! -# - ANSI: http://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#predefined-type # noqa +# - ANSI: +# - http://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#predefined-type # noqa: E501 +# # - SQL Server: -# - https://support.microsoft.com/en-us/office/equivalent-ansi-sql-data-types-7a0a6bef-ef25-45f9-8a9a-3c5f21b5c65d # noqa -# - https://docs.microsoft.com/en-us/sql/t-sql/data-types/data-types-transact-sql?view=sql-server-ver15 # noqa +# - https://support.microsoft.com/en-us/office/equivalent-ansi-sql-data-types-7a0a6bef-ef25-45f9-8a9a-3c5f21b5c65d # noqa: E501 +# - https://docs.microsoft.com/en-us/sql/t-sql/data-types/data-types-transact-sql?view=sql-server-ver15 # noqa: E501 +# - https://learn.microsoft.com/en-us/sql/t-sql/data-types/data-types-transact-sql?view=sql-server-ver16 # noqa: E501 # - Note that ANSI "BIT" is SQL Server "BINARY". -# - MySQL: https://dev.mysql.com/doc/refman/8.0/en/data-types.html -# - PostgreSQL: https://www.postgresql.org/docs/9.5/datatype.html +# +# - MySQL: +# - https://dev.mysql.com/doc/refman/8.0/en/data-types.html +# - https://dev.mysql.com/doc/refman/9.1/en/data-types.html +# +# - PostgreSQL: +# - https://www.postgresql.org/docs/9.5/datatype.html +# +# - SQLite: +# - https://www.sqlite.org/datatype3.html +# +# - Databricks: +# - https://github.com/databricks/databricks-sqlalchemy + +SQLTYPE_DATE = "DATE" # ANSI SQLTYPES_INTEGER = ( "BIGINT", # ANSI @@ -71,6 +85,12 @@ "SMALLSERIAL", # PostgreSQL "TINYINT", # SQL Server, MySQL ) +SQLTYPES_BIT = ( + "BIT VARYING", # ANSI + "BIT", # ANSI + "BOOL", # MySQL synonym for BOOLEAN or TINYINT(1) + "BOOLEAN", # ANSI +) SQLTYPES_FLOAT = ( "DOUBLE PRECISION", # ANSI (8 bytes) "DOUBLE", # SQL Server, MySQL; synonym for DOUBLE PRECISION @@ -84,16 +104,13 @@ "SINGLE", # SQL Server ) SQLTYPES_OTHER_NUMERIC = ( - "BIT VARYING", # ANSI - "BIT", # ANSI - "BOOL", # MySQL synonym for BOOLEAN or TINYINT(1) - "BOOLEAN", # ANSI "DEC", # ANSI; synonym for DECIMAL "DECIMAL", # ANSI "FIXED", # MySQL; synonym for DECIMAL "LOGICAL", # SQL Server "LOGICAL1", # SQL Server "NUMERIC", # ANSI; synonym for DECIMAL + "SMALLMONEY", # SQL Server "ROWVERSION", # SQL Server "VARBIT", # PostgreSQL synonym for BIT VARYING "YESNO", # SQL Server @@ -125,8 +142,8 @@ "NTEXT", # SQL Server "NVARCHAR", # SQL Server "SET", # MySQL - "STRING", # SQL Server - "TEXT", # SQL Server, MySQL + "STRING", # SQL Server, Databricks + "TEXT", # SQL Server, MySQL, SQLite "TINYTEXT", # MySQL "VARCHAR", # ANSI ) @@ -146,12 +163,13 @@ "VARBINARY", # ANSI ) SQLTYPES_WITH_DATE = ( - "DATE", # ANSI - "DATETIME", # SQL Server, MySQL + SQLTYPE_DATE, # ANSI + "DATETIME", # SQL Server, MySQL, most "DATETIME2", # SQL Server "DATETIMEOFFSET", # SQL Server (date + time + time zone) "SMALLDATETIME", # SQL Server "TIMESTAMP", # ANSI + "TIMESTAMP_NTZ", # Databricks ) SQLTYPES_DATETIME_OTHER = ( "INTERVAL", # ANSI (not always supported); PostgreSQL diff --git a/cardinal_pythonlib/sqlalchemy/alembic_func.py b/cardinal_pythonlib/sqlalchemy/alembic_func.py index 7d54dac..9906626 100644 --- a/cardinal_pythonlib/sqlalchemy/alembic_func.py +++ b/cardinal_pythonlib/sqlalchemy/alembic_func.py @@ -48,7 +48,7 @@ # Constants for Alembic # ============================================================================= # https://alembic.readthedocs.org/en/latest/naming.html -# http://docs.sqlalchemy.org/en/latest/core/constraints.html#configuring-constraint-naming-conventions # noqa +# http://docs.sqlalchemy.org/en/latest/core/constraints.html#configuring-constraint-naming-conventions # noqa: E501 ALEMBIC_NAMING_CONVENTION = { "ix": "ix_%(column_0_label)s", @@ -66,7 +66,7 @@ # ============================================================================= # Alembic revision/migration system # ============================================================================= -# https://stackoverflow.com/questions/24622170/using-alembic-api-from-inside-application-code # noqa +# https://stackoverflow.com/questions/24622170/using-alembic-api-from-inside-application-code # noqa: E501 def get_head_revision_from_alembic( @@ -103,7 +103,7 @@ def get_current_revision( database_url: SQLAlchemy URL for the database version_table: table name for Alembic versions """ - engine = create_engine(database_url) + engine = create_engine(database_url, future=True) conn = engine.connect() opts = {"version_table": version_table} mig_context = MigrationContext.configure(conn, opts=opts) @@ -334,7 +334,7 @@ def create_database_migration_numbered_style( message: message to be associated with this revision n_sequence_chars: number of numerical sequence characters to use in the filename/revision (see above). - """ # noqa + """ # noqa: E501 file_regex = r"\d{" + str(n_sequence_chars) + r"}_\S*\.py$" _, _, existing_version_filenames = next( @@ -403,7 +403,7 @@ def stamp_allowing_unusual_version_table( This function is a clone of ``alembic.command.stamp()``, but allowing ``version_table`` to change. See https://alembic.zzzcomputing.com/en/latest/api/commands.html#alembic.command.stamp - """ # noqa + """ script = ScriptDirectory.from_config(config) diff --git a/cardinal_pythonlib/sqlalchemy/alembic_ops.py b/cardinal_pythonlib/sqlalchemy/alembic_ops.py index a023dd8..94ea22d 100644 --- a/cardinal_pythonlib/sqlalchemy/alembic_ops.py +++ b/cardinal_pythonlib/sqlalchemy/alembic_ops.py @@ -62,7 +62,7 @@ def __init__(self, name, sqltext): config C INNER JOIN session S ON S.config_id = C.config_id - """ # noqa + """ # noqa: E501 self.name = name self.sqltext = sqltext diff --git a/cardinal_pythonlib/sqlalchemy/core_query.py b/cardinal_pythonlib/sqlalchemy/core_query.py index adeec37..8b37dc8 100644 --- a/cardinal_pythonlib/sqlalchemy/core_query.py +++ b/cardinal_pythonlib/sqlalchemy/core_query.py @@ -26,25 +26,25 @@ """ -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from sqlalchemy.engine.base import Connection, Engine -from sqlalchemy.engine import CursorResult +from sqlalchemy.engine.row import Row from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import ( + case, column, exists, func, - literal, select, table, + text, ) from sqlalchemy.sql.schema import Table -from sqlalchemy.sql.selectable import Select +from sqlalchemy.sql.selectable import Select, TableClause from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler -from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName log = get_brace_style_log_with_null_handler(__name__) @@ -56,26 +56,111 @@ def get_rows_fieldnames_from_raw_sql( session: Union[Session, Engine, Connection], sql: str -) -> Tuple[Sequence[Sequence[Any]], Sequence[str]]: +) -> Tuple[List[Row], List[str]]: """ Returns results and column names from a query. Args: - session: SQLAlchemy :class:`Session`, :class:`Engine`, or - :class:`Connection` object - sql: raw SQL to execure + session: + SQLAlchemy :class:`Session`, :class:`Engine` (SQL Alchemy 1.4 + only), or :class:`Connection` object + sql: + raw SQL to execure Returns: ``(rows, fieldnames)`` where ``rows`` is the usual set of results and ``fieldnames`` are the name of the result columns/fields. """ - result = session.execute(sql) # type: CursorResult + if not isinstance(sql, str): + raise ValueError("sql argument must be a string") + result = session.execute(text(sql)) fieldnames = result.keys() rows = result.fetchall() return rows, fieldnames +def get_rows_fieldnames_from_select( + session: Union[Session, Engine, Connection], select_query: Select +) -> Tuple[List[Row], List[str]]: + """ + Returns results and column names from a query. + + Args: + session: + SQLAlchemy :class:`Session`, :class:`Engine` (SQL Alchemy 1.4 + only), or :class:`Connection` object + select_query: + select() statement, i.e. instance of + :class:`sqlalchemy.sql.selectable.Select` + + Returns: + ``(rows, fieldnames)`` where ``rows`` is the usual set of results and + ``fieldnames`` are the name of the result columns/fields. + + """ + if not isinstance(select_query, Select): + raise ValueError("select_query argument must be a select() statement") + + # Check that the user is not querying an ORM *class* rather than columns. + # It doesn't make much sense to use this function in that case. + # If Pet is an ORM class (see unit tests!), then: + # + # - select(Pet).column_descriptions: + # + # [{'name': 'Pet', 'type': , 'aliased': + # False, 'expr': , 'entity': }] + # + # "entity" matches "type"; this is the one we want to disallow + # + # - select(Pet.id, Pet.name).column_descriptions: + # + # [{'name': 'id', 'type': Integer(), 'aliased': False, 'expr': + # , 'entity': }, {'name': + # 'name', 'type': String(length=50), 'aliased': False, 'expr': + # , 'entity': }] + # + # ... "entity" differs from "type" + # + # - select(sometable.a, sometable.b).column_descriptions: + # + # [{'name': 'a', 'type': INTEGER(), 'expr': Column('a', INTEGER(), + # table=, primary_key=True)}, {'name': 'b', 'type': INTEGER(), 'expr': + # Column('b', INTEGER(), table=)}] + # + # ... no "entity" key. + # + # Therefore: + for cd in select_query.column_descriptions: + # https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.Select.column_descriptions # noqa: E501 + # For + if "entity" not in cd: + continue + if cd["type"] == cd["entity"]: + raise ValueError( + "It looks like your select() query is querying whole ORM " + "object classes, not just columns or column-like " + "expressions. Its column_descriptions are: " + f"{select_query.column_descriptions}" + ) + + result = session.execute(select_query) + + fieldnames_rmkview = result.keys() + # ... of type RMKeyView, e.g. RMKeyView(['a', 'b']) + fieldnames = [x for x in fieldnames_rmkview] + + rows = result.fetchall() + + # I don't know how to differentiate select(Pet), selecting an ORM class, + # from select(Pet.name), selecting a column. + + return rows, fieldnames + + # ============================================================================= # SELECT COUNT(*) (SQLAlchemy Core) # ============================================================================= @@ -90,8 +175,9 @@ def count_star( additional ``WHERE`` criteria if desired). Args: - session: SQLAlchemy :class:`Session`, :class:`Engine`, or - :class:`Connection` object + session: + SQLAlchemy :class:`Session`, :class:`Engine` (SQL Alchemy 1.4 + only), or :class:`Connection` object tablename: name of the table criteria: optional SQLAlchemy "where" criteria @@ -100,7 +186,7 @@ def count_star( """ # works if you pass a connection or a session or an engine; all have # the execute() method - query = select([func.count()]).select_from(table(tablename)) + query = select(func.count()).select_from(table(tablename)) for criterion in criteria: query = query.where(criterion) return session.execute(query).scalar() @@ -115,13 +201,14 @@ def count_star_and_max( session: Union[Session, Engine, Connection], tablename: str, maxfield: str, - *criteria: Any + *criteria: Any, ) -> Tuple[int, Optional[int]]: """ Args: - session: SQLAlchemy :class:`Session`, :class:`Engine`, or - :class:`Connection` object + session: + SQLAlchemy :class:`Session`, :class:`Engine` (SQL Alchemy 1.4 + only), or :class:`Connection` object tablename: name of the table maxfield: name of column (field) to take the ``MAX()`` of criteria: optional SQLAlchemy "where" criteria @@ -130,13 +217,14 @@ def count_star_and_max( a tuple: ``(count, maximum)`` """ - query = select([func.count(), func.max(column(maxfield))]).select_from( + query = select(func.count(), func.max(column(maxfield))).select_from( table(tablename) ) for criterion in criteria: query = query.where(criterion) result = session.execute(query) - return result.fetchone() # count, maximum + count, maximum = result.fetchone() + return count, maximum # ============================================================================= @@ -146,15 +234,18 @@ def count_star_and_max( # http://docs.sqlalchemy.org/en/latest/orm/query.html -def exists_in_table(session: Session, table_: Table, *criteria: Any) -> bool: +def exists_in_table( + session: Session, table_: Union[Table, TableClause], *criteria: Any +) -> bool: """ Implements an efficient way of detecting if a record or records exist; should be faster than ``COUNT(*)`` in some circumstances. Args: - session: SQLAlchemy :class:`Session`, :class:`Engine`, or - :class:`Connection` object - table_: SQLAlchemy :class:`Table` object + session: + SQLAlchemy :class:`Session`, :class:`Engine` (SQL Alchemy 1.4 + only), or :class:`Connection` object + table_: SQLAlchemy :class:`Table` object or table clause criteria: optional SQLAlchemy "where" criteria Returns: @@ -175,15 +266,70 @@ def exists_in_table(session: Session, table_: Table, *criteria: Any) -> bool: exists_clause = exists_clause.where(criterion) # ... EXISTS (SELECT * FROM tablename WHERE ...) - if session.get_bind().dialect.name == SqlaDialectName.MSSQL: - query = select([literal(True)]).where(exists_clause) - # ... SELECT 1 WHERE EXISTS (SELECT * FROM tablename WHERE ...) - else: - query = select([exists_clause]) - # ... SELECT EXISTS (SELECT * FROM tablename WHERE ...) - + # Methods as follows. + # SQL validation: http://developer.mimer.com/validator/ + # Standard syntax: https://en.wikipedia.org/wiki/SQL_syntax + # We can make it conditional on dialect via + # session.get_bind().dialect.name + # but it would be better not to need to. + # + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # SELECT 1 FROM mytable WHERE EXISTS (SELECT * FROM mytable WHERE ...) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # - Produces multiple results (a 1 for each row). + # + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # SELECT 1 WHERE EXISTS (SELECT * FROM tablename WHERE ...) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # - Produces either 1 or NULL (no rows). + # - Implementation: + # + # query = select(literal(True)).where(exists_clause) + # result = session.execute(query).scalar() + # return bool(result) # None/0 become False; 1 becomes True + # + # - However, may be non-standard: no FROM clause. + # - Works on SQL Server (empirically). + # + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # SELECT EXISTS (SELECT * FROM tablename WHERE ...) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # - Produces 0 or 1. + # - Implementation: + # + # query = select(exists_clause) + # result = session.execute(query).scalar() + # return bool(result) + # + # - But it may not be standard. + # + # - Supported by MySQL: + # - https://dev.mysql.com/doc/refman/8.4/en/exists-and-not-exists-subqueries.html # noqa: E501 + # - and an empirical test + # + # Suported by SQLite: + # - https://www.sqlite.org/lang_expr.html#the_exists_operator + # - and an empirical test + # + # Possibly not SQL Server. + # + # Possibly not Databricks. + # - https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select.html # noqa: E501 + # - https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-where.html # noqa: E501 + # + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # SELECT CASE WHEN EXISTS(SELECT * FROM tablename WHERE...) THEN 0 ELSE 1 END # noqa: E501 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # - ANSI standard. + # - https://stackoverflow.com/questions/17284688/how-to-efficiently-check-if-a-table-is-empty # noqa: E501 + # - Returns 0 or 1. + # - May be possible to use "SELECT 1 FROM tablename" also, but unclear + # what's faster, and likely EXISTS() should optimise. + # - Implementation as below. + + query = select(case((exists_clause, 1), else_=0)) result = session.execute(query).scalar() - return bool(result) + return bool(result) # None/0 become False; 1 becomes True def exists_plain(session: Session, tablename: str, *criteria: Any) -> bool: @@ -192,8 +338,9 @@ def exists_plain(session: Session, tablename: str, *criteria: Any) -> bool: should be faster than COUNT(*) in some circumstances. Args: - session: SQLAlchemy :class:`Session`, :class:`Engine`, or - :class:`Connection` object + session: + SQLAlchemy :class:`Session`, :class:`Engine` (SQL Alchemy 1.4 + only), or :class:`Connection` object tablename: name of the table criteria: optional SQLAlchemy "where" criteria @@ -236,7 +383,7 @@ def fetch_all_first_values( a list of the first value of each result row """ - rows = session.execute(select_statement) # type: CursorResult + rows = session.execute(select_statement) try: return [row[0] for row in rows] except ValueError as e: diff --git a/cardinal_pythonlib/sqlalchemy/dialect.py b/cardinal_pythonlib/sqlalchemy/dialect.py index 2fde813..76f6c0a 100644 --- a/cardinal_pythonlib/sqlalchemy/dialect.py +++ b/cardinal_pythonlib/sqlalchemy/dialect.py @@ -28,9 +28,9 @@ from typing import Union -# noinspection PyProtectedMember -from sqlalchemy.engine import create_engine, Engine +from sqlalchemy.engine import Connection, create_engine, Engine from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.orm.session import Session from sqlalchemy.sql.compiler import IdentifierPreparer, SQLCompiler @@ -44,6 +44,8 @@ class SqlaDialectName(object): Dialect names used by SQLAlchemy. """ + # SQLAlchemy itself: + FIREBIRD = "firebird" MYSQL = "mysql" MSSQL = "mssql" @@ -53,6 +55,15 @@ class SqlaDialectName(object): SQLSERVER = MSSQL # synonym SYBASE = "sybase" + # Additional third-party dialects: + # - https://docs.sqlalchemy.org/en/20/dialects/ + # Interface: + # - https://docs.sqlalchemy.org/en/20/core/internals.html#sqlalchemy.engine.Dialect # noqa: E501 + + DATABRICKS = "databricks" + # ... https://github.com/databricks/databricks-sqlalchemy + # ... https://docs.databricks.com/en/sql/language-manual/index.html + ALL_SQLA_DIALECTS = list( set( @@ -70,13 +81,16 @@ class SqlaDialectName(object): # ============================================================================= -def get_dialect(mixed: Union[SQLCompiler, Engine, Dialect]) -> Dialect: +def get_dialect( + mixed: Union[Engine, Dialect, Session, SQLCompiler] +) -> Union[Dialect, type(Dialect)]: """ Finds the SQLAlchemy dialect in use. Args: - mixed: an SQLAlchemy :class:`SQLCompiler`, :class:`Engine`, or - :class:`Dialect` object + mixed: + An SQLAlchemy engine, bound session, SQLCompiler, or Dialect + object. Returns: the SQLAlchemy :class:`Dialect` being used @@ -85,19 +99,30 @@ def get_dialect(mixed: Union[SQLCompiler, Engine, Dialect]) -> Dialect: return mixed elif isinstance(mixed, Engine): return mixed.dialect + elif isinstance(mixed, Session): + if mixed.bind is None: + raise ValueError("get_dialect: parameter is an unbound session") + bind = mixed.bind + assert isinstance(bind, (Engine, Connection)) + return bind.dialect elif isinstance(mixed, SQLCompiler): return mixed.dialect else: - raise ValueError("get_dialect: 'mixed' parameter of wrong type") + raise ValueError( + f"get_dialect: 'mixed' parameter of wrong type: {mixed!r}" + ) -def get_dialect_name(mixed: Union[SQLCompiler, Engine, Dialect]) -> str: +def get_dialect_name( + mixed: Union[Engine, Dialect, Session, SQLCompiler] +) -> str: """ Finds the name of the SQLAlchemy dialect in use. Args: - mixed: an SQLAlchemy :class:`SQLCompiler`, :class:`Engine`, or - :class:`Dialect` object + mixed: + An SQLAlchemy engine, bound session, SQLCompiler, or Dialect + object. Returns: the SQLAlchemy dialect name being used """ @@ -107,15 +132,16 @@ def get_dialect_name(mixed: Union[SQLCompiler, Engine, Dialect]) -> str: def get_preparer( - mixed: Union[SQLCompiler, Engine, Dialect] + mixed: Union[Engine, Dialect, Session, SQLCompiler] ) -> IdentifierPreparer: """ Returns the SQLAlchemy :class:`IdentifierPreparer` in use for the dialect being used. Args: - mixed: an SQLAlchemy :class:`SQLCompiler`, :class:`Engine`, or - :class:`Dialect` object + mixed: + An SQLAlchemy engine, bound session, SQLCompiler, or Dialect + object. Returns: an :class:`IdentifierPreparer` @@ -126,7 +152,7 @@ def get_preparer( def quote_identifier( - identifier: str, mixed: Union[SQLCompiler, Engine, Dialect] + identifier: str, mixed: Union[Engine, Dialect, Session, SQLCompiler] ) -> str: """ Converts an SQL identifier to a quoted version, via the SQL dialect in @@ -134,14 +160,15 @@ def quote_identifier( Args: identifier: the identifier to be quoted - mixed: an SQLAlchemy :class:`SQLCompiler`, :class:`Engine`, or - :class:`Dialect` object + mixed: + An SQLAlchemy engine, bound session, SQLCompiler, or Dialect + object. Returns: the quoted identifier """ - # See also http://sqlalchemy-utils.readthedocs.io/en/latest/_modules/sqlalchemy_utils/functions/orm.html # noqa + # See also http://sqlalchemy-utils.readthedocs.io/en/latest/_modules/sqlalchemy_utils/functions/orm.html # noqa: E501 return get_preparer(mixed).quote(identifier) @@ -155,6 +182,9 @@ def null_executor(querysql, *multiparams, **params): pass engine = create_engine( - f"{dialect_name}://", strategy="mock", executor=null_executor + f"{dialect_name}://", + strategy="mock", + executor=null_executor, + future=True, ) return engine.dialect diff --git a/cardinal_pythonlib/sqlalchemy/dump.py b/cardinal_pythonlib/sqlalchemy/dump.py index c398095..8aec710 100644 --- a/cardinal_pythonlib/sqlalchemy/dump.py +++ b/cardinal_pythonlib/sqlalchemy/dump.py @@ -33,7 +33,6 @@ import pendulum -# noinspection PyProtectedMember from sqlalchemy.engine import Connectable, create_mock_engine from sqlalchemy.engine.base import Engine from sqlalchemy.engine.default import DefaultDialect @@ -53,11 +52,12 @@ from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName from cardinal_pythonlib.sqlalchemy.orm_inspect import walk_orm_tree from cardinal_pythonlib.sqlalchemy.schema import get_table_names +from cardinal_pythonlib.sqlalchemy.session import get_safe_url_from_engine log = get_brace_style_log_with_null_handler(__name__) -SEP1 = sql_comment("=" * 76) -SEP2 = sql_comment("-" * 76) +COMMENT_SEP1 = sql_comment("=" * 76) +COMMENT_SEP2 = sql_comment("-" * 76) # ============================================================================= @@ -67,7 +67,7 @@ def dump_connection_info(engine: Engine, fileobj: TextIO = sys.stdout) -> None: """ - Dumps some connection info, as an SQL comment. Obscures passwords. + Dumps the engine's connection info, as an SQL comment. Obscures passwords. Args: engine: the SQLAlchemy :class:`Engine` to dump metadata information @@ -75,8 +75,8 @@ def dump_connection_info(engine: Engine, fileobj: TextIO = sys.stdout) -> None: fileobj: the file-like object (default ``sys.stdout``) to write information to """ - meta = MetaData(bind=engine) - writeline_nl(fileobj, sql_comment(f"Database info: {meta}")) + url = get_safe_url_from_engine(engine) + writeline_nl(fileobj, sql_comment(f"Database info: {url}")) def dump_ddl( @@ -97,8 +97,8 @@ def dump_ddl( equivalent. """ - # http://docs.sqlalchemy.org/en/rel_0_8/faq.html#how-can-i-get-the-create-table-drop-table-output-as-a-string # noqa - # https://stackoverflow.com/questions/870925/how-to-generate-a-file-with-ddl-in-the-engines-sql-dialect-in-sqlalchemy # noqa + # http://docs.sqlalchemy.org/en/rel_0_8/faq.html#how-can-i-get-the-create-table-drop-table-output-as-a-string # noqa: E501 + # https://stackoverflow.com/questions/870925/how-to-generate-a-file-with-ddl-in-the-engines-sql-dialect-in-sqlalchemy # noqa: E501 # https://github.com/plq/scripts/blob/master/pg_dump.py # noinspection PyUnusedLocal def dump(querysql, *multiparams, **params): @@ -135,17 +135,17 @@ def dump_table_as_insert_sql( include_ddl: if ``True``, include the DDL to create the table as well multirow: write multi-row ``INSERT`` statements """ - # https://stackoverflow.com/questions/5631078/sqlalchemy-print-the-actual-query # noqa + # https://stackoverflow.com/questions/5631078/sqlalchemy-print-the-actual-query # noqa: E501 # http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html - # http://www.tylerlesmann.com/2009/apr/27/copying-databases-across-platforms-sqlalchemy/ # noqa + # http://www.tylerlesmann.com/2009/apr/27/copying-databases-across-platforms-sqlalchemy/ # noqa: E501 # https://github.com/plq/scripts/blob/master/pg_dump.py log.info("dump_data_as_insert_sql: table_name={}", table_name) writelines_nl( fileobj, [ - SEP1, + COMMENT_SEP1, sql_comment(f"Data for table: {table_name}"), - SEP2, + COMMENT_SEP2, sql_comment(f"Filters: {wheredict}"), ], ) @@ -153,7 +153,7 @@ def dump_table_as_insert_sql( dialect = engine.dialect # type: DefaultDialect # "supports_multivalues_insert" is part of DefaultDialect, but not Dialect # -- nevertheless, it should be there: - # https://docs.sqlalchemy.org/en/20/core/internals.html#sqlalchemy.engine.default.DefaultDialect.supports_multivalues_insert # noqa + # https://docs.sqlalchemy.org/en/20/core/internals.html#sqlalchemy.engine.default.DefaultDialect.supports_multivalues_insert # noqa: E501 if not dialect.supports_multivalues_insert: multirow = False if multirow: @@ -163,59 +163,46 @@ def dump_table_as_insert_sql( ) multirow = False - # literal_query = make_literal_query_fn(dialect) - - meta = MetaData(bind=engine) + meta = MetaData() log.debug("... retrieving schema") - table = Table(table_name, meta, autoload=True) + table = Table(table_name, meta, autoload_with=engine) if include_ddl: log.debug("... producing DDL") # noinspection PyUnresolvedReferences dump_ddl( - table.metadata, dialect_name=engine.dialect.name, fileobj=fileobj + metadata=table.metadata, + dialect_name=engine.dialect.name, + fileobj=fileobj, ) - # NewRecord = quick_mapper(table) - # columns = table.columns.keys() log.debug("... fetching records") - # log.debug("meta: {}", meta) # obscures password - # log.debug("table: {}", table) - # log.debug("table.columns: {!r}", table.columns) - # log.debug("multirow: {}", multirow) - query = select(table.columns) + query = select(*table.columns) if wheredict: for k, v in wheredict.items(): col = table.columns.get(k) query = query.where(col == v) - # log.debug("query: {}", query) - cursor = engine.execute(query) - if multirow: - row_dict_list = [] - for r in cursor: - row_dict_list.append(dict(r)) - # log.debug("row_dict_list: {}", row_dict_list) - if row_dict_list: - statement = table.insert().values(row_dict_list) - # log.debug("statement: {!r}", statement) - # insert_str = literal_query(statement) - insert_str = get_literal_query(statement, bind=engine) - # NOT WORKING FOR MULTIROW INSERTS. ONLY SUBSTITUTES FIRST ROW. - writeline_nl(fileobj, insert_str) + with engine.begin() as connection: + cursor = connection.execute(query) + if multirow: + row_dict_list = [] + for r in cursor.mappings(): + row_dict_list.append(dict(r)) + if row_dict_list: + statement = table.insert().values(row_dict_list) + insert_str = get_literal_query(statement, bind=engine) + # NOT WORKING FOR MULTIROW INSERTS. ONLY SUBSTITUTES FIRST ROW. + writeline_nl(fileobj, insert_str) + else: + writeline_nl(fileobj, sql_comment("No data!")) else: - writeline_nl(fileobj, sql_comment("No data!")) - else: - found_one = False - for r in cursor: - found_one = True - row_dict = dict(r) - statement = table.insert(values=row_dict) - # insert_str = literal_query(statement) - insert_str = get_literal_query(statement, bind=engine) - # log.debug("row_dict: {}", row_dict) - # log.debug("insert_str: {}", insert_str) - writeline_nl(fileobj, insert_str) - if not found_one: - writeline_nl(fileobj, sql_comment("No data!")) - writeline_nl(fileobj, SEP2) + found_one = False + for r in cursor.mappings(): + found_one = True + statement = table.insert().values(dict(r)) + insert_str = get_literal_query(statement, bind=engine) + writeline_nl(fileobj, insert_str) + if not found_one: + writeline_nl(fileobj, sql_comment("No data!")) + writeline_nl(fileobj, COMMENT_SEP2) log.debug("... done") @@ -260,9 +247,9 @@ def dump_orm_object_as_insert_sql( # literal_query = make_literal_query_fn(engine.dialect) insp = inspect(obj) # insp: an InstanceState - # http://docs.sqlalchemy.org/en/latest/orm/internals.html#sqlalchemy.orm.state.InstanceState # noqa + # http://docs.sqlalchemy.org/en/latest/orm/internals.html#sqlalchemy.orm.state.InstanceState # noqa: E501 # insp.mapper: a Mapper - # http://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper # noqa + # http://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper # noqa: E501 # Don't do this: # table = insp.mapper.mapped_table @@ -271,26 +258,18 @@ def dump_orm_object_as_insert_sql( # from the database itself. meta = MetaData(bind=engine) table_name = insp.mapper.mapped_table.name - # log.debug("table_name: {}", table_name) table = Table(table_name, meta, autoload=True) - # log.debug("table: {}", table) - # NewRecord = quick_mapper(table) - # columns = table.columns.keys() query = select(table.columns) - # log.debug("query: {}", query) for orm_pkcol in insp.mapper.primary_key: core_pkcol = table.columns.get(orm_pkcol.name) pkval = getattr(obj, orm_pkcol.name) query = query.where(core_pkcol == pkval) - # log.debug("query: {}", query) - cursor = engine.execute(query) - row = cursor.fetchone() # should only be one... + with engine.begin() as connection: + cursor = connection.execute(query) + row = cursor.fetchone() # should only be one... row_dict = dict(row) - # log.debug("obj: {}", obj) - # log.debug("row_dict: {}", row_dict) statement = table.insert(values=row_dict) - # insert_str = literal_query(statement) insert_str = get_literal_query(statement, bind=engine) writeline_nl(fileobj, insert_str) @@ -311,13 +290,13 @@ def dump_orm_tree_as_insert_sql( - MySQL/InnoDB doesn't wait to the end of a transaction to check FK integrity (which it should): - https://stackoverflow.com/questions/5014700/in-mysql-can-i-defer-referential-integrity-checks-until-commit # noqa + https://stackoverflow.com/questions/5014700/in-mysql-can-i-defer-referential-integrity-checks-until-commit # noqa: E501 - PostgreSQL can. - Anyway, slightly ugly hacks... https://dev.mysql.com/doc/refman/5.5/en/optimizing-innodb-bulk-data-loading.html - Not so obvious how we can iterate through the list of ORM objects and guarantee correct insertion order with respect to all FKs. - """ # noqa + """ writeline_nl( fileobj, sql_comment("Data for all objects related to the first below:"), @@ -349,7 +328,7 @@ def quick_mapper(table: Table) -> Type[DeclarativeMeta]: Returns: a :class:`DeclarativeMeta` class - """ # noqa + """ # noinspection PyPep8Naming Base = declarative_base() @@ -369,7 +348,8 @@ def bulk_insert_extras( dialect_name: str, fileobj: TextIO, start: bool ) -> None: """ - Writes bulk ``INSERT`` preamble (start=True) or end (start=False). + Writes bulk ``INSERT`` preamble (start=True) or end (start=False) to our + text file. For MySQL, this temporarily switches off autocommit behaviour and index/FK checks, for speed, then re-enables them at the end and commits. @@ -407,22 +387,32 @@ def bulk_insert_extras( class StringLiteral(String): """ - Teach SQLAlchemy how to literalize various things. Used by - `make_literal_query_fn`, below. See - https://stackoverflow.com/questions/5631078/sqlalchemy-print-the-actual-query + Teach sqlalchemy how to literalize various things. Used by + make_literal_query_fn, below. See + https://stackoverflow.com/questions/5631078. """ + def __init__(self, *args, **kwargs) -> None: + """ + This __init__() function exists purely because the docstring in the + SQLAlchemy superclass (String) has "pycon+sql" as its source type, + which Sphinx warns about. + """ + super().__init__(*args, **kwargs) + def literal_processor( self, dialect: DefaultDialect ) -> Callable[[Any], str]: - """Returns a function to translate any value to a string.""" - # Docstring above necessary to stop sphinx build error: - # undefined label: types_typedecorator - + """ + Returns a function to translate any value to a string. + """ super_processor = super().literal_processor(dialect) def process(value: Any) -> str: - log.debug("process: {!r}", value) + """ + Translate any value to a string. + """ + # log.debug("process: {!r}", value) if isinstance(value, int): return str(value) if not isinstance(value, str): @@ -453,7 +443,7 @@ class LiteralDialect(DialectClass): override the encode of various kinds of data to literal values. """ - # https://stackoverflow.com/questions/5631078/sqlalchemy-print-the-actual-query # noqa + # https://stackoverflow.com/questions/5631078/sqlalchemy-print-the-actual-query # noqa: E501 colspecs = { # prevent various encoding explosions String: StringLiteral, @@ -468,7 +458,7 @@ def literal_query(statement: Union[ClauseElement, Query]) -> str: Produce an SQL query with literal values. NOTE: This is entirely insecure. DO NOT execute the resulting strings. """ - # https://stackoverflow.com/questions/5631078/sqlalchemy-print-the-actual-query # noqa + # https://stackoverflow.com/questions/5631078/sqlalchemy-print-the-actual-query # noqa: E501 if isinstance(statement, Query): statement = statement.statement return ( @@ -511,7 +501,7 @@ def get_literal_query( Returns: a string literal version of the query. - """ # noqa + """ # log.debug("statement: {!r}", statement) # log.debug("statement.bind: {!r}", statement.bind) if isinstance(statement, Query): diff --git a/cardinal_pythonlib/sqlalchemy/engine_func.py b/cardinal_pythonlib/sqlalchemy/engine_func.py index f1111cd..831fcc5 100644 --- a/cardinal_pythonlib/sqlalchemy/engine_func.py +++ b/cardinal_pythonlib/sqlalchemy/engine_func.py @@ -64,7 +64,7 @@ def is_sqlserver(engine: "Engine") -> bool: return dialect_name == SqlaDialectName.SQLSERVER -def get_sqlserver_product_version(engine: "Engine") -> Tuple[int]: +def get_sqlserver_product_version(engine: "Engine") -> Tuple[int, ...]: """ Gets SQL Server version information. @@ -75,7 +75,7 @@ def get_sqlserver_product_version(engine: "Engine") -> Tuple[int]: from sqlalchemy import create_engine url = "mssql+pyodbc://USER:PASSWORD@ODBC_NAME" - engine = create_engine(url) + engine = create_engine(url, future=True) dialect = engine.dialect vi = dialect.server_version_info @@ -104,13 +104,14 @@ def get_sqlserver_product_version(engine: "Engine") -> Tuple[int]: "instances." ) sql = "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)" - rp = engine.execute(sql) # type: Result - row = rp.fetchone() + with engine.begin() as connection: + rp = connection.execute(sql) # type: Result + row = rp.fetchone() dotted_version = row[0] # type: str # e.g. '12.0.5203.0' return tuple(int(x) for x in dotted_version.split(".")) -# https://www.mssqltips.com/sqlservertip/1140/how-to-tell-what-sql-server-version-you-are-running/ # noqa +# https://www.mssqltips.com/sqlservertip/1140/how-to-tell-what-sql-server-version-you-are-running/ # noqa: E501 SQLSERVER_MAJOR_VERSION_2000 = 8 SQLSERVER_MAJOR_VERSION_2005 = 9 SQLSERVER_MAJOR_VERSION_2008 = 10 @@ -129,3 +130,16 @@ def is_sqlserver_2008_or_later(engine: "Engine") -> bool: return False version_tuple = get_sqlserver_product_version(engine) return version_tuple >= (SQLSERVER_MAJOR_VERSION_2008,) + + +# ============================================================================= +# Helper functions for Databricks +# ============================================================================= + + +def is_databricks(engine: "Engine") -> bool: + """ + Is the SQLAlchemy :class:`Engine` a Databricks database? + """ + dialect_name = get_dialect_name(engine) + return dialect_name == SqlaDialectName.DATABRICKS diff --git a/cardinal_pythonlib/sqlalchemy/insert_on_duplicate.py b/cardinal_pythonlib/sqlalchemy/insert_on_duplicate.py index a17b376..6057505 100644 --- a/cardinal_pythonlib/sqlalchemy/insert_on_duplicate.py +++ b/cardinal_pythonlib/sqlalchemy/insert_on_duplicate.py @@ -24,6 +24,9 @@ **Add "INSERT ON DUPLICATE KEY UPDATE" functionality to SQLAlchemy for MySQL.** +OLD VERSION (before SQLAlchemy 1.4/future=True or SQLAlchemy 2.0): +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + - https://www.reddit.com/r/Python/comments/p5grh/sqlalchemy_whats_the_idiomatic_way_of_writing/ - https://github.com/bedwards/sqlalchemy_mysql_ext/blob/master/duplicate.py ... modified @@ -38,159 +41,119 @@ q = sqla_table.insert_on_duplicate().values(destvalues) session.execute(q) -**Note: superseded by SQLAlchemy v1.2:** +**Then: this partly superseded by SQLAlchemy v1.2:** - https://docs.sqlalchemy.org/en/latest/changelog/migration_12.html - https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql-insert-on-duplicate-key-update -""" # noqa - -import re -from typing import Any - -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.compiler import SQLCompiler -from sqlalchemy.sql.expression import Insert, TableClause - -from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler -from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName - -log = get_brace_style_log_with_null_handler(__name__) +FOR SQLAlchemy 1.4/future=True OR SQLAlchemy 2.0: +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# noinspection PyAbstractClass -class InsertOnDuplicate(Insert): - """ - Class that derives from :class:`Insert`, so we can hook in to operations - involving it. - """ +New function: insert_with_upsert_if_supported(). - pass +""" # noqa: E501 +# ============================================================================= +# Imports +# ============================================================================= -def insert_on_duplicate( - tablename: str, values: Any = None, inline: bool = False, **kwargs -): - """ - Command to produce an :class:`InsertOnDuplicate` object. +import logging +from typing import Dict - Args: - tablename: name of the table - values: values to ``INSERT`` - inline: as per - https://docs.sqlalchemy.org/en/latest/core/dml.html#sqlalchemy.sql.expression.insert - kwargs: additional parameters +from cardinal_pythonlib.sqlalchemy.dialect import get_dialect_name +from sqlalchemy.dialects.mysql import insert as insert_mysql +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.schema import Table +from sqlalchemy.orm.session import Session +from sqlalchemy.sql.expression import Insert - Returns: - an :class:`InsertOnDuplicate` object +from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName - """ # noqa - return InsertOnDuplicate(tablename, values, inline=inline, **kwargs) +log = logging.getLogger(__name__) -# noinspection PyPep8Naming -def monkeypatch_TableClause() -> None: - """ - Modifies :class:`sqlalchemy.sql.expression.TableClause` to insert - a ``insert_on_duplicate`` member that is our :func:`insert_on_duplicate` - function as above. - """ - log.debug( - "Adding 'INSERT ON DUPLICATE KEY UPDATE' support for MySQL " - "to SQLAlchemy" - ) - TableClause.insert_on_duplicate = insert_on_duplicate +# ============================================================================= +# insert_with_upsert_if_supported +# ============================================================================= -# noinspection PyPep8Naming -def unmonkeypatch_TableClause() -> None: +def insert_with_upsert_if_supported( + table: Table, + values: Dict, + session: Session = None, + dialect: Dialect = None, +) -> Insert: """ - Reverses the action of :func:`monkeypatch_TableClause`. - """ - del TableClause.insert_on_duplicate - - -STARTSEPS = "`" -ENDSEPS = "`" -INSERT_FIELDNAMES_REGEX = ( - r"^INSERT\sINTO\s[{startseps}]?(?P\w+)[{endseps}]?\s+" - r"\((?P[, {startseps}{endseps}\w]+)\)\s+VALUES".format( - startseps=STARTSEPS, endseps=ENDSEPS - ) -) -# http://pythex.org/ ! -RE_INSERT_FIELDNAMES = re.compile(INSERT_FIELDNAMES_REGEX) + Creates an "upsert" (INSERT ... ON DUPLICATE KEY UPDATE) statment if + possible (e.g. MySQL/MariaDB). Failing that, returns an INSERT statement. + Args: + table: + SQLAlchemy Table in which to insert values. + values: + Values to insert (column: value dictionary). + session: + Session from which to extract a dialect. + dialect: + Explicit dialect. -@compiles(InsertOnDuplicate, SqlaDialectName.MYSQL) -def compile_insert_on_duplicate_key_update( - insert: Insert, compiler: SQLCompiler, **kw -) -> str: - """ - Hooks into the use of the :class:`InsertOnDuplicate` class - for the MySQL dialect. Compiles the relevant SQL for an ``INSERT... - ON DUPLICATE KEY UPDATE`` statement. + Previously (prior to 2025-01-05 and prior to SQLAlchemy 2), we did this: - Notes: + .. code-block:: python - - We can't get the fieldnames directly from ``insert`` or ``compiler``. - - We could rewrite the innards of the visit_insert statement - (https://github.com/bedwards/sqlalchemy_mysql_ext/blob/master/duplicate.py)... - but, like that, it will get outdated. - - We could use a hack-in-by-hand method - (https://stackoverflow.com/questions/6611563/sqlalchemy-on-duplicate-key-update) - ... but a little automation would be nice. - - So, regex to the rescue. - - NOTE THAT COLUMNS ARE ALREADY QUOTED by this stage; no need to repeat. - """ # noqa - # log.critical(compiler.__dict__) - # log.critical(compiler.dialect.__dict__) - # log.critical(insert.__dict__) - s = compiler.visit_insert(insert, **kw) - # log.critical(s) - m = RE_INSERT_FIELDNAMES.match(s) - if m is None: - raise ValueError("compile_insert_on_duplicate_key_update: no match") - columns = [c.strip() for c in m.group("columns").split(",")] - # log.critical(columns) - updates = ", ".join([f"{c} = VALUES({c})" for c in columns]) - s += f" ON DUPLICATE KEY UPDATE {updates}" - # log.critical(s) - return s + q = sqla_table.insert_on_duplicate().values(destvalues) + This "insert_on_duplicate" member was available because + crate_anon/anonymise/config.py ran monkeypatch_TableClause(), from + cardinal_pythonlib.sqlalchemy.insert_on_duplicate. The function did dialect + detection via "@compiles(InsertOnDuplicate, SqlaDialectName.MYSQL)". But + it did nasty text-based hacking to get the column names. -_TEST_CODE = """ + However, SQLAlchemy now supports "upsert" for MySQL: + https://docs.sqlalchemy.org/en/20/dialects/mysql.html#insert-on-duplicate-key-update-upsert -from sqlalchemy import Column, String, Integer, create_engine -from sqlalchemy.orm import Session -from sqlalchemy.ext.declarative import declarative_base + Note the varying argument forms possible. -Base = declarative_base() + The only other question: if the dialect is not MySQL, will the reference to + insert_stmt.on_duplicate_key_update crash or just not do anything? To test: + .. code-block:: python -class OrmObject(Base): - __tablename__ = "sometable" - id = Column(Integer, primary_key=True) - name = Column(String) + from sqlalchemy import table + t = table("tablename") + destvalues = {"a": 1} + insert_stmt = t.insert().values(destvalues) + on_dup_key_stmt = insert_stmt.on_duplicate_key_update(destvalues) -engine = create_engine("sqlite://", echo=True) -Base.metadata.create_all(engine) + This does indeed crash (AttributeError: 'Insert' object has no attribute + 'on_duplicate_key_update'). In contrast, this works: -session = Session(engine) + .. code-block:: python -d1 = dict(id=1, name="One") -d2 = dict(id=2, name="Two") + from sqlalchemy.dialects.mysql import insert as insert_mysql -insert_1 = OrmObject.__table__.insert(values=d1) -insert_2 = OrmObject.__table__.insert(values=d2) -session.execute(insert_1) -session.execute(insert_2) -session.execute(insert_1) # raises sqlalchemy.exc.IntegrityError + insert2 = insert_mysql(t).values(destvalues) + on_dup_key2 = insert2.on_duplicate_key_update(destvalues) + Note also that an insert() statement doesn't gain a + "on_duplicate_key_update" attribute just because MySQL is used (the insert + statement doesn't know that yet). -# ... recommended cross-platform way is SELECT then INSERT or UPDATE -# accordingly; see -# https://groups.google.com/forum/#!topic/sqlalchemy/aQLqeHmLPQY + The old way was good for dialect detection but ugly for textual analysis of + the query. The new way is more elegant in the query, but less for dialect + detection. Overall, new way likely preferable. -""" + """ + if bool(session) + bool(dialect) != 1: + raise ValueError( + f"Must specify exactly one of: {session=}, {dialect=}" + ) + dialect_name = get_dialect_name(dialect or session) + if dialect_name == SqlaDialectName.MYSQL: + return ( + insert_mysql(table).values(values).on_duplicate_key_update(values) + ) + else: + return table.insert().values(values) diff --git a/cardinal_pythonlib/sqlalchemy/list_types.py b/cardinal_pythonlib/sqlalchemy/list_types.py index df92ebf..2ff9890 100644 --- a/cardinal_pythonlib/sqlalchemy/list_types.py +++ b/cardinal_pythonlib/sqlalchemy/list_types.py @@ -178,7 +178,7 @@ def add_to_mylist(self, text: str) -> None: # noinspection PyAugmentAssignment self.mylist = self.mylist + [text] # not "append()", not "+=" - """ # noqa + """ # noqa: E501 impl = UnicodeText() @property diff --git a/cardinal_pythonlib/sqlalchemy/merge_db.py b/cardinal_pythonlib/sqlalchemy/merge_db.py index 31138af..4c5ef52 100644 --- a/cardinal_pythonlib/sqlalchemy/merge_db.py +++ b/cardinal_pythonlib/sqlalchemy/merge_db.py @@ -226,7 +226,7 @@ def get_all_dependencies( """ extra_dependencies = ( extra_dependencies or [] - ) # type: List[TableDependency] # noqa + ) # type: List[TableDependency] for td in extra_dependencies: td.set_metadata_if_none(metadata) dependencies = set([td.sqla_tuple() for td in extra_dependencies]) @@ -236,7 +236,7 @@ def get_all_dependencies( for table in tables: for fkc in table.foreign_key_constraints: if fkc.use_alter is True: - # http://docs.sqlalchemy.org/en/latest/core/constraints.html#sqlalchemy.schema.ForeignKeyConstraint.params.use_alter # noqa + # http://docs.sqlalchemy.org/en/latest/core/constraints.html#sqlalchemy.schema.ForeignKeyConstraint.params.use_alter # noqa: E501 continue dependent_on = fkc.referred_table @@ -698,16 +698,16 @@ def my_translate_fn(trcon: TranslationContext) -> None: only_tables = only_tables or [] # type: List[TableIdentity] tables_to_keep_pks_for = ( tables_to_keep_pks_for or [] - ) # type: List[TableIdentity] # noqa + ) # type: List[TableIdentity] extra_table_dependencies = ( extra_table_dependencies or [] - ) # type: List[TableDependency] # noqa + ) # type: List[TableDependency] trcon_info = trcon_info or {} # type: Dict[str, Any] # We need both Core and ORM for the source. # noinspection PyUnresolvedReferences metadata = base_class.metadata # type: MetaData - src_session = sessionmaker(bind=src_engine)() # type: Session + src_session = sessionmaker(bind=src_engine, future=True)() # type: Session dst_engine = get_engine_from_session(dst_session) tablename_to_ormclass = get_orm_classes_by_table_name_from_base(base_class) @@ -723,7 +723,7 @@ def my_translate_fn(trcon: TranslationContext) -> None: only_table_names = [ti.tablename for ti in only_tables] tables_to_keep_pks_for = [ ti.tablename for ti in tables_to_keep_pks_for - ] # type: List[str] # noqa + ] # type: List[str] # ... now all are of type List[str] # Safety check: this is an imperfect check for source == destination, but @@ -899,7 +899,7 @@ def wipe_primary_key(inst: object) -> None: # This doesn't work: # - process tables in order of dependencies, eager-loading # relationships with - # for relationship in insp.mapper.relationships: # type: RelationshipProperty # noqa + # for relationship in insp.mapper.relationships: # type: RelationshipProperty # noqa: E501 # related_col = getattr(orm_class, relationship.key) # query = query.options(joinedload(related_col)) # - expunge from old session / make_transient / wipe_primary_key/ add @@ -937,7 +937,7 @@ def wipe_primary_key(inst: object) -> None: # simply move the instance from one session to the other, # blanking primary keys. - # https://stackoverflow.com/questions/14636192/sqlalchemy-modification-of-detached-object # noqa + # https://stackoverflow.com/questions/14636192/sqlalchemy-modification-of-detached-object # noqa: E501 src_session.expunge(instance) make_transient(instance) if wipe_pk: diff --git a/cardinal_pythonlib/sqlalchemy/orm_inspect.py b/cardinal_pythonlib/sqlalchemy/orm_inspect.py index 8a703e0..5fff4dc 100644 --- a/cardinal_pythonlib/sqlalchemy/orm_inspect.py +++ b/cardinal_pythonlib/sqlalchemy/orm_inspect.py @@ -45,7 +45,7 @@ from sqlalchemy.orm.session import Session from sqlalchemy.sql.schema import Column, MetaData from sqlalchemy.sql.type_api import TypeEngine -from sqlalchemy.sql.visitors import VisitableType +from sqlalchemy.sql.visitors import Visitable from sqlalchemy.util import OrderedProperties from cardinal_pythonlib.classes import gen_all_subclasses @@ -60,6 +60,13 @@ log = get_brace_style_log_with_null_handler(__name__) +# ============================================================================= +# Constants +# ============================================================================= + +VisitableType = Type[Visitable] # for SQLAlchemy 2.0 + + # ============================================================================= # Creating ORM objects conveniently, etc. # ============================================================================= @@ -78,10 +85,15 @@ class type. This function ensures that such classes are converted to .. code-block:: python + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.sqltypes import Integer, String, TypeEngine + a = Column("a", Integer) b = Column("b", Integer()) c = Column("c", String(length=50)) + # In SQLAlchemy to 1.4: + isinstance(Integer, TypeEngine) # False isinstance(Integer(), TypeEngine) # True isinstance(String(length=50), TypeEngine) # True @@ -91,11 +103,30 @@ class type. This function ensures that such classes are converted to type(String) # type(String(length=50)) # + # In SQLAlchemy 2.0, VisitableType has gone. Though there is Visitable. + # (So we can also recreate VisitableType, as above, although only for + # type hints, not e.g. isinstance.) + + from sqlalchemy.sql.visitors import Visitable + + isinstance(Integer, TypeEngine) # False + isinstance(Integer(), TypeEngine) # True + isinstance(String(length=50), TypeEngine) # True + + type(Integer) # + issubclass(Integer, Visitable) # True + type(Integer()) # + type(String) # + issubclass(String, Visitable) # True + type(String(length=50)) # + This function coerces things to a :class:`TypeEngine`. """ if isinstance(coltype, TypeEngine): return coltype - return coltype() + instance = coltype() + assert isinstance(instance, TypeEngine) + return instance # ============================================================================= @@ -200,16 +231,16 @@ def walk_orm_tree( # http://docs.sqlalchemy.org/en/latest/faq/sessions.html#faq-walk-objects skip_relationships_always = ( skip_relationships_always or [] - ) # type: List[str] # noqa + ) # type: List[str] skip_relationships_by_tablename = ( skip_relationships_by_tablename or {} - ) # type: Dict[str, List[str]] # noqa + ) # type: Dict[str, List[str]] skip_all_relationships_for_tablenames = ( skip_all_relationships_for_tablenames or [] - ) # type: List[str] # noqa + ) # type: List[str] skip_all_objects_for_tablenames = ( skip_all_objects_for_tablenames or [] - ) # type: List[str] # noqa + ) # type: List[str] stack = [obj] if seen is None: seen = set() @@ -227,7 +258,7 @@ def walk_orm_tree( insp = inspect(obj) # type: InstanceState for ( relationship - ) in insp.mapper.relationships: # type: RelationshipProperty # noqa + ) in insp.mapper.relationships: # type: RelationshipProperty attrname = relationship.key # Skip? if attrname in skip_relationships_always: @@ -260,7 +291,7 @@ def walk_orm_tree( # dependent upon X, i.e. traverse relationships. # # https://groups.google.com/forum/#!topic/sqlalchemy/wb2M_oYkQdY -# https://groups.google.com/forum/#!searchin/sqlalchemy/cascade%7Csort:date/sqlalchemy/eIOkkXwJ-Ms/JLnpI2wJAAAJ # noqa +# https://groups.google.com/forum/#!searchin/sqlalchemy/cascade%7Csort:date/sqlalchemy/eIOkkXwJ-Ms/JLnpI2wJAAAJ # noqa: E501 def copy_sqla_object( @@ -385,13 +416,13 @@ def rewrite_relationships( attrname_rel ) in ( insp.mapper.relationships.items() - ): # type: Tuple[str, RelationshipProperty] # noqa + ): # type: Tuple[str, RelationshipProperty] attrname = attrname_rel[0] rel_prop = attrname_rel[1] if rel_prop.viewonly: if debug: log.debug("Skipping viewonly relationship") - continue # don't attempt to write viewonly relationships # noqa + continue # don't attempt to write viewonly relationships related_class = rel_prop.mapper.class_ related_table_name = related_class.__tablename__ # type: str if related_table_name in skip_table_names: @@ -627,7 +658,7 @@ def gen_columns_for_uninstrumented_class( SAWarning: Unmanaged access of declarative attribute id from non-mapped class GenericTabletRecordMixin Try to use :func:`gen_columns` instead. - """ # noqa + """ # noqa: E501 for attrname in dir(cls): potential_column = getattr(cls, attrname) if isinstance(potential_column, Column): @@ -677,12 +708,9 @@ def gen_relationships( rel_prop, ) in ( insp.mapper.relationships.items() - ): # type: Tuple[str, RelationshipProperty] # noqa + ): # type: Tuple[str, RelationshipProperty] # noinspection PyUnresolvedReferences related_class = rel_prop.mapper.class_ - # log.critical("gen_relationships: attrname={!r}, " - # "rel_prop={!r}, related_class={!r}, rel_prop.info={!r}", - # attrname, rel_prop, related_class, rel_prop.info) yield attrname, rel_prop, related_class @@ -698,9 +726,9 @@ def get_orm_columns(cls: Type) -> List[Column]: """ mapper = inspect(cls) # type: Mapper # ... returns InstanceState if called with an ORM object - # http://docs.sqlalchemy.org/en/latest/orm/session_state_management.html#session-object-states # noqa + # http://docs.sqlalchemy.org/en/latest/orm/session_state_management.html#session-object-states # noqa: E501 # ... returns Mapper if called with an ORM class - # http://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper # noqa + # http://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper # noqa: E501 colmap = mapper.columns # type: OrderedProperties return colmap.values() diff --git a/cardinal_pythonlib/sqlalchemy/orm_query.py b/cardinal_pythonlib/sqlalchemy/orm_query.py index 2a3e9e5..c104221 100644 --- a/cardinal_pythonlib/sqlalchemy/orm_query.py +++ b/cardinal_pythonlib/sqlalchemy/orm_query.py @@ -26,14 +26,13 @@ """ -from typing import Any, Dict, Sequence, Tuple, Union +from typing import Any, Dict, List, Tuple, Type, Union from sqlalchemy.engine.base import Connection, Engine -from sqlalchemy.engine import CursorResult from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session -from sqlalchemy.sql.expression import ClauseElement, literal +from sqlalchemy.sql.expression import ClauseElement, literal, select from sqlalchemy.sql import func from sqlalchemy.sql.selectable import Exists @@ -48,31 +47,76 @@ # ============================================================================= +# noinspection PyUnusedLocal def get_rows_fieldnames_from_query( session: Union[Session, Engine, Connection], query: Query -) -> Tuple[Sequence[Sequence[Any]], Sequence[str]]: +) -> Tuple[List[Tuple[Any, ...]], List[str]]: """ - Returns results and column names from a query. - - Args: - session: SQLAlchemy :class:`Session`, :class:`Engine`, or - :class:`Connection` object - query: SQLAlchemy :class:`Query` - - Returns: - ``(rows, fieldnames)`` where ``rows`` is the usual set of results and - ``fieldnames`` are the name of the result columns/fields. - + Superseded. It used to be fine to use a Query object to run a SELECT + statement. But as of SQLAlchemy 2.0 (or 1.4 with future=True), this has + been removed. + + Also, it isn't worth coercing here. Some details are in the source code, + but usually we are not seeking to run a query that fetches ORM objects + themselves and then map those to fieldnames/values. Instead, we used to use + a Query object made from selectable elements like columns and COUNT() + clauses. That is what the select() system is meant for. So this code will + now raise an error. """ - # https://stackoverflow.com/questions/6455560/how-to-get-column-names-from-sqlalchemy-result-declarative-syntax # noqa - # No! Returns e.g. "User" for session.Query(User)... - # fieldnames = [cd['name'] for cd in query.column_descriptions] - result = session.execute(query) # type: CursorResult - fieldnames = result.keys() - # ... yes! Comes out as "_table_field", which is how SQLAlchemy SELECTs - # things. - rows = result.fetchall() - return rows, fieldnames + raise NotImplementedError( + "From SQLAlchemy 2.0, don't perform queries directly with a " + "sqlalchemy.orm.query.Query object; use a " + "sqlalchemy.sql.selectable.Select object, e.g. from select(). Use " + "cardinal_pythonlib.sqlalchemy.core_query." + "get_rows_fieldnames_from_select() instead." + ) + + # - Old and newer advice: + # https://stackoverflow.com/questions/6455560/how-to-get-column-names-from-sqlalchemy-result-declarative-syntax # noqa: E501 + # + # 1. query.column_description + # fieldnames = [cd['name'] for cd in query.column_descriptions] + # No. Returns e.g. "User" for session.Query(User), i.e. ORM class names. + # 2. Formerly (prior to SQLAlchemy 1.4+/future=True), result.keys() worked. + # It came out as "_table_field", which is how SQLAlchemy SELECTs things. + # 3. But now, use query.statement.columns.keys(). + # Or possible query.statement.subquery().columns.keys(). + # + # In SQLAlchemy 2, the result of session.execute(query) is typically a + # sqlalchemy.engine.result.ChunkedIteratorResult. Then, "result.mappings()" + # gives a sqlalchemy.engine.result.MappingResult. See + # https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Result # noqa: E501 + # In the context of the Core/declarative methods, results.mappings() is + # useful and gives a dictionary. But when you do it here, you get a + # dictionary of {classname: classinstance}, which is less helpful. + + # FIELDNAMES ARE ACHIEVABLE LIKE THIS: + # + # fieldnames = query.statement.subquery().columns.keys() + # + # Without "subquery()": + # SADeprecationWarning: The SelectBase.c and SelectBase.columns attributes + # are deprecated and will be removed in a future release; these attributes + # implicitly create a subquery that should be explicit. Please call + # SelectBase.subquery() first in order to create a subquery, which then + # contains this attribute. To access the columns that this SELECT object + # SELECTs from, use the SelectBase.selected_columns attribute. (deprecated + # since: 1.4) + + # VALUES ARE ACHIEVABLE ALONG THESE LINES [although session.execute(query) + # is no longer legitimate] BUT IT IS A BIT SILLY. + # + # https://docs.sqlalchemy.org/en/14/errors.html#error-89ve + # + # result = session.execute(query) + # rows_as_object_tuples = result.fetchall() + # orm_objects = tuple(row[0] for row in rows_as_object_tuples) + # rows = [ + # tuple(getattr(obj, k) for k in fieldnames) + # for obj in orm_objects + # ] + + # return rows, fieldnames # ============================================================================= @@ -99,6 +143,7 @@ def bool_from_exists_clause(session: Session, exists_clause: Exists) -> bool: SELECT 1 WHERE EXISTS (SELECT 1 FROM table WHERE ...) -- ... giving 1 or None (no rows) -- ... fine for SQL Server, but invalid for MySQL (no FROM clause) + -- ... also fine for SQLite, giving 1 or None (no rows) *Others, including MySQL* @@ -107,8 +152,9 @@ def bool_from_exists_clause(session: Session, exists_clause: Exists) -> bool: SELECT EXISTS (SELECT 1 FROM table WHERE ...) -- ... giving 1 or 0 -- ... fine for MySQL, but invalid syntax for SQL Server + -- ... also fine for SQLite, giving 1 or 0 - """ # noqa + """ # noqa: E501 if session.get_bind().dialect.name == SqlaDialectName.MSSQL: # SQL Server result = session.query(literal(True)).filter(exists_clause).scalar() @@ -119,7 +165,7 @@ def bool_from_exists_clause(session: Session, exists_clause: Exists) -> bool: def exists_orm( - session: Session, ormclass: DeclarativeMeta, *criteria: Any + session: Session, ormclass: Type[DeclarativeMeta], *criteria: Any ) -> bool: """ Detects whether a database record exists for the specified ``ormclass`` @@ -148,7 +194,7 @@ def exists_orm( def get_or_create( session: Session, - model: DeclarativeMeta, + model: Type[DeclarativeMeta], defaults: Dict[str, Any] = None, **kwargs: Any ) -> Tuple[Any, bool]: @@ -187,12 +233,17 @@ def get_or_create( # Extend Query to provide an optimized COUNT(*) # ============================================================================= + # noinspection PyAbstractClass -class CountStarSpecializedQuery(Query): - def __init__(self, *args, **kwargs) -> None: +class CountStarSpecializedQuery: + def __init__(self, model: Type[DeclarativeMeta], session: Session) -> None: """ Optimizes ``COUNT(*)`` queries. + Given an ORM class, and a session, creates a query that counts + instances of that ORM class. (You can filter later using the filter() + command, which chains as usual.) + See https://stackoverflow.com/questions/12941416/how-to-count-rows-with-select-count-with-sqlalchemy @@ -200,18 +251,32 @@ def __init__(self, *args, **kwargs) -> None: .. code-block:: python - q = CountStarSpecializedQuery([cls], session=dbsession)\ + q = CountStarSpecializedQuery(cls, session=dbsession)\ .filter(cls.username == username) return q.count_star() - """ # noqa - super().__init__(*args, **kwargs) + Note that in SQLAlchemy <1.4, Query(ormclass) implicitly added "from + the table of that ORM class". But SQLAlchemy 2.0 doesn't. That means + that Query(ormclass) leads ultimately to "SELECT COUNT(*)" by itself; + somewhat surprisingly to me, that gives 1 rather than an error, at + least in SQLite. So now we inherit from Select, not Query. + + """ + # https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.select # noqa: E501 + # ... accepts "series of ColumnElement and / or FromClause objects" + # But passing the table to select() just means you select too many + # columns. So let's do this by embedding, not inheriting from, a + # select()-type object (Select). + self.select_query = select(func.count()).select_from(model.__table__) + self.session = session + + def filter(self, *args, **kwargs) -> "CountStarSpecializedQuery": + self.select_query = self.select_query.filter(*args, **kwargs) + return self def count_star(self) -> int: """ Implements the ``COUNT(*)`` specialization. """ - count_query = self.statement.with_only_columns( - [func.count()] - ).order_by(None) + count_query = self.select_query.order_by(None) return self.session.execute(count_query).scalar() diff --git a/cardinal_pythonlib/sqlalchemy/orm_schema.py b/cardinal_pythonlib/sqlalchemy/orm_schema.py index 2a88b78..75a6cff 100644 --- a/cardinal_pythonlib/sqlalchemy/orm_schema.py +++ b/cardinal_pythonlib/sqlalchemy/orm_schema.py @@ -26,18 +26,20 @@ """ -from typing import TYPE_CHECKING +import logging +from typing import Type, TYPE_CHECKING -from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler -from cardinal_pythonlib.sqlalchemy.session import get_safe_url_from_engine from sqlalchemy.engine.base import Engine from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.schema import CreateTable +from cardinal_pythonlib.sqlalchemy.schema import execute_ddl +from cardinal_pythonlib.sqlalchemy.session import get_safe_url_from_engine + if TYPE_CHECKING: from sqlalchemy.sql.schema import Table -log = get_brace_style_log_with_null_handler(__name__) +log = logging.getLogger(__name__) # ============================================================================= @@ -47,7 +49,7 @@ def create_table_from_orm_class( engine: Engine, - ormclass: DeclarativeMeta, + ormclass: Type[DeclarativeMeta], without_constraints: bool = False, ) -> None: """ @@ -61,12 +63,11 @@ def create_table_from_orm_class( """ table = ormclass.__table__ # type: Table log.info( - "Creating table {} on engine {}{}", - table.name, - get_safe_url_from_engine(engine), - " (omitting constraints)" if without_constraints else "", + f"Creating table {table.name} " + f"on engine {get_safe_url_from_engine(engine)}" + f"{' (omitting constraints)' if without_constraints else ''}" ) - # https://stackoverflow.com/questions/19175311/how-to-create-only-one-table-with-sqlalchemy # noqa + # https://stackoverflow.com/questions/19175311/how-to-create-only-one-table-with-sqlalchemy # noqa: E501 if without_constraints: include_foreign_key_constraints = [] else: @@ -74,4 +75,4 @@ def create_table_from_orm_class( creator = CreateTable( table, include_foreign_key_constraints=include_foreign_key_constraints ) - creator.execute(bind=engine) + execute_ddl(engine, ddl=creator) diff --git a/cardinal_pythonlib/sqlalchemy/schema.py b/cardinal_pythonlib/sqlalchemy/schema.py index a36fe48..77b9fb1 100644 --- a/cardinal_pythonlib/sqlalchemy/schema.py +++ b/cardinal_pythonlib/sqlalchemy/schema.py @@ -25,23 +25,33 @@ **Functions to work with SQLAlchemy schemas (schemata) directly, via SQLAlchemy Core.** +Functions that have to work with specific dialect information are marked +DIALECT-AWARE. + """ import ast -import contextlib import copy import csv from functools import lru_cache import io import re -from typing import Any, Dict, Generator, List, Optional, Type, Union +from typing import ( + Any, + Dict, + Generator, + List, + Optional, + Type, + Union, + TYPE_CHECKING, +) from sqlalchemy import inspect -from sqlalchemy.dialects import mssql, mysql -# noinspection PyProtectedMember -from sqlalchemy.engine import Connection, Engine, CursorResult +from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.dialects import postgresql, mssql, mysql, sqlite from sqlalchemy.dialects.mssql.base import TIMESTAMP as MSSQL_TIMESTAMP from sqlalchemy.schema import ( Column, @@ -52,14 +62,31 @@ Table, ) from sqlalchemy.sql import sqltypes, text -from sqlalchemy.sql.sqltypes import BigInteger, TypeEngine -from sqlalchemy.sql.visitors import VisitableType +from sqlalchemy.sql.ddl import DDLElement +from sqlalchemy.sql.sqltypes import ( + BigInteger, + Boolean, + Date, + DateTime, + Double, + Float, + Integer, + Numeric, + SmallInteger, + Text, + TypeEngine, +) +from sqlalchemy.sql.visitors import Visitable from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler from cardinal_pythonlib.sqlalchemy.dialect import ( quote_identifier, SqlaDialectName, ) +from cardinal_pythonlib.sqlalchemy.orm_inspect import coltype_as_typeengine + +if TYPE_CHECKING: + from sqlalchemy.engine.interfaces import ReflectedIndex log = get_brace_style_log_with_null_handler(__name__) @@ -68,10 +95,29 @@ # Constants # ============================================================================= +VisitableType = Type[Visitable] # for SQLAlchemy 2.0 + MIN_TEXT_LENGTH_FOR_FREETEXT_INDEX = 1000 MSSQL_DEFAULT_SCHEMA = "dbo" POSTGRES_DEFAULT_SCHEMA = "public" +DATABRICKS_SQLCOLTYPE_TO_SQLALCHEMY_GENERIC = { + # A bit nasty: https://github.com/databricks/databricks-sqlalchemy + # Part of the reverse mapping is via + # from databricks.sqlalchemy import DatabricksDialect + # print(DatabricksDialect.colspecs) + "BIGINT": BigInteger, + "BOOLEAN": Boolean, + "DATE": Date, + "TIMESTAMP_NTZ": DateTime, + "DOUBLE": Double, + "FLOAT": Float, + "INT": Integer, + "DECIMAL": Numeric, + "SMALLINT": SmallInteger, + "STRING": Text, +} + # ============================================================================= # Inspect tables (SQLAlchemy Core) @@ -133,14 +179,14 @@ def __init__(self, sqla_info_dict: Dict[str, Any]) -> None: - https://docs.sqlalchemy.org/en/latest/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_columns - https://bitbucket.org/zzzeek/sqlalchemy/issues/4051/sqlalchemyenginereflectioninspectorget_col - """ # noqa + """ # noqa: E501 # log.debug(repr(sqla_info_dict)) self.name = sqla_info_dict["name"] # type: str self.type = sqla_info_dict["type"] # type: TypeEngine self.nullable = sqla_info_dict["nullable"] # type: bool self.default = sqla_info_dict[ "default" - ] # type: str # SQL string expression + ] # type: Optional[str] # SQL string expression self.attrs = sqla_info_dict.get("attrs", {}) # type: Dict[str, Any] self.comment = sqla_info_dict.get("comment", "") # ... NB not appearing in @@ -154,7 +200,7 @@ def gen_columns_info( :class:`SqlaColumnInspectionInfo` objects. """ # Dictionary structure: see - # http://docs.sqlalchemy.org/en/latest/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_columns # noqa + # http://docs.sqlalchemy.org/en/latest/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_columns # noqa: E501 insp = inspect(engine) for d in insp.get_columns(tablename): yield SqlaColumnInspectionInfo(d) @@ -236,6 +282,35 @@ def get_single_int_pk_colname(table_: Table) -> Optional[str]: return None +def is_int_autoincrement_column(c: Column, t: Table) -> bool: + """ + Is this an integer AUTOINCREMENT column? Used by + get_single_int_autoincrement_colname(); q.v. + """ + # https://docs.sqlalchemy.org/en/20/core/metadata.html#sqlalchemy.schema.Column.params.autoincrement # noqa: E501 + # "The setting only has an effect for columns which are: + # - Integer derived (i.e. INT, SMALLINT, BIGINT). + # - Part of the primary key + # - Not referring to another column via ForeignKey, unless the value is + # specified as 'ignore_fk':" + if not c.primary_key or not is_sqlatype_integer(c.type): + return False + a = c.autoincrement + if isinstance(a, bool): + # Specified as True or False. + return a + if a == "auto": + # "indicates that a single-column (i.e. non-composite) primary key that + # is of an INTEGER type with no other client-side or server-side + # default constructs indicated should receive auto increment semantics + # automatically." Therefore: + n_pk = sum(x.primary_key for x in t.columns) + return n_pk == 1 and c.default is None + if c.foreign_keys: + return a == "ignore_fk" + return False + + def get_single_int_autoincrement_colname(table_: Table) -> Optional[str]: """ If a table has a single integer ``AUTOINCREMENT`` column, this will @@ -267,19 +342,19 @@ def get_single_int_autoincrement_colname(table_: Table) -> Optional[str]: ... which is what SQLAlchemy does (``dialects/mssql/base.py``, in :func:`get_columns`). """ - n_autoinc = 0 - int_autoinc_names = [] + int_autoinc_names = [] # type: List[str] for col in table_.columns: - if col.autoincrement: - n_autoinc += 1 - if is_sqlatype_integer(col.type): - int_autoinc_names.append(col.name) + if is_int_autoincrement_column(col, table_): + int_autoinc_names.append(col.name) + n_autoinc = len(int_autoinc_names) + if n_autoinc == 1: + return int_autoinc_names[0] if n_autoinc > 1: log.warning( - "Table {!r} has {} autoincrement columns", table_.name, n_autoinc + "Table {!r} has {} integer autoincrement columns", + table_.name, + n_autoinc, ) - if n_autoinc == 1 and len(int_autoinc_names) == 1: - return int_autoinc_names[0] return None @@ -295,17 +370,71 @@ def get_effective_int_pk_col(table_: Table) -> Optional[str]: ) +# ============================================================================= +# Execute DDL +# ============================================================================= + + +def execute_ddl( + engine: Engine, sql: str = None, ddl: DDLElement = None +) -> None: + """ + Execute DDL, either from a plain SQL string, or from an SQLAlchemy DDL + element. + + Previously we would use DDL(sql, bind=engine).execute(), but this has gone + in SQLAlchemy 2.0. + + If you want dialect-conditional execution, create the DDL object with e.g. + ddl = DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER), and pass that + DDL object to this function. + """ + assert bool(sql) ^ (ddl is not None) # one or the other. + if sql: + ddl = DDL(sql) + with engine.connect() as connection: + # DDL doesn't need a COMMIT. + connection.execute(ddl) + + # ============================================================================= # Indexes # ============================================================================= -def index_exists(engine: Engine, tablename: str, indexname: str) -> bool: +def index_exists( + engine: Engine, + tablename: str, + indexname: str = None, + colnames: Union[str, List[str]] = None, + raise_if_nonexistent_table: bool = True, +) -> bool: """ Does the specified index exist for the specified table? + + You can specify either the name of the index, or the name(s) of columns. + But not both. + + If the table doesn't exist, then if raise_if_nonexistent_table is True, + raise sqlalchemy.exc.NoSuchTableError; otherwise, warn and return False. """ + assert bool(indexname) ^ bool(colnames) # one or the other insp = inspect(engine) - return any(i["name"] == indexname for i in insp.get_indexes(tablename)) + if not raise_if_nonexistent_table and not insp.has_table(tablename): + log.warning(f"index_exists(): no such table {tablename!r}") + return False + indexes = insp.get_indexes(tablename) # type: List[ReflectedIndex] + if indexname: + # Look up by index name. + return any(i["name"] == indexname for i in indexes) + else: + # Look up by column names. All must be present in a given index. + if isinstance(colnames, str): + colnames = [colnames] + return any( + all(colname in i["column_names"] for colname in colnames) + for i in indexes + ) def mssql_get_pk_index_name( @@ -316,10 +445,10 @@ def mssql_get_pk_index_name( for the specified table (in the specified schema), or ``''`` if none is found. """ - # http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Connection.execute # noqa - # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text # noqa - # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.TextClause.bindparams # noqa - # http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.CursorResult # noqa + # http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Connection.execute # noqa: E501 + # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text # noqa: E501 + # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.TextClause.bindparams # noqa: E501 + # http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.CursorResult # noqa: E501 query = text( """ SELECT @@ -332,10 +461,10 @@ def mssql_get_pk_index_name( kc.[type] = 'PK' AND ta.name = :tablename AND s.name = :schemaname - """ + """ ).bindparams(tablename=tablename, schemaname=schemaname) - with contextlib.closing(engine.execute(query)) as result: - result: CursorResult + with engine.begin() as connection: + result = connection.execute(query) row = result.fetchone() return row[0] if row else "" @@ -359,11 +488,10 @@ def mssql_table_has_ft_index( WHERE ta.name = :tablename AND s.name = :schemaname - """ + """ ).bindparams(tablename=tablename, schemaname=schemaname) - with contextlib.closing( - engine.execute(query) - ) as result: # type: CursorResult + with engine.begin() as connection: + result = connection.execute(query) row = result.fetchone() return row[0] > 0 @@ -375,12 +503,17 @@ def mssql_transaction_count(engine_or_conn: Union[Connection, Engine]) -> int: https://docs.microsoft.com/en-us/sql/t-sql/functions/trancount-transact-sql?view=sql-server-2017). Returns ``None`` if it can't be found (unlikely?). """ - sql = "SELECT @@TRANCOUNT" - with contextlib.closing( - engine_or_conn.execute(sql) - ) as result: # type: CursorResult + query = text("SELECT @@TRANCOUNT") + if isinstance(engine_or_conn, Connection): + result = engine_or_conn.execute(query) row = result.fetchone() - return row[0] if row else None + elif isinstance(engine_or_conn, Engine): + with engine_or_conn.begin() as connection: + result = connection.execute(query) + row = result.fetchone() + else: + raise ValueError(f"Unexpected {engine_or_conn=}") + return row[0] if row else None def add_index( @@ -397,6 +530,8 @@ def add_index( The table name is worked out from the :class:`Column` object. + DIALECT-AWARE. + Args: engine: SQLAlchemy :class:`Engine` object sqla_column: single column to index @@ -416,7 +551,7 @@ def add_index( """ # We used to process a table as a unit; this makes index creation faster # (using ALTER TABLE). - # http://dev.mysql.com/doc/innodb/1.1/en/innodb-create-index-examples.html # noqa + # http://dev.mysql.com/doc/innodb/1.1/en/innodb-create-index-examples.html # noqa: E501 # ... ignored in transition to SQLAlchemy def quote(identifier: str) -> str: @@ -424,6 +559,7 @@ def quote(identifier: str) -> str: is_mssql = engine.dialect.name == SqlaDialectName.MSSQL is_mysql = engine.dialect.name == SqlaDialectName.MYSQL + is_sqlite = engine.dialect.name == SqlaDialectName.SQLITE multiple_sqla_columns = multiple_sqla_columns or [] # type: List[Column] if multiple_sqla_columns and not (fulltext and is_mssql): @@ -458,6 +594,9 @@ def quote(identifier: str) -> str: idxname = "_idxft_{}".format("_".join(colnames)) else: idxname = "_idx_{}".format("_".join(colnames)) + if is_sqlite: + # SQLite doesn't allow indexes with the same names on different tables. + idxname = f"{tablename}_{idxname}" if idxname and index_exists(engine, tablename, idxname): log.info( f"Skipping creation of index {idxname} on " @@ -488,8 +627,7 @@ def quote(identifier: str) -> str: colnames=", ".join(quote(c) for c in colnames), ) ) - # DDL(sql, bind=engine).execute_if(dialect=SqlaDialectName.MYSQL) - DDL(sql, bind=engine).execute() + execute_ddl(engine, sql=sql) elif is_mssql: # Microsoft SQL Server # https://msdn.microsoft.com/library/ms187317(SQL.130).aspx @@ -551,7 +689,7 @@ def quote(identifier: str) -> str: ) # Executing serial COMMITs or a ROLLBACK won't help here if # this transaction is due to Python DBAPI default behaviour. - DDL(sql, bind=engine).execute() + execute_ddl(engine, sql=sql) # The reversal procedure is DROP FULLTEXT INDEX ON tablename; @@ -571,36 +709,41 @@ def quote(identifier: str) -> str: # More DDL # ============================================================================= +# https://stackoverflow.com/questions/18835740/does-bigint-auto-increment-work-for-sqlalchemy-with-sqlite # noqa: E501 + +BigIntegerForAutoincrementType = BigInteger() +BigIntegerForAutoincrementType = BigIntegerForAutoincrementType.with_variant( + postgresql.BIGINT(), SqlaDialectName.POSTGRES +) +BigIntegerForAutoincrementType = BigIntegerForAutoincrementType.with_variant( + mssql.BIGINT(), SqlaDialectName.MSSQL +) +BigIntegerForAutoincrementType = BigIntegerForAutoincrementType.with_variant( + mysql.BIGINT(), SqlaDialectName.MYSQL +) +BigIntegerForAutoincrementType = BigIntegerForAutoincrementType.with_variant( + sqlite.INTEGER(), SqlaDialectName.SQLITE +) + def make_bigint_autoincrement_column( - column_name: str, dialect: Dialect, nullable=False + column_name: str, nullable: bool = False, comment: str = None ) -> Column: """ Returns an instance of :class:`Column` representing a :class:`BigInteger` - ``AUTOINCREMENT`` column in the specified :class:`Dialect`. - """ - - # https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column.params.nullable # noqa: E501 - # Different behaviour of nullable flag observed to the documentation. See - # sqlalchemy/tests/schema_tests.py. - - # noinspection PyUnresolvedReferences - if dialect.name == SqlaDialectName.MSSQL: - return Column( - column_name, - BigInteger, - Identity(start=1, increment=1), - nullable=nullable, - autoincrement=True, - ) - else: - # return Column(column_name, BigInteger, autoincrement=True) - # noinspection PyUnresolvedReferences - raise AssertionError( - f"SQLAlchemy doesn't support non-PK autoincrement fields yet for " - f"dialect {dialect.name!r}" - ) - # see https://stackoverflow.com/questions/2937229 + ``AUTOINCREMENT`` column, or the closest that the database engine can + manage. + """ + return Column( + column_name, + BigIntegerForAutoincrementType, + Identity(start=1, increment=1), + # https://docs.sqlalchemy.org/en/20/core/defaults.html#identity-ddl + autoincrement=True, + nullable=nullable, + comment=comment, + ) + # see also: https://stackoverflow.com/questions/2937229 def column_creation_ddl(sqla_column: Column, dialect: Dialect) -> str: @@ -608,35 +751,13 @@ def column_creation_ddl(sqla_column: Column, dialect: Dialect) -> str: Returns DDL to create a column, using the specified dialect. The column should already be bound to a table (because e.g. the SQL Server - dialect requires this for DDL generation). - - Manual testing: - - .. code-block:: python + dialect requires this for DDL generation). If you don't append the column + to a Table object, the DDL generation step gives + "sqlalchemy.exc.CompileError: mssql requires Table-bound columns in order + to generate DDL". - from sqlalchemy.schema import Column, CreateColumn, MetaData, Sequence, Table - from sqlalchemy.sql.sqltypes import BigInteger - from sqlalchemy.dialects.mssql.base import MSDialect - dialect = MSDialect() - col1 = Column('hello', BigInteger, nullable=True) - col2 = Column('world', BigInteger, autoincrement=True) # does NOT generate IDENTITY - col3 = Column('you', BigInteger, Sequence('dummy_name', start=1, increment=1)) - metadata = MetaData() - t = Table('mytable', metadata) - t.append_column(col1) - t.append_column(col2) - t.append_column(col3) - print(str(CreateColumn(col1).compile(dialect=dialect))) # hello BIGINT NULL - print(str(CreateColumn(col2).compile(dialect=dialect))) # world BIGINT NULL - print(str(CreateColumn(col3).compile(dialect=dialect))) # you BIGINT NOT NULL IDENTITY(1,1) - - If you don't append the column to a Table object, the DDL generation step - gives: - - .. code-block:: none - - sqlalchemy.exc.CompileError: mssql requires Table-bound columns in order to generate DDL - """ # noqa + Testing: see schema_tests.py + """ return str(CreateColumn(sqla_column).compile(dialect=dialect)) @@ -646,18 +767,35 @@ def giant_text_sqltype(dialect: Dialect) -> str: Returns the SQL column type used to make very large text columns for a given dialect. + DIALECT-AWARE. + Args: dialect: a SQLAlchemy :class:`Dialect` Returns: the SQL data type of "giant text", typically 'LONGTEXT' for MySQL and 'NVARCHAR(MAX)' for SQL Server. """ - if dialect.name == SqlaDialectName.SQLSERVER: + dname = dialect.name + if dname == SqlaDialectName.MSSQL: return "NVARCHAR(MAX)" - elif dialect.name == SqlaDialectName.MYSQL: + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver16 # noqa: E501 + elif dname == SqlaDialectName.MYSQL: return "LONGTEXT" + # https://dev.mysql.com/doc/refman/8.4/en/blob.html + elif dname == SqlaDialectName.ORACLE: + return "LONG" + # https://docs.oracle.com/cd/A58617_01/server.804/a58241/ch5.htm + elif dname == SqlaDialectName.POSTGRES: + return "TEXT" + # https://www.postgresql.org/docs/current/datatype-character.html + elif dname == SqlaDialectName.SQLITE: + return "TEXT" + # https://www.sqlite.org/datatype3.html + elif dname == SqlaDialectName.DATABRICKS: + return "STRING" + # https://github.com/databricks/databricks-sqlalchemy else: - raise ValueError(f"Unknown dialect: {dialect.name}") + raise ValueError(f"Unknown dialect: {dname}") # ============================================================================= @@ -688,16 +826,40 @@ def _get_sqla_coltype_class_from_str( Returns the SQLAlchemy class corresponding to a particular SQL column type in a given dialect. + DIALECT-AWARE. + Performs an upper- and lower-case search. For example, the SQLite dialect uses upper case, and the MySQL dialect uses lower case. + + For exploratory thinking, see + dev_notes/convert_sql_string_coltype_to_sqlalchemy_type.py. + + DISCUSSION AT: https://github.com/sqlalchemy/sqlalchemy/discussions/12230 """ - # noinspection PyUnresolvedReferences - ischema_names = dialect.ischema_names - try: - return ischema_names[coltype.upper()] - except KeyError: - return ischema_names[coltype.lower()] + if hasattr(dialect, "ischema_names"): + # The built-in dialects all have this, even though it's an internal + # detail. + ischema_names = dialect.ischema_names + try: + return ischema_names[coltype.upper()] + except KeyError: + return ischema_names[coltype.lower()] + elif dialect.name == SqlaDialectName.DATABRICKS: + # Ugly hack. + # Databricks is an example that doesn't have ischema_names. + try: + return DATABRICKS_SQLCOLTYPE_TO_SQLALCHEMY_GENERIC[coltype.upper()] + except KeyError: + raise ValueError( + f"Don't know how to convert SQL column type {coltype!r} " + f"to SQLAlchemy dialect {dialect!r}" + ) + else: + raise ValueError( + f"Don't know a generic way to convert SQL column types " + f"(in text format) to SQLAlchemy dialect {dialect.name!r}. " + ) def get_list_of_sql_string_literals_from_quoted_csv(x: str) -> List[str]: @@ -731,6 +893,8 @@ def get_sqla_coltype_from_dialect_str( ``coltype.compile()`` or ``coltype.compile(dialect)``; see :class:`TypeEngine`. + DIALECT-AWARE. + Args: dialect: a SQLAlchemy :class:`Dialect` class @@ -900,6 +1064,8 @@ def convert_sqla_type_for_dialect( """ Converts an SQLAlchemy column type from one SQL dialect to another. + DIALECT-AWARE. + Args: coltype: SQLAlchemy column type in the source dialect @@ -925,9 +1091,7 @@ def convert_sqla_type_for_dialect( """ assert coltype is not None - # noinspection PyUnresolvedReferences to_mysql = dialect.name == SqlaDialectName.MYSQL - # noinspection PyUnresolvedReferences to_mssql = dialect.name == SqlaDialectName.MSSQL typeclass = type(coltype) @@ -979,11 +1143,11 @@ def convert_sqla_type_for_dialect( if is_mssql_timestamp and to_mssql and convert_mssql_timestamp: # You cannot write explicitly to a TIMESTAMP field in SQL Server; it's # used for autogenerated values only. - # - https://stackoverflow.com/questions/10262426/sql-server-cannot-insert-an-explicit-value-into-a-timestamp-column # noqa - # - https://social.msdn.microsoft.com/Forums/sqlserver/en-US/5167204b-ef32-4662-8e01-00c9f0f362c2/how-to-tranfer-a-column-with-timestamp-datatype?forum=transactsql # noqa + # - https://stackoverflow.com/questions/10262426/sql-server-cannot-insert-an-explicit-value-into-a-timestamp-column # noqa: E501 + # - https://social.msdn.microsoft.com/Forums/sqlserver/en-US/5167204b-ef32-4662-8e01-00c9f0f362c2/how-to-tranfer-a-column-with-timestamp-datatype?forum=transactsql # noqa: E501 # ... suggesting BINARY(8) to store the value. # MySQL is more helpful: - # - https://stackoverflow.com/questions/409286/should-i-use-field-datetime-or-timestamp # noqa + # - https://stackoverflow.com/questions/409286/should-i-use-field-datetime-or-timestamp # noqa: E501 return mssql.base.BINARY(8) # ------------------------------------------------------------------------- @@ -996,36 +1160,6 @@ def convert_sqla_type_for_dialect( # Questions about SQLAlchemy column types # ============================================================================= -# Note: -# x = String } type(x) == VisitableType # metaclass -# x = BigInteger } -# but: -# x = String() } type(x) == TypeEngine -# x = BigInteger() } -# -# isinstance also cheerfully handles multiple inheritance, i.e. if you have -# class A(object), class B(object), and class C(A, B), followed by x = C(), -# then all of isinstance(x, A), isinstance(x, B), isinstance(x, C) are True - - -def _coltype_to_typeengine( - coltype: Union[TypeEngine, VisitableType] -) -> TypeEngine: - """ - An example is simplest: if you pass in ``Integer()`` (an instance of - :class:`TypeEngine`), you'll get ``Integer()`` back. If you pass in - ``Integer`` (an instance of :class:`VisitableType`), you'll also get - ``Integer()`` back. The function asserts that its return type is an - instance of :class:`TypeEngine`. - - See also - :func:`cardinal_pythonlib.sqlalchemy.orm_inspect.coltype_as_typeengine`. - """ - if isinstance(coltype, VisitableType): - coltype = coltype() - assert isinstance(coltype, TypeEngine) - return coltype - def is_sqlatype_binary(coltype: Union[TypeEngine, VisitableType]) -> bool: """ @@ -1033,7 +1167,7 @@ def is_sqlatype_binary(coltype: Union[TypeEngine, VisitableType]) -> bool: """ # Several binary types inherit internally from _Binary, making that the # easiest to check. - coltype = _coltype_to_typeengine(coltype) + coltype = coltype_as_typeengine(coltype) # noinspection PyProtectedMember return isinstance(coltype, sqltypes._Binary) @@ -1042,9 +1176,7 @@ def is_sqlatype_date(coltype: Union[TypeEngine, VisitableType]) -> bool: """ Is the SQLAlchemy column type a date type? """ - coltype = _coltype_to_typeengine(coltype) - # No longer valid in SQLAlchemy 1.2.11: - # return isinstance(coltype, sqltypes._DateAffinity) + coltype = coltype_as_typeengine(coltype) return isinstance(coltype, sqltypes.DateTime) or isinstance( coltype, sqltypes.Date ) @@ -1054,7 +1186,7 @@ def is_sqlatype_integer(coltype: Union[TypeEngine, VisitableType]) -> bool: """ Is the SQLAlchemy column type an integer type? """ - coltype = _coltype_to_typeengine(coltype) + coltype = coltype_as_typeengine(coltype) return isinstance(coltype, sqltypes.Integer) @@ -1062,8 +1194,10 @@ def is_sqlatype_numeric(coltype: Union[TypeEngine, VisitableType]) -> bool: """ Is the SQLAlchemy column type one that inherits from :class:`Numeric`, such as :class:`Float`, :class:`Decimal`? + + Note that integers don't count as Numeric! """ - coltype = _coltype_to_typeengine(coltype) + coltype = coltype_as_typeengine(coltype) return isinstance(coltype, sqltypes.Numeric) # includes Float, Decimal @@ -1071,7 +1205,7 @@ def is_sqlatype_string(coltype: Union[TypeEngine, VisitableType]) -> bool: """ Is the SQLAlchemy column type a string type? """ - coltype = _coltype_to_typeengine(coltype) + coltype = coltype_as_typeengine(coltype) return isinstance(coltype, sqltypes.String) @@ -1083,7 +1217,7 @@ def is_sqlatype_text_of_length_at_least( Is the SQLAlchemy column type a string type that's at least the specified length? """ - coltype = _coltype_to_typeengine(coltype) + coltype = coltype_as_typeengine(coltype) if not isinstance(coltype, sqltypes.String): return False # not a string/text type at all if coltype.length is None: @@ -1098,7 +1232,6 @@ def is_sqlatype_text_over_one_char( Is the SQLAlchemy column type a string type that's more than one character long? """ - coltype = _coltype_to_typeengine(coltype) return is_sqlatype_text_of_length_at_least(coltype, 2) @@ -1110,7 +1243,6 @@ def does_sqlatype_merit_fulltext_index( Is the SQLAlchemy column type a type that might merit a ``FULLTEXT`` index (meaning a string type of at least ``min_length``)? """ - coltype = _coltype_to_typeengine(coltype) return is_sqlatype_text_of_length_at_least(coltype, min_length) @@ -1125,7 +1257,7 @@ def does_sqlatype_require_index_len( ``TEXT`` columns: https://dev.mysql.com/doc/refman/5.7/en/create-index.html.) """ - coltype = _coltype_to_typeengine(coltype) + coltype = coltype_as_typeengine(coltype) if isinstance(coltype, sqltypes.Text): return True if isinstance(coltype, sqltypes.LargeBinary): @@ -1134,30 +1266,10 @@ def does_sqlatype_require_index_len( # ============================================================================= -# Hack in new type +# hack_in_mssql_xml_type # ============================================================================= - - -def hack_in_mssql_xml_type(): - r""" - Modifies SQLAlchemy's type map for Microsoft SQL Server to support XML. - - SQLAlchemy does not support the XML type in SQL Server (mssql). - Upon reflection, we get: - - .. code-block:: none - - sqlalchemy\dialects\mssql\base.py:1921: SAWarning: Did not recognize type 'xml' of column '...' - - We will convert anything of type ``XML`` into type ``TEXT``. - - """ # noqa - log.debug("Adding type 'xml' to SQLAlchemy reflection for dialect 'mssql'") - mssql.base.ischema_names["xml"] = mssql.base.TEXT - # https://stackoverflow.com/questions/32917867/sqlalchemy-making-schema-reflection-find-use-a-custom-type-for-all-instances # noqa - - # print(repr(mssql.base.ischema_names.keys())) - # print(repr(mssql.base.ischema_names)) +# +# Removed, as mssql.base.ischema_names["xml"] is now defined. # ============================================================================= @@ -1173,7 +1285,7 @@ def column_types_equal(a_coltype: TypeEngine, b_coltype: TypeEngine) -> bool: See https://stackoverflow.com/questions/34787794/sqlalchemy-column-type-comparison. IMPERFECT. - """ # noqa + """ # noqa: E501 return str(a_coltype) == str(b_coltype) diff --git a/cardinal_pythonlib/sqlalchemy/semantic_version_coltype.py b/cardinal_pythonlib/sqlalchemy/semantic_version_coltype.py index 78de49e..a55d47c 100644 --- a/cardinal_pythonlib/sqlalchemy/semantic_version_coltype.py +++ b/cardinal_pythonlib/sqlalchemy/semantic_version_coltype.py @@ -142,7 +142,7 @@ class comparator_factory(TypeDecorator.Comparator): which will be alphabetical and therefore wrong. Disabled on 2019-04-28. - """ # noqa + """ # noqa: E501 def operate(self, op, *other, **kwargs): assert len(other) == 1 diff --git a/cardinal_pythonlib/sqlalchemy/session.py b/cardinal_pythonlib/sqlalchemy/session.py index 811ef70..a59c480 100644 --- a/cardinal_pythonlib/sqlalchemy/session.py +++ b/cardinal_pythonlib/sqlalchemy/session.py @@ -120,5 +120,5 @@ def get_safe_url_from_url(url: str) -> str: """ Converts an SQLAlchemy URL into a safe version that obscures the password. """ - engine = create_engine(url) + engine = create_engine(url, future=True) return get_safe_url_from_engine(engine) diff --git a/cardinal_pythonlib/sqlalchemy/sqla_version.py b/cardinal_pythonlib/sqlalchemy/sqla_version.py index fe66034..d08ec67 100644 --- a/cardinal_pythonlib/sqlalchemy/sqla_version.py +++ b/cardinal_pythonlib/sqlalchemy/sqla_version.py @@ -26,10 +26,7 @@ """ -from semantic_version import Version -import sqlalchemy - -SQLA_VERSION = Version(sqlalchemy.__version__) -SQLA_SUPPORTS_POOL_PRE_PING = SQLA_VERSION >= Version("1.2.0") -SQLA_SUPPORTS_MYSQL_UPSERT = SQLA_VERSION >= Version("1.2.0") -# "upsert" = INSERT ... ON DUPLICATE KEY UPDATE +# SQLA_VERSION = Version(sqlalchemy.__version__) +# SQLA_SUPPORTS_POOL_PRE_PING = SQLA_VERSION >= Version("1.2.0") +# SQLA_SUPPORTS_MYSQL_UPSERT = SQLA_VERSION >= Version("1.2.0") +# ... "upsert" = INSERT ... ON DUPLICATE KEY UPDATE diff --git a/cardinal_pythonlib/sqlalchemy/sqlfunc.py b/cardinal_pythonlib/sqlalchemy/sqlfunc.py index 9c1d146..74f3914 100644 --- a/cardinal_pythonlib/sqlalchemy/sqlfunc.py +++ b/cardinal_pythonlib/sqlalchemy/sqlfunc.py @@ -138,7 +138,7 @@ def extract_year_default( def extract_year_year( element: "ClauseElement", compiler: "SQLCompiler", **kw ) -> str: - # https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_year # noqa + # https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_year # noqa: E501 # https://docs.microsoft.com/en-us/sql/t-sql/functions/year-transact-sql clause = fetch_processed_single_clause(element, compiler) return f"YEAR({clause})" diff --git a/cardinal_pythonlib/sqlalchemy/sqlserver.py b/cardinal_pythonlib/sqlalchemy/sqlserver.py index ec5c701..93b9c8a 100644 --- a/cardinal_pythonlib/sqlalchemy/sqlserver.py +++ b/cardinal_pythonlib/sqlalchemy/sqlserver.py @@ -28,10 +28,15 @@ from contextlib import contextmanager -from sqlalchemy.orm import Session as SqlASession - -from cardinal_pythonlib.sqlalchemy.dialect import quote_identifier -from cardinal_pythonlib.sqlalchemy.engine_func import is_sqlserver +from sqlalchemy.engine.base import Engine +from sqlalchemy.orm.session import Session as SqlASession +from sqlalchemy.schema import DDL + +from cardinal_pythonlib.sqlalchemy.dialect import ( + quote_identifier, + SqlaDialectName, +) +from cardinal_pythonlib.sqlalchemy.schema import execute_ddl from cardinal_pythonlib.sqlalchemy.session import get_engine_from_session @@ -40,6 +45,14 @@ # ============================================================================= +def _exec_ddl_if_sqlserver(engine: Engine, sql: str) -> None: + """ + Execute DDL only if we are running on Microsoft SQL Server. + """ + ddl = DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER) + execute_ddl(engine, ddl=ddl) + + @contextmanager def if_sqlserver_disable_constraints( session: SqlASession, tablename: str @@ -54,19 +67,18 @@ def if_sqlserver_disable_constraints( See https://stackoverflow.com/questions/123558/sql-server-2005-t-sql-to-temporarily-disable-a-trigger - """ # noqa + """ engine = get_engine_from_session(session) - if is_sqlserver(engine): - quoted_tablename = quote_identifier(tablename, engine) - session.execute( - f"ALTER TABLE {quoted_tablename} NOCHECK CONSTRAINT all" - ) - yield - session.execute( - f"ALTER TABLE {quoted_tablename} WITH CHECK CHECK CONSTRAINT all" - ) - else: - yield + quoted_tablename = quote_identifier(tablename, engine) + _exec_ddl_if_sqlserver( + engine, f"ALTER TABLE {quoted_tablename} NOCHECK CONSTRAINT all" + ) + yield + _exec_ddl_if_sqlserver( + engine, + f"ALTER TABLE {quoted_tablename} WITH CHECK CHECK CONSTRAINT all", + ) + # "CHECK CHECK" is correct here. @contextmanager @@ -83,15 +95,16 @@ def if_sqlserver_disable_triggers( See https://stackoverflow.com/questions/123558/sql-server-2005-t-sql-to-temporarily-disable-a-trigger - """ # noqa + """ engine = get_engine_from_session(session) - if is_sqlserver(engine): - quoted_tablename = quote_identifier(tablename, engine) - session.execute(f"ALTER TABLE {quoted_tablename} DISABLE TRIGGER all") - yield - session.execute(f"ALTER TABLE {quoted_tablename} ENABLE TRIGGER all") - else: - yield + quoted_tablename = quote_identifier(tablename, engine) + _exec_ddl_if_sqlserver( + engine, f"ALTER TABLE {quoted_tablename} DISABLE TRIGGER all" + ) + yield + _exec_ddl_if_sqlserver( + engine, f"ALTER TABLE {quoted_tablename} ENABLE TRIGGER all" + ) @contextmanager diff --git a/cardinal_pythonlib/sqlalchemy/tests/core_query_tests.py b/cardinal_pythonlib/sqlalchemy/tests/core_query_tests.py new file mode 100644 index 0000000..18403ef --- /dev/null +++ b/cardinal_pythonlib/sqlalchemy/tests/core_query_tests.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# cardinal_pythonlib/sqlalchemy/tests/core_query_tests.py + +""" +=============================================================================== + + Original code copyright (C) 2009-2022 Rudolf Cardinal (rudolf@pobox.com). + + This file is part of cardinal_pythonlib. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +=============================================================================== + +**Unit tests.** + +""" + +# ============================================================================= +# Imports +# ============================================================================= + +from unittest import TestCase + +from sqlalchemy.engine import create_engine +from sqlalchemy.orm.session import sessionmaker, Session +from sqlalchemy.sql.expression import column, select, table, text +from sqlalchemy.sql.schema import MetaData + +from cardinal_pythonlib.sqlalchemy.core_query import ( + count_star_and_max, + exists_in_table, + exists_plain, + fetch_all_first_values, + get_rows_fieldnames_from_raw_sql, + get_rows_fieldnames_from_select, +) +from cardinal_pythonlib.sqlalchemy.session import SQLITE_MEMORY_URL + + +# ============================================================================= +# Unit tests +# ============================================================================= + + +class CoreQueryTests(TestCase): + def __init__(self, *args, echo: bool = False, **kwargs) -> None: + self.echo = echo + super().__init__(*args, **kwargs) + + def setUp(self) -> None: + self.engine = create_engine( + SQLITE_MEMORY_URL, echo=self.echo, future=True + ) + self.tablename = "t" + self.a = "a" + self.b = "b" + self.a_val1 = 1 + self.a_val2 = 2 + self.b_val1 = 101 + self.b_val2 = 102 + self.emptytablename = "emptytable" + with self.engine.begin() as con: + con.execute( + text( + f"CREATE TABLE {self.tablename} " + f"(a INTEGER PRIMARY KEY, b INTEGER)" + ) + ) + con.execute( + text( + f"INSERT INTO {self.tablename} " + f"({self.a}, {self.b}) " + f"VALUES ({self.a_val1}, {self.b_val1})" + ) + ) + con.execute( + text( + f"INSERT INTO {self.tablename} " + f"({self.a}, {self.b}) " + f"VALUES ({self.a_val2}, {self.b_val2})" + ) + ) + con.execute( + text(f"CREATE TABLE {self.emptytablename} (x INTEGER)") + ) + self.session = sessionmaker( + bind=self.engine, future=True + )() # type: Session + self.metadata = MetaData() + self.metadata.reflect(bind=self.engine) + self.table = self.metadata.tables[self.tablename] + self.emptytable = self.metadata.tables[self.emptytablename] + + # noinspection DuplicatedCode + def test_get_rows_fieldnames_from_raw_sql(self) -> None: + sql = f"SELECT {self.a}, {self.b} FROM {self.tablename}" + rows, fieldnames = get_rows_fieldnames_from_raw_sql(self.session, sql) + self.assertEqual(fieldnames, [self.a, self.b]) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0], (self.a_val1, self.b_val1)) + self.assertEqual(rows[1], (self.a_val2, self.b_val2)) + + # noinspection DuplicatedCode + def test_get_rows_fieldnames_from_select(self) -> None: + query = select(self.table.c.a, self.table.c.b).select_from(self.table) + rows, fieldnames = get_rows_fieldnames_from_select(self.session, query) + self.assertEqual(fieldnames, [self.a, self.b]) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0], (self.a_val1, self.b_val1)) + self.assertEqual(rows[1], (self.a_val2, self.b_val2)) + + def test_count_star_and_max(self) -> None: + count, maximum = count_star_and_max( + self.session, self.tablename, self.b + ) + self.assertEqual(count, 2) + self.assertEqual(maximum, self.b_val2) + + def test_exists_in_table(self) -> None: + # exists: + exists1 = exists_in_table(self.session, self.table) + self.assertTrue(exists1) + exists2 = exists_in_table( + self.session, self.table, column(self.a) == 1 + ) + self.assertTrue(exists2) + # does not exist: + exists3 = exists_in_table( + self.session, self.table, column(self.a) == 99 + ) + self.assertFalse(exists3) + exists4 = exists_in_table(self.session, self.emptytable) + self.assertFalse(exists4) + + def test_exists_plain(self) -> None: + # exists: + exists1 = exists_plain(self.session, self.tablename) + self.assertTrue(exists1) + exists2 = exists_plain( + self.session, self.tablename, column(self.a) == 1 + ) + self.assertTrue(exists2) + # does not exist: + exists3 = exists_plain( + self.session, self.tablename, column(self.a) == 99 + ) + self.assertFalse(exists3) + exists4 = exists_plain( + self.session, + self.emptytablename, + ) + self.assertFalse(exists4) + + def test_fetch_all_first_values(self) -> None: + select_stmt = select(text("*")).select_from(table(self.tablename)) + firstvalues = fetch_all_first_values(self.session, select_stmt) + self.assertEqual(len(firstvalues), 2) + self.assertEqual(firstvalues, [self.a_val1, self.a_val2]) diff --git a/cardinal_pythonlib/sqlalchemy/tests/dump_tests.py b/cardinal_pythonlib/sqlalchemy/tests/dump_tests.py index e682861..5e14ba3 100644 --- a/cardinal_pythonlib/sqlalchemy/tests/dump_tests.py +++ b/cardinal_pythonlib/sqlalchemy/tests/dump_tests.py @@ -27,29 +27,35 @@ """ import logging +from io import StringIO +import re import unittest from sqlalchemy.engine import create_engine from sqlalchemy.engine.base import Engine from sqlalchemy.orm import declarative_base from sqlalchemy.orm.session import Session, sessionmaker -from sqlalchemy.schema import Column, MetaData, Table -from sqlalchemy.sql.expression import select +from sqlalchemy.schema import Column, Table +from sqlalchemy.sql.expression import select, text from sqlalchemy.sql.sqltypes import Integer, String +from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName from cardinal_pythonlib.sqlalchemy.dump import ( + dump_connection_info, + dump_ddl, + dump_table_as_insert_sql, get_literal_query, make_literal_query_fn, + COMMENT_SEP1, + COMMENT_SEP2, ) from cardinal_pythonlib.sqlalchemy.session import SQLITE_MEMORY_URL log = logging.getLogger(__name__) -Base = declarative_base() - # ============================================================================= -# Unit tests +# Helper functions # ============================================================================= @@ -57,7 +63,36 @@ def simplify_whitespace(statement: str) -> str: """ Standardize SQL by simplifying whitespace. """ - return statement.replace("\n", " ").replace(" ", " ") + x = statement.replace("\n", " ").replace("\t", " ") + x = re.sub(" +", " ", x) # replace multiple spaces with single space + return x.strip() + + +# ============================================================================= +# SQLAlchemy test framework +# ============================================================================= + +Base = declarative_base() + + +class Person(Base): + __tablename__ = "person" + pk = Column("pk", Integer, primary_key=True, autoincrement=True) + name = Column("name", Integer, index=True) + address = Column("address", Integer, index=False) + + +PET_TABLE = Table( + "pet", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), +) + + +# ============================================================================= +# Unit tests +# ============================================================================= class DumpTests(unittest.TestCase): @@ -69,31 +104,27 @@ def __init__(self, *args, echo: bool = False, **kwargs) -> None: self.echo = echo super().__init__(*args, **kwargs) - class Person(Base): - __tablename__ = "person" - pk = Column("pk", Integer, primary_key=True, autoincrement=True) - name = Column("name", Integer, index=True) - address = Column("address", Integer, index=False) - def setUp(self) -> None: + # NB This function gets executed for each test. Therefore, don't set + # up tables here using a class-specific metadata. super().setUp() self.engine = create_engine( - SQLITE_MEMORY_URL, echo=self.echo + SQLITE_MEMORY_URL, echo=self.echo, future=True ) # type: Engine self.dialect = self.engine.dialect - self.session = sessionmaker(bind=self.engine)() # type: Session - self.metadata = MetaData() - - self.pet = Table( - "pet", - self.metadata, - Column("id", Integer, primary_key=True), - Column("name", String(50)), - ) + self.session = sessionmaker( + bind=self.engine, future=True + )() # type: Session + + Base.metadata.create_all(bind=self.engine) + with self.engine.begin() as connection: + connection.execute( + text("INSERT INTO pet (id, name) VALUES (1, 'Garfield')") + ) def test_literal_query_method_1_base(self) -> None: - namecol = self.pet.columns.name + namecol = PET_TABLE.columns.name base_q = select(namecol).where(namecol == "Garfield") literal_query = make_literal_query_fn(self.dialect) literal = simplify_whitespace(literal_query(base_q)) @@ -102,7 +133,7 @@ def test_literal_query_method_1_base(self) -> None: ) def test_literal_query_method_1_orm(self) -> None: - orm_q = select(self.Person.name).where(self.Person.name == "Jon") + orm_q = select(Person.name).where(Person.name == "Jon") literal_query = make_literal_query_fn(self.dialect) literal = simplify_whitespace(literal_query(orm_q)) self.assertEqual( @@ -111,7 +142,7 @@ def test_literal_query_method_1_orm(self) -> None: ) def test_literal_query_method_2_base(self) -> None: - namecol = self.pet.columns.name + namecol = PET_TABLE.columns.name base_q = select(namecol).where(namecol == "Garfield") literal = simplify_whitespace( get_literal_query(base_q, bind=self.engine) @@ -121,10 +152,59 @@ def test_literal_query_method_2_base(self) -> None: ) def test_literal_query_method_2_orm(self) -> None: - orm_q = select(self.Person.pk).where(self.Person.name == "Jon") + orm_q = select(Person.pk).where(Person.name == "Jon") literal = simplify_whitespace( get_literal_query(orm_q, bind=self.engine) ) self.assertEqual( literal, "SELECT person.pk FROM person WHERE person.name = 'Jon';" ) + + def test_dump_connection_info(self) -> None: + s = StringIO() + dump_connection_info(engine=self.engine, fileobj=s) + txt = simplify_whitespace(s.getvalue()) + self.assertEqual(txt, f"-- Database info: {SQLITE_MEMORY_URL}") + + def test_dump_ddl(self) -> None: + s = StringIO() + dump_ddl( + metadata=Base.metadata, + dialect_name=SqlaDialectName.SQLITE, + fileobj=s, + ) + txt = simplify_whitespace(s.getvalue()) + self.assertEqual( + txt, + "-- Schema (for dialect sqlite): " + "CREATE TABLE person ( " + "pk INTEGER NOT NULL, " + "name INTEGER, " + "address INTEGER, " + "PRIMARY KEY (pk) " + ") " + "; " + "CREATE INDEX ix_person_name ON person (name); " + "CREATE TABLE pet ( " + "id INTEGER NOT NULL, " + "name VARCHAR(50), " + "PRIMARY KEY (id) " + ") " + ";", + ) + + def test_dump_table_as_insert_sql(self) -> None: + s = StringIO() + dump_table_as_insert_sql( + engine=self.engine, table_name="pet", fileobj=s, include_ddl=False + ) + txt = simplify_whitespace(s.getvalue()) + self.assertEqual( + txt, + f"{COMMENT_SEP1} " + f"-- Data for table: pet " + f"{COMMENT_SEP2} " + f"-- Filters: None " + f"INSERT INTO pet (id, name) VALUES (1, 'Garfield'); " + f"{COMMENT_SEP2}", + ) diff --git a/cardinal_pythonlib/sqlalchemy/tests/insert_on_duplicate_tests.py b/cardinal_pythonlib/sqlalchemy/tests/insert_on_duplicate_tests.py new file mode 100644 index 0000000..ddd16ea --- /dev/null +++ b/cardinal_pythonlib/sqlalchemy/tests/insert_on_duplicate_tests.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# cardinal_pythonlib/sqlalchemy/tests/insert_on_duplicate_tests.py + +""" +=============================================================================== + + Original code copyright (C) 2009-2022 Rudolf Cardinal (rudolf@pobox.com). + + This file is part of cardinal_pythonlib. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +=============================================================================== + +**Unit tests.** + +""" + +# ============================================================================= +# Imports +# ============================================================================= + +import logging +from unittest import TestCase + +from sqlalchemy import Column, String, Integer, create_engine +from sqlalchemy.dialects.mysql.base import MySQLDialect +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm.session import Session +from sqlalchemy.exc import IntegrityError + +from cardinal_pythonlib.sqlalchemy.insert_on_duplicate import ( + insert_with_upsert_if_supported, +) + +log = logging.getLogger(__name__) + + +# ============================================================================= +# Unit tests +# ============================================================================= + + +class InsertOnDuplicateKeyUpdateTests(TestCase): + def test_insert_with_upsert_if_supported_syntax(self) -> None: + # noinspection PyPep8Naming + Base = declarative_base() + + class OrmObject(Base): + __tablename__ = "sometable" + id = Column(Integer, primary_key=True) + name = Column(String) + + sqlite_engine = create_engine("sqlite://", echo=True, future=True) + Base.metadata.create_all(sqlite_engine) + + session = Session(sqlite_engine, future=True) + + d1 = dict(id=1, name="One") + d2 = dict(id=2, name="Two") + + table = OrmObject.__table__ + + insert_1 = table.insert().values(d1) + insert_2 = table.insert().values(d2) + session.execute(insert_1) + session.execute(insert_2) + with self.assertRaises(IntegrityError): + session.execute(insert_1) + + upsert_1 = insert_with_upsert_if_supported( + table=table, values=d1, session=session + ) + odku = "ON DUPLICATE KEY UPDATE" + self.assertFalse(odku in str(upsert_1)) + + upsert_2 = insert_with_upsert_if_supported( + table=table, values=d1, dialect=MySQLDialect() + ) + self.assertTrue(odku in str(upsert_2)) + + # We can't test fully here without a MySQL connection. + # But syntax tested separately in upsert_test_1.sql diff --git a/cardinal_pythonlib/sqlalchemy/tests/merge_db_tests.py b/cardinal_pythonlib/sqlalchemy/tests/merge_db_tests.py index 050de5e..8087eb3 100644 --- a/cardinal_pythonlib/sqlalchemy/tests/merge_db_tests.py +++ b/cardinal_pythonlib/sqlalchemy/tests/merge_db_tests.py @@ -72,13 +72,17 @@ class MergeTestMixin(object): def setUp(self) -> None: super().setUp() - self.src_engine = create_engine(SQLITE_MEMORY_URL) # type: Engine - self.dst_engine = create_engine(SQLITE_MEMORY_URL) # type: Engine + self.src_engine = create_engine( + SQLITE_MEMORY_URL, future=True + ) # type: Engine + self.dst_engine = create_engine( + SQLITE_MEMORY_URL, future=True + ) # type: Engine self.src_session = sessionmaker( - bind=self.src_engine + bind=self.src_engine, future=True )() # type: Session self.dst_session = sessionmaker( - bind=self.dst_engine + bind=self.dst_engine, future=True )() # type: Session def do_merge(self, dummy_run: bool = False) -> None: @@ -121,7 +125,7 @@ class MergeTestPlain(MergeTestMixin, unittest.TestCase): - If you use mixins, they go AFTER :class:`unittest.TestCase`; see https://stackoverflow.com/questions/1323455/python-unit-test-with-base-and-sub-class - """ # noqa: E501 + """ def setUp(self) -> None: super().setUp() diff --git a/cardinal_pythonlib/sqlalchemy/tests/orm_query_tests.py b/cardinal_pythonlib/sqlalchemy/tests/orm_query_tests.py new file mode 100644 index 0000000..8758b24 --- /dev/null +++ b/cardinal_pythonlib/sqlalchemy/tests/orm_query_tests.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# cardinal_pythonlib/sqlalchemy/tests/orm_query_tests.py + +""" +=============================================================================== + + Original code copyright (C) 2009-2022 Rudolf Cardinal (rudolf@pobox.com). + + This file is part of cardinal_pythonlib. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +=============================================================================== + +**Unit tests.** + +""" + +# ============================================================================= +# Imports +# ============================================================================= + +import logging +from unittest import TestCase + +from sqlalchemy.engine import create_engine +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm.session import sessionmaker +from sqlalchemy.schema import Column +from sqlalchemy.sql.expression import select +from sqlalchemy.sql.sqltypes import Integer, String + +from cardinal_pythonlib.sqlalchemy.core_query import ( + get_rows_fieldnames_from_select, +) +from cardinal_pythonlib.sqlalchemy.dialect import SqlaDialectName +from cardinal_pythonlib.sqlalchemy.orm_query import ( + CountStarSpecializedQuery, + bool_from_exists_clause, + exists_orm, + get_or_create, + get_rows_fieldnames_from_query, +) +from cardinal_pythonlib.sqlalchemy.session import SQLITE_MEMORY_URL + +log = logging.getLogger(__name__) + + +# ============================================================================= +# SQLAlchemy test framework +# ============================================================================= + +Base = declarative_base() + + +class Person(Base): + __tablename__ = "person" + pk = Column("pk", Integer, primary_key=True, autoincrement=True) + name = Column("name", Integer, index=True) + address = Column("address", Integer, index=False) + + +class Pet(Base): + __tablename__ = "pet" + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) + + +# ============================================================================= +# Unit tests +# ============================================================================= + + +class OrmQueryTests(TestCase): + def __init__(self, *args, echo: bool = False, **kwargs) -> None: + self.echo = echo + super().__init__(*args, **kwargs) + + def setUp(self) -> None: + self.engine = create_engine( + SQLITE_MEMORY_URL, echo=self.echo, future=True + ) + self.session = sessionmaker(bind=self.engine, future=True)() # for ORM + Base.metadata.create_all(bind=self.engine) + self._pet_1_name = "Garfield" + self.pet1 = Pet(id=1, name=self._pet_1_name) + self.session.add(self.pet1) + self.session.flush() + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_rows_fieldnames_from_select + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def test_get_rows_fieldnames_old_function_fails(self) -> None: + # Superseded + query = select(Pet.id, Pet.name).select_from(Pet.__table__) + with self.assertRaises(NotImplementedError): + get_rows_fieldnames_from_query(self.session, query) + + def test_get_rows_fieldnames_old_style_fails(self) -> None: + # Wrong type of object + query = self.session.query(Pet) + with self.assertRaises(ValueError): + get_rows_fieldnames_from_select(self.session, query) + + def test_get_rows_fieldnames_select_works(self) -> None: + # How it should be done in SQLAlchemy 2: select(), either with ORM + # classes or columns/column-like things. + query = select(Pet.id, Pet.name).select_from(Pet.__table__) + rows, fieldnames = get_rows_fieldnames_from_select(self.session, query) + self.assertEqual(fieldnames, ["id", "name"]) + self.assertEqual(rows, [(1, self._pet_1_name)]) + + def test_get_rows_fieldnames_whole_object_q_fails(self) -> None: + # We want to disallow querying + query = select(Pet) + with self.assertRaises(ValueError): + get_rows_fieldnames_from_select(self.session, query) + + def test_get_rows_fieldnames_no_rows_returns_fieldnames(self) -> None: + # Check zero-result queries still give fieldnames + query = select(Pet.id, Pet.name).where(Pet.name == "missing") + rows, fieldnames = get_rows_fieldnames_from_select(self.session, query) + self.assertEqual(fieldnames, ["id", "name"]) + self.assertEqual(rows, []) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # bool_from_exists_clause + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def test_bool_from_exists_clause_sqlite(self) -> None: + exists_q = self.session.query(Pet).exists() + b = bool_from_exists_clause(self.session, exists_q) + self.assertIsInstance(b, bool) + + def test_bool_from_exists_clause_sqlite_pretending_mysql(self) -> None: + self.session.get_bind().dialect.name = SqlaDialectName.MSSQL + # NB setUp() is called for each test, so this won't break others + exists_q = self.session.query(Pet).exists() + b = bool_from_exists_clause(self.session, exists_q) + self.assertIsInstance(b, bool) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # exists_orm + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def test_exists_orm_when_exists(self) -> None: + b = exists_orm(self.session, Pet) + self.assertIsInstance(b, bool) + self.assertEqual(b, True) + + def test_exists_orm_when_not_exists(self) -> None: + b = exists_orm(self.session, Person) + self.assertIsInstance(b, bool) + self.assertEqual(b, False) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_or_create + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def test_get_or_create_get(self) -> None: + p, newly_created = get_or_create(self.session, Pet, id=1) + self.assertIsInstance(p, Pet) + self.assertIsInstance(newly_created, bool) + self.assertEqual(p.id, 1) + self.assertEqual(p.name, self._pet_1_name) + self.assertEqual(newly_created, False) + + def test_get_or_create_create(self) -> None: + newid = 3 + newname = "Nermal" + p, newly_created = get_or_create( + self.session, Pet, id=newid, name=newname + ) + self.assertIsInstance(p, Pet) + self.assertIsInstance(newly_created, bool) + self.assertEqual(p.id, newid) + self.assertEqual(p.name, newname) + self.assertEqual(newly_created, True) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # CountStarSpecializedQuery + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def test_count_star_specialized_one(self) -> None: + # echo output also inspected + q = CountStarSpecializedQuery(Pet, session=self.session) + n = q.count_star() + self.assertIsInstance(n, int) + self.assertEqual(n, 1) + + def test_count_star_specialized_none(self) -> None: + # echo output also inspected + q = CountStarSpecializedQuery(Person, session=self.session) + n = q.count_star() + self.assertIsInstance(n, int) + self.assertEqual(n, 0) + + def test_count_star_specialized_filter(self) -> None: + # echo output also inspected + q = CountStarSpecializedQuery(Pet, session=self.session).filter( + Pet.name == self._pet_1_name + ) + n = q.count_star() + self.assertIsInstance(n, int) + self.assertEqual(n, 1) diff --git a/cardinal_pythonlib/sqlalchemy/tests/orm_schema_tests.py b/cardinal_pythonlib/sqlalchemy/tests/orm_schema_tests.py new file mode 100644 index 0000000..e20afc4 --- /dev/null +++ b/cardinal_pythonlib/sqlalchemy/tests/orm_schema_tests.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# cardinal_pythonlib/sqlalchemy/tests/orm_schema_tests.py + +""" +=============================================================================== + + Original code copyright (C) 2009-2022 Rudolf Cardinal (rudolf@pobox.com). + + This file is part of cardinal_pythonlib. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +=============================================================================== + +**Unit tests.** + +""" + +# ============================================================================= +# Imports +# ============================================================================= + +import logging +from unittest import TestCase + +from sqlalchemy.engine import create_engine +from sqlalchemy.orm import declarative_base +from sqlalchemy.schema import Column +from sqlalchemy.sql.sqltypes import Integer, String + +from cardinal_pythonlib.sqlalchemy.orm_schema import ( + create_table_from_orm_class, +) +from cardinal_pythonlib.sqlalchemy.session import SQLITE_MEMORY_URL + +log = logging.getLogger(__name__) + + +# ============================================================================= +# SQLAlchemy test framework +# ============================================================================= + +Base = declarative_base() + + +class Pet(Base): + __tablename__ = "pet" + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) + + +# ============================================================================= +# Unit tests +# ============================================================================= + + +class OrmQueryTests(TestCase): + def __init__(self, *args, echo: bool = False, **kwargs) -> None: + self.echo = echo + super().__init__(*args, **kwargs) + + def setUp(self) -> None: + self.engine = create_engine( + SQLITE_MEMORY_URL, echo=self.echo, future=True + ) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # create_table_from_orm_class + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def test_create_table_from_orm_class(self) -> None: + create_table_from_orm_class(self.engine, Pet) diff --git a/cardinal_pythonlib/sqlalchemy/tests/schema_tests.py b/cardinal_pythonlib/sqlalchemy/tests/schema_tests.py index 987e966..a7020d8 100644 --- a/cardinal_pythonlib/sqlalchemy/tests/schema_tests.py +++ b/cardinal_pythonlib/sqlalchemy/tests/schema_tests.py @@ -30,21 +30,72 @@ import unittest from sqlalchemy import event, inspect, select -from sqlalchemy.dialects.mssql.base import MSDialect +from sqlalchemy.dialects.mssql.base import MSDialect, DECIMAL as MS_DECIMAL from sqlalchemy.dialects.mysql.base import MySQLDialect from sqlalchemy.engine import create_engine +from sqlalchemy.exc import NoSuchTableError, OperationalError from sqlalchemy.ext import compiler from sqlalchemy.orm import declarative_base -from sqlalchemy.schema import Column, DDLElement, MetaData, Table +from sqlalchemy.schema import ( + Column, + CreateTable, + DDLElement, + Index, + MetaData, + Sequence, + Table, +) from sqlalchemy.sql import table -from sqlalchemy.sql.sqltypes import BigInteger, Integer, String +from sqlalchemy.sql.selectable import Select +from sqlalchemy.sql.sqltypes import ( + BigInteger, + Date, + DateTime, + Float, + LargeBinary, + Integer, + String, + Text, + Time, +) from cardinal_pythonlib.sqlalchemy.schema import ( + add_index, column_creation_ddl, + column_lists_equal, + column_types_equal, + columns_equal, + convert_sqla_type_for_dialect, + does_sqlatype_require_index_len, + execute_ddl, + gen_columns_info, + get_column_info, + get_column_names, + get_column_type, + get_effective_int_pk_col, + get_list_of_sql_string_literals_from_quoted_csv, + get_pk_colnames, + get_single_int_autoincrement_colname, + get_single_int_pk_colname, get_sqla_coltype_from_dialect_str, + get_table_names, get_view_names, index_exists, + is_sqlatype_binary, + is_sqlatype_date, + is_sqlatype_integer, + is_sqlatype_numeric, + is_sqlatype_string, + is_sqlatype_text_of_length_at_least, + is_sqlatype_text_over_one_char, make_bigint_autoincrement_column, + mssql_get_pk_index_name, + mssql_table_has_ft_index, + mssql_transaction_count, + remove_collation, + table_exists, + table_or_view_exists, + view_exists, ) from cardinal_pythonlib.sqlalchemy.session import SQLITE_MEMORY_URL @@ -52,6 +103,15 @@ log = logging.getLogger(__name__) +# ============================================================================= +# NOT TESTED from schema.py: +# ============================================================================= +# giant_text_sqltype +# does_sqlatype_merit_fulltext_index (see is_sqlatype_text_of_length_at_least) +# indexes_equal +# index_lists_equal + + # ============================================================================= # Tests # ============================================================================= @@ -59,6 +119,10 @@ class SchemaTests(unittest.TestCase): def test_schema_functions(self) -> None: + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # make_bigint_autoincrement_column + # column_creation_ddl + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ d_mssql = MSDialect() d_mysql = MySQLDialect() big_int_null_col = Column("hello", BigInteger, nullable=True) @@ -98,6 +162,9 @@ def test_schema_functions(self) -> None: ) # not big_int_autoinc_sequence_col; not supported by MySQL + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_sqla_coltype_from_dialect_str + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ log.info("Checking SQL type -> SQL Alchemy type") to_check = [ # mssql @@ -117,23 +184,74 @@ def test_schema_functions(self) -> None: ) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# index_exists +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + class IndexExistsTests(unittest.TestCase): - class Person(Base): - __tablename__ = "person" - pk = Column("pk", Integer, primary_key=True, autoincrement=True) - name = Column("name", Integer, index=True) - address = Column("address", Integer, index=False) + def __init__(self, *args, echo: bool = False, **kwargs) -> None: + self.echo = echo + super().__init__(*args, **kwargs) def setUp(self) -> None: super().setUp() + self.engine = create_engine( + SQLITE_MEMORY_URL, echo=self.echo, future=True + ) + metadata = MetaData() + self.person = Table( + "person", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), index=True), + Column("address", String(50)), + Index("my_index2", "id", "name"), + ) + # Expected indexes, therefore: + # 1. "ix_person_name" (by default naming convention) on person.name + # 2. "my_index2" on {person.id, person.name} + with self.engine.begin() as conn: + metadata.create_all(conn) - self.engine = create_engine(SQLITE_MEMORY_URL) + def test_bad_table(self) -> None: + with self.assertRaises(NoSuchTableError): + index_exists( + self.engine, + "nonexistent_table", + "does_not_matter", + raise_if_nonexistent_table=True, + ) def test_exists(self) -> None: - self.assertFalse(index_exists(self.engine, "person", "name")) + # First index: + self.assertTrue(index_exists(self.engine, "person", colnames="name")) + self.assertTrue(index_exists(self.engine, "person", colnames=["name"])) + # And by the default naming convention: + self.assertTrue( + index_exists(self.engine, "person", indexname="ix_person_name") + ) + # Second index: + self.assertTrue( + index_exists(self.engine, "person", colnames=["id", "name"]) + ) + self.assertTrue( + index_exists(self.engine, "person", indexname="my_index2") + ) def test_does_not_exist(self) -> None: - self.assertFalse(index_exists(self.engine, "person", "address")) + self.assertFalse(index_exists(self.engine, "person", indexname="name")) + self.assertFalse( + index_exists(self.engine, "person", colnames="address") + ) + self.assertFalse( + index_exists(self.engine, "person", colnames=["name", "address"]) + ) + + +# ----------------------------------------------------------------------------- +# Support code for view testing +# ----------------------------------------------------------------------------- # https://github.com/sqlalchemy/sqlalchemy/wiki/Views @@ -148,30 +266,41 @@ def __init__(self, name): self.name = name +# noinspection PyUnusedLocal @compiler.compiles(CreateView) -def _create_view(element, compiler, **kw): +def _create_view(element, compiler_, **kw): return "CREATE VIEW %s AS %s" % ( element.name, - compiler.sql_compiler.process(element.selectable, literal_binds=True), + compiler_.sql_compiler.process(element.selectable, literal_binds=True), ) +# noinspection PyUnusedLocal @compiler.compiles(DropView) -def _drop_view(element, compiler, **kw): - return "DROP VIEW %s" % (element.name) +def _drop_view(element, compiler_, **kw) -> str: + return "DROP VIEW %s" % element.name -def view_exists(ddl, target, connection, **kw): +# noinspection PyUnusedLocal +def _view_exists(ddl, target, connection, **kw) -> bool: return ddl.name in inspect(connection).get_view_names() -def view_doesnt_exist(ddl, target, connection, **kw): - return not view_exists(ddl, target, connection, **kw) +def _view_doesnt_exist(ddl, target, connection, **kw): + return not _view_exists(ddl, target, connection, **kw) -def view(name, metadata, selectable): - t = table(name) +def _attach_view( + tablename: str, metadata: MetaData, selectable: Select +) -> None: + """ + Attaches a view to a table of the given name, such that the view (which is + of "selectable") is created after the table is created, and dropped before + the table is dropped, via listeners. + """ + t = table(tablename) + # noinspection PyProtectedMember t._columns._populate_separate_keys( col._make_proxy(t) for col in selectable.selected_columns ) @@ -179,53 +308,440 @@ def view(name, metadata, selectable): event.listen( metadata, "after_create", - CreateView(name, selectable).execute_if(callable_=view_doesnt_exist), + CreateView(tablename, selectable).execute_if( + callable_=_view_doesnt_exist + ), ) event.listen( metadata, "before_drop", - DropView(name).execute_if(callable_=view_exists), + DropView(tablename).execute_if(callable_=_view_exists), ) - return t -class GetViewNamesTests(unittest.TestCase): +class MoreSchemaTests(unittest.TestCase): def setUp(self) -> None: super().setUp() - self.engine = create_engine(SQLITE_MEMORY_URL) - - def test_returns_list_of_database_views(self) -> None: + self.engine = create_engine(SQLITE_MEMORY_URL, future=True) metadata = MetaData() - person = Table( + self.person = Table( "person", metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), ) - view( + _attach_view( "one", metadata, - select(person.c.id.label("name")), + select(self.person.c.id.label("name")), ) - view( + _attach_view( "two", metadata, - select(person.c.id.label("name")), + select(self.person.c.id.label("name")), ) - view( + _attach_view( "three", metadata, - select(person.c.id.label("name")), + select(self.person.c.id.label("name")), ) with self.engine.begin() as conn: metadata.create_all(conn) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_table_names + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_table_names(self) -> None: + table_names = get_table_names(self.engine) + self.assertEqual(len(table_names), 1) + self.assertIn("person", table_names) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_view_names + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_view_names(self) -> None: view_names = get_view_names(self.engine) self.assertEqual(len(view_names), 3) self.assertIn("one", view_names) self.assertIn("two", view_names) self.assertIn("three", view_names) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # table_exists + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_table_exists(self) -> None: + self.assertTrue(table_exists(self.engine, "person")) + self.assertFalse(table_exists(self.engine, "nope")) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # view_exists + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_view_exists(self) -> None: + self.assertTrue(view_exists(self.engine, "one")) + self.assertFalse(view_exists(self.engine, "nope")) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # table_or_view_exists + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_table_or_view_exists(self) -> None: + self.assertTrue(table_or_view_exists(self.engine, "person")) # table + self.assertTrue(table_or_view_exists(self.engine, "one")) # view + self.assertFalse(table_or_view_exists(self.engine, "nope")) + + def test_get_column_info(self) -> None: + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # gen_columns_info + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ci_list = list(gen_columns_info(self.engine, "person")) + self.assertEqual(len(ci_list), 2) + ci_id = ci_list[0] + self.assertEqual(ci_id.name, "id") + self.assertIsInstance(ci_id.type, Integer) + self.assertEqual(ci_id.nullable, False) + self.assertEqual(ci_id.default, None) + self.assertEqual(ci_id.attrs, {}) + self.assertEqual(ci_id.comment, "") + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_column_info + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ci_id_2 = get_column_info(self.engine, "person", "id") + for a in ("name", "nullable", "default", "attrs", "comment"): + self.assertEqual(getattr(ci_id_2, a), getattr(ci_id, a)) + self.assertEqual(type(ci_id_2.type), type(ci_id.type)) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_column_type + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ci_id_3_type = get_column_type(self.engine, "person", "id") + self.assertEqual(type(ci_id_3_type), type(ci_id.type)) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # table_or_view_exists + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_column_names(self) -> None: + colnames = get_column_names(self.engine, "person") + self.assertEqual(colnames, ["id", "name"]) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_pk_colnames + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_pk_colnames(self) -> None: + pknames = get_pk_colnames(self.person) + self.assertEqual(pknames, ["id"]) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_single_int_pk_colname + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_single_int_pk_colname(self) -> None: + pkname = get_single_int_pk_colname(self.person) + self.assertEqual(pkname, "id") + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_single_int_autoincrement_colname (partial test) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_single_int_autoincrement_colname_a(self) -> None: + pkname = get_single_int_autoincrement_colname(self.person) + self.assertEqual(pkname, "id") + # This is present based on SQLAlchemy's default "auto" and its + # semantics for integer PKs. See below for one where it's forced off. + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_effective_int_pk_col + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_effective_int_pk_col(self) -> None: + pkname = get_effective_int_pk_col(self.person) + self.assertEqual(pkname, "id") + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # mssql_get_pk_index_name (partial test) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_mssql_get_pk_index_name(self) -> None: + with self.assertRaises(OperationalError): + # Bad SQL for SQLite. But should not raise NotImplementedError, + # which is what happens with query methods incompatible with + # SQLAlchemy 2.0. + mssql_get_pk_index_name(self.engine, "person") + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # mssql_table_has_ft_index (partial test) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_mssql_table_has_ft_index(self) -> None: + with self.assertRaises(OperationalError): + # As above + mssql_table_has_ft_index(self.engine, "person") + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # mssql_transaction_count (partial test) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_mssql_transaction_count(self) -> None: + with self.assertRaises(OperationalError): + # As above + mssql_transaction_count(self.engine) + + +class YetMoreSchemaTests(unittest.TestCase): + def __init__(self, *args, echo: bool = False, **kwargs) -> None: + self.echo = echo + super().__init__(*args, **kwargs) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # make_bigint_autoincrement_column + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def setUp(self) -> None: + super().setUp() + self.engine = create_engine( + SQLITE_MEMORY_URL, echo=self.echo, future=True + ) + self.metadata = MetaData() + self.person = Table( + "person", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("name", String(50)), + make_bigint_autoincrement_column("bigthing"), + ) + with self.engine.begin() as conn: + self.metadata.create_all(conn) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_single_int_autoincrement_colname (again) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_single_int_autoincrement_colname_b(self) -> None: + pkname = get_single_int_autoincrement_colname(self.person) + self.assertIsNone(pkname) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # add_index + # indexes_equal + # index_lists_equal + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_add_index(self) -> None: + add_index(self.engine, self.person.columns.name) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # column_creation_ddl + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_column_creation_ddl(self) -> None: + mssql_dialect = MSDialect() + mysql_dialect = MySQLDialect() + + col1 = Column("hello", BigInteger, nullable=True) + col2 = Column( + "world", BigInteger, autoincrement=True + ) # does NOT generate IDENTITY + col3 = Column( + "you", BigInteger, Sequence("dummy_name", start=1, increment=1) + ) + + self.metadata = MetaData() + t = Table("mytable", self.metadata) + t.append_column(col1) + t.append_column(col2) + t.append_column(col3) + # See column_creation_ddl() for reasons for attaching to a Table. + + self.assertEqual( + column_creation_ddl(col1, mssql_dialect), + "hello BIGINT NULL", + ) + self.assertEqual( + column_creation_ddl(col2, mssql_dialect), + "world BIGINT NOT NULL IDENTITY", + # used to be "world BIGINT NULL" + ) + self.assertEqual( + column_creation_ddl(col3, mssql_dialect), + "you BIGINT NOT NULL" + # used to be "you BIGINT NOT NULL IDENTITY(1,1)", + ) + + self.assertEqual( + column_creation_ddl(col1, mysql_dialect), + "hello BIGINT", + ) + self.assertEqual( + column_creation_ddl(col2, mysql_dialect), + "world BIGINT", + ) + self.assertEqual( + column_creation_ddl(col3, mysql_dialect), "you BIGINT" + ) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # remove_collation + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_remove_collation(self) -> None: + remove_collation(self.person.columns.name) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # convert_sqla_type_for_dialect (very basic only!) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_convert_sqla_type_for_dialect(self) -> None: + to_dialect = MySQLDialect() + c1 = convert_sqla_type_for_dialect(self.person.columns.id, to_dialect) + self.assertIsInstance(c1, Column) + c2 = convert_sqla_type_for_dialect( + self.person.columns.name, to_dialect + ) + self.assertIsInstance(c2, Column) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # column_types_equal + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_column_types_equal(self) -> None: + self.assertTrue(column_types_equal(self.person.c.id, self.person.c.id)) + self.assertFalse( + column_types_equal(self.person.c.id, self.person.c.name) + ) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # column_types_equal + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_columns_equal(self) -> None: + self.assertTrue(columns_equal(self.person.c.id, self.person.c.id)) + self.assertFalse(columns_equal(self.person.c.id, self.person.c.name)) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # column_lists_equal + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_column_lists_equal(self) -> None: + a = self.person.c.id + b = self.person.c.name + self.assertTrue(column_lists_equal([a, b], [a, b])) + self.assertFalse(column_lists_equal([a, b], [b, a])) + self.assertFalse(column_lists_equal([a, b], [a])) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # execute_ddl + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_execute_ddl(self) -> None: + sql = "CREATE TABLE x (a INT)" + execute_ddl(self.engine, sql=sql) + + ddl = CreateTable(Table("y", self.metadata, Column("z", Integer))) + execute_ddl(self.engine, ddl=ddl) + + with self.assertRaises(AssertionError): + execute_ddl(self.engine, sql=sql, ddl=ddl) # both + with self.assertRaises(AssertionError): + execute_ddl(self.engine) # neither + + +class SchemaAbstractTests(unittest.TestCase): + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # get_list_of_sql_string_literals_from_quoted_csv + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_get_list_of_sql_string_literals_from_quoted_csv(self) -> None: + self.assertEqual( + get_list_of_sql_string_literals_from_quoted_csv("'a', 'b', 'c'"), + ["a", "b", "c"], + ) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # is_sqlatype_binary + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_is_sqlatype_binary(self) -> None: + self.assertTrue(is_sqlatype_binary(LargeBinary())) + + self.assertFalse(is_sqlatype_binary(Integer())) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # is_sqlatype_date + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_is_sqlatype_date(self) -> None: + self.assertTrue(is_sqlatype_date(Date())) + self.assertTrue(is_sqlatype_date(DateTime())) + + self.assertFalse(is_sqlatype_date(Integer())) + self.assertFalse(is_sqlatype_date(Time())) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # is_sqlatype_integer + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_is_sqlatype_integer(self) -> None: + self.assertTrue(is_sqlatype_integer(Integer())) + + self.assertFalse(is_sqlatype_integer(Float())) + self.assertFalse(is_sqlatype_integer(Date())) + self.assertFalse(is_sqlatype_integer(DateTime())) + self.assertFalse(is_sqlatype_integer(Time())) + self.assertFalse(is_sqlatype_integer(String())) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # is_sqlatype_numeric + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_is_sqlatype_numeric(self) -> None: + self.assertTrue(is_sqlatype_numeric(Float())) + self.assertTrue(is_sqlatype_numeric(MS_DECIMAL())) + + self.assertFalse(is_sqlatype_numeric(Integer())) # False! + + self.assertFalse(is_sqlatype_numeric(Date())) + self.assertFalse(is_sqlatype_numeric(DateTime())) + self.assertFalse(is_sqlatype_numeric(Time())) + self.assertFalse(is_sqlatype_numeric(String())) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # is_sqlatype_string + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_is_sqlatype_string(self) -> None: + self.assertTrue(is_sqlatype_string(String())) + self.assertTrue(is_sqlatype_string(Text())) + + self.assertFalse(is_sqlatype_string(Integer())) + self.assertFalse(is_sqlatype_string(Float())) + self.assertFalse(is_sqlatype_string(Date())) + self.assertFalse(is_sqlatype_string(DateTime())) + self.assertFalse(is_sqlatype_string(Time())) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # is_sqlatype_text_of_length_at_least + # ... and thus the trivial function does_sqlatype_merit_fulltext_index + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_is_sqlatype_text_of_length_at_least(self) -> None: + testlen = 5 + self.assertTrue( + is_sqlatype_text_of_length_at_least(String(testlen), testlen) + ) + self.assertTrue( + is_sqlatype_text_of_length_at_least(String(testlen + 1), testlen) + ) + self.assertTrue(is_sqlatype_text_of_length_at_least(Text(), testlen)) + + self.assertFalse( + is_sqlatype_text_of_length_at_least(String(testlen - 1), testlen) + ) + self.assertFalse( + is_sqlatype_text_of_length_at_least(Integer(), testlen) + ) + self.assertFalse(is_sqlatype_text_of_length_at_least(Float(), testlen)) + self.assertFalse(is_sqlatype_text_of_length_at_least(Date(), testlen)) + self.assertFalse( + is_sqlatype_text_of_length_at_least(DateTime(), testlen) + ) + self.assertFalse(is_sqlatype_text_of_length_at_least(Time(), testlen)) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # is_sqlatype_text_over_one_char + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_is_sqlatype_text_over_one_char(self) -> None: + self.assertTrue(is_sqlatype_text_over_one_char(String(2))) + self.assertTrue(is_sqlatype_text_over_one_char(Text())) + + self.assertFalse(is_sqlatype_text_over_one_char(String(1))) + self.assertFalse(is_sqlatype_text_over_one_char(Integer())) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # does_sqlatype_require_index_len + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + def test_does_sqlatype_require_index_len(self) -> None: + self.assertTrue(does_sqlatype_require_index_len(Text())) + self.assertTrue(does_sqlatype_require_index_len(LargeBinary())) + + self.assertFalse(does_sqlatype_require_index_len(String(1))) + self.assertFalse(does_sqlatype_require_index_len(Integer())) diff --git a/cardinal_pythonlib/sqlalchemy/tests/upsert_test_1.sql b/cardinal_pythonlib/sqlalchemy/tests/upsert_test_1.sql new file mode 100644 index 0000000..3532a72 --- /dev/null +++ b/cardinal_pythonlib/sqlalchemy/tests/upsert_test_1.sql @@ -0,0 +1,17 @@ +-- cardinal_pythonlib/sqlalchemy/tests/upsert_test_1.sql + +-- For MySQL: + +USE crud; -- pre-existing database + +CREATE TABLE ut (a INTEGER PRIMARY KEY, b INTEGER); + +INSERT INTO ut (a, b) VALUES (1, 101); -- OK +INSERT INTO ut (a, b) VALUES (2, 102); -- OK + +INSERT INTO ut (a, b) VALUES (1, 101); -- fails; duplicate key + +INSERT INTO ut (a, b) VALUES (1, 101) ON DUPLICATE KEY UPDATE a = 1, b = 103; + -- succeeds and changes only one row + +SELECT * FROM ut; diff --git a/cardinal_pythonlib/stringfunc.py b/cardinal_pythonlib/stringfunc.py index 8c574b5..f6f3eed 100644 --- a/cardinal_pythonlib/stringfunc.py +++ b/cardinal_pythonlib/stringfunc.py @@ -45,7 +45,7 @@ def find_nth(s: str, x: str, n: int = 0, overlap: bool = False) -> int: As per https://stackoverflow.com/questions/1883980/find-the-nth-occurrence-of-substring-in-a-string - """ # noqa: E501 + """ length_of_fragment = 1 if overlap else len(x) i = -length_of_fragment for _ in range(n + 1): diff --git a/cardinal_pythonlib/subproc.py b/cardinal_pythonlib/subproc.py index f6206a3..9bd69fb 100644 --- a/cardinal_pythonlib/subproc.py +++ b/cardinal_pythonlib/subproc.py @@ -302,7 +302,7 @@ def run_multiple_processes( # f"(indicating failure); its args were: " # f"{args_list[procidx]!r}") # if die_on_failure: - # # https://stackoverflow.com/questions/29177490/how-do-you-kill-futures-once-they-have-started # noqa + # # https://stackoverflow.com/questions/29177490/how-do-you-kill-futures-once-they-have-started # noqa: E501 # for f2 in future_to_procidx.keys(): # f2.cancel() # # ... prevents more jobs being scheduled, except the @@ -516,7 +516,7 @@ def mimic_user_input( print_stdout=show_zip_output, print_stdin=show_zip_output) - """ # noqa + """ # noqa: E501 line_terminators = line_terminators or ["\n"] # type: List[str] stdin_encoding = stdin_encoding or sys.getdefaultencoding() stdout_encoding = stdout_encoding or sys.getdefaultencoding() diff --git a/cardinal_pythonlib/tee.py b/cardinal_pythonlib/tee.py index 326e1a4..dddbd2f 100644 --- a/cardinal_pythonlib/tee.py +++ b/cardinal_pythonlib/tee.py @@ -131,7 +131,7 @@ def tee(infile: IO, *files: IO) -> Thread: x = t.readline() # "hello\n" y = b.readline() # b"world\n" - """ # noqa: E501 + """ def fanout(_infile: IO, *_files: IO): for line in iter(_infile.readline, ""): @@ -170,7 +170,7 @@ def teed_call( encoding: encoding to apply to ``stdout`` and ``stderr`` kwargs: additional arguments for :class:`subprocess.Popen` - """ # noqa: E501 + """ # Make a copy so we can append without damaging the original: stdout_targets = ( stdout_targets.copy() if stdout_targets else [] diff --git a/cardinal_pythonlib/tests/dogpile_cache_tests.py b/cardinal_pythonlib/tests/dogpile_cache_tests.py index 54d99a8..5621862 100644 --- a/cardinal_pythonlib/tests/dogpile_cache_tests.py +++ b/cardinal_pythonlib/tests/dogpile_cache_tests.py @@ -348,7 +348,7 @@ def no_params_instance_cache(self) -> str: return f"TestClass.no_params_instance_cache: a={self.a}" # Decorator order is critical here: - # https://stackoverflow.com/questions/1987919/why-can-decorator-not-decorate-a-staticmethod-or-a-classmethod # noqa + # https://stackoverflow.com/questions/1987919/why-can-decorator-not-decorate-a-staticmethod-or-a-classmethod # noqa: E501 @classmethod @mycache.cache_on_arguments(function_key_generator=plain_fkg) def classy(cls) -> str: @@ -410,7 +410,7 @@ def no_params_instance_cache(self) -> str: return f"Inherited.no_params_instance_cache: a={self.a}" # Decorator order is critical here: - # https://stackoverflow.com/questions/1987919/why-can-decorator-not-decorate-a-staticmethod-or-a-classmethod # noqa + # https://stackoverflow.com/questions/1987919/why-can-decorator-not-decorate-a-staticmethod-or-a-classmethod # noqa: E501 @classmethod @mycache.cache_on_arguments(function_key_generator=plain_fkg) def classy(cls) -> str: diff --git a/cardinal_pythonlib/tests/rounding_tests.py b/cardinal_pythonlib/tests/rounding_tests.py index 3fa4c16..6b57f18 100644 --- a/cardinal_pythonlib/tests/rounding_tests.py +++ b/cardinal_pythonlib/tests/rounding_tests.py @@ -201,7 +201,7 @@ def test_range_roundable_up_to(self) -> None: Decimal("-0.5"), ) - # range_roundable_up_to(0.5, 0) # bad input (not correctly rounded); would assert # noqa + # range_roundable_up_to(0.5, 0) # bad input (not correctly rounded); would assert # noqa: E501 for x in [Decimal("100.332"), Decimal("-150.12")]: for dp in [-2, -1, 0, 1, 2]: @@ -243,7 +243,7 @@ def test_range_truncatable_to(self) -> None: Decimal("-1"), ) - # range_truncatable_to(0.5, 0) # bad input (not correctly rounded); would assert # noqa + # range_truncatable_to(0.5, 0) # bad input (not correctly rounded); would assert # noqa: E501 for x in [Decimal("100.332"), Decimal("-150.12")]: for dp in [-2, -1, 0, 1, 2]: diff --git a/cardinal_pythonlib/tests/rpm_tests.py b/cardinal_pythonlib/tests/rpm_tests.py index 9cd501d..58f151d 100644 --- a/cardinal_pythonlib/tests/rpm_tests.py +++ b/cardinal_pythonlib/tests/rpm_tests.py @@ -162,7 +162,7 @@ def test_fast_rpm_twochoice(self) -> None: seen.add(forwards) successes = [success_this, success_other] failures = [failure_this, failure_other] - p_fast_this = rpm_probabilities_successes_failures_twochoice_fast( # noqa + p_fast_this = rpm_probabilities_successes_failures_twochoice_fast( # noqa: E501 success_this, failure_this, success_other, diff --git a/cardinal_pythonlib/text.py b/cardinal_pythonlib/text.py index 7f266b1..78ffa02 100644 --- a/cardinal_pythonlib/text.py +++ b/cardinal_pythonlib/text.py @@ -158,7 +158,7 @@ def _unicode_def_src_to_str(srclist: List[Union[str, int]]) -> str: # https://stackoverflow.com/questions/13233076/determine-if-a-unicode-character-is-alphanumeric-without-using-a-regular-express # noqa: E501 _UNICODE_CATEGORY_SRC = { - # From https://github.com/slevithan/xregexp/blob/master/tools/scripts/property-regex.py # noqa + # From https://github.com/slevithan/xregexp/blob/master/tools/scripts/property-regex.py # noqa: E501 "ASCII": ["0000-007F"], "Alphabetic": [ "0041-005A", @@ -2241,7 +2241,7 @@ def _unicode_def_src_to_str(srclist: List[Union[str, int]]) -> str: "1F130-1F149", "1F150-1F169", "1F170-1F189", - ], # noqa + ], "White_Space": [ "0009-000D", 0x0020, @@ -2254,7 +2254,7 @@ def _unicode_def_src_to_str(srclist: List[Union[str, int]]) -> str: 0x202F, 0x205F, 0x3000, - ], # noqa + ], # From https://en.wikipedia.org/wiki/Latin_script_in_Unicode "Latin": [ "0000-007F", # Basic Latin; this block corresponds to ASCII. @@ -2286,7 +2286,7 @@ def _unicode_def_src_to_str(srclist: List[Union[str, int]]) -> str: # more symbols "00C0-00D6", # Basic Latin: accented capitals # multiplication symbol - "00D8-00F6", # Basic Latin: more accented capitals, something odd, Eszett, accented lower case # noqa + "00D8-00F6", # Basic Latin: more accented capitals, something odd, Eszett, accented lower case # noqa: E501 # division symbol "00F8-00FF", # Basic Latin: more accented... "0100-017F", # Latin Extended-A @@ -2304,7 +2304,7 @@ def _unicode_def_src_to_str(srclist: List[Union[str, int]]) -> str: "A7B0-A7B7", # Latin Extended-D: part 2 "A7F7-A7FF", # Latin Extended-D: part 3 "AB30-AB65", # Latin Extended-E: those assigned - "FB00-FB06", # Alphabetic Presentation Forms (Latin ligatures): those assigned # noqa + "FB00-FB06", # Alphabetic Presentation Forms (Latin ligatures): those assigned # noqa: E501 "FF20-FF5F", # Halfwidth and Fullwidth Forms: those assigned ], } diff --git a/cardinal_pythonlib/timing.py b/cardinal_pythonlib/timing.py index 801a7ef..242c091 100644 --- a/cardinal_pythonlib/timing.py +++ b/cardinal_pythonlib/timing.py @@ -160,7 +160,7 @@ def report(self) -> None: "total": total_sec, "description": ( f"- {name}: {total_sec:.3f} s " - f"({(100 * total_sec / grand_total.total_seconds()):.2f}%, " # noqa + f"({(100 * total_sec / grand_total.total_seconds()):.2f}%, " # noqa: E501 f"n={n}, mean={mean:.3f}s)" ), } diff --git a/cardinal_pythonlib/tools/convert_athena_ohdsi_codes.py b/cardinal_pythonlib/tools/convert_athena_ohdsi_codes.py index 398239d..4b4e61a 100644 --- a/cardinal_pythonlib/tools/convert_athena_ohdsi_codes.py +++ b/cardinal_pythonlib/tools/convert_athena_ohdsi_codes.py @@ -34,7 +34,7 @@ cardinalpythonlib_convert_athena_ohdsi_codes 175898006 118677009 265764009 --src_vocabulary SNOMED --descendants --dest_vocabulary OPCS4 > renal_procedures_opcs4.txt # ... kidney operation, procedure on urinary system, renal dialysis -""" # noqa +""" # noqa: E501 import argparse import logging diff --git a/cardinal_pythonlib/tools/convert_mdb_to_mysql.py b/cardinal_pythonlib/tools/convert_mdb_to_mysql.py index 54d9e2f..8f2b397 100644 --- a/cardinal_pythonlib/tools/convert_mdb_to_mysql.py +++ b/cardinal_pythonlib/tools/convert_mdb_to_mysql.py @@ -61,7 +61,7 @@ - REVISED 16 Jan 2017: conversion to Python 3. - Fixed a bit more, 2020-01-19. Also type hinting. -""" # noqa +""" # noqa: E501 import argparse import getpass @@ -146,7 +146,7 @@ class PasswordPromptAction(argparse.Action): Modified from https://stackoverflow.com/questions/27921629/python-using-getpass-with-argparse - """ # noqa + """ # noinspection PyShadowingBuiltins def __init__( @@ -318,7 +318,7 @@ def main() -> None: # ... BUT (Jan 2013): now mdb-tools is better, text-processing not # necessary - can use temporary disk file # Turns out the bottleneck is the import to MySQL, not the export from MDB. - # So see http://dev.mysql.com/doc/refman/5.5/en/optimizing-innodb-bulk-data-loading.html # noqa + # So see http://dev.mysql.com/doc/refman/5.5/en/optimizing-innodb-bulk-data-loading.html # noqa: E501 # The massive improvement is by disabling autocommit. (Example source # database is 208M; largest table here is 554M as a textfile; it has # 1,686,075 rows.) This improvement was from 20 Hz to the whole database diff --git a/cardinal_pythonlib/tools/explore_clang_format_config.py b/cardinal_pythonlib/tools/explore_clang_format_config.py index 92c7f28..94cb8b0 100644 --- a/cardinal_pythonlib/tools/explore_clang_format_config.py +++ b/cardinal_pythonlib/tools/explore_clang_format_config.py @@ -71,7 +71,7 @@ def clang_format( # you can specify "-style=PATH_TO_CONFIG", but no; you have to use # "-style=file" literally, and have the config file correctly named in the # current directory. - # https://stackoverflow.com/questions/46373858/how-do-i-specify-a-clang-format-file # noqa + # https://stackoverflow.com/questions/46373858/how-do-i-specify-a-clang-format-file # noqa: E501 fixed_config_filename = ".clang-format" fixed_config_path = os.path.join(dir, fixed_config_filename) diff --git a/cardinal_pythonlib/tools/pdf_to_booklet.py b/cardinal_pythonlib/tools/pdf_to_booklet.py index 73e6cdb..0c1c035 100644 --- a/cardinal_pythonlib/tools/pdf_to_booklet.py +++ b/cardinal_pythonlib/tools/pdf_to_booklet.py @@ -245,7 +245,7 @@ def make_blank_pdf(filename: str, paper: str = "A4") -> None: NOT USED. Makes a blank single-page PDF, using ImageMagick's ``convert``. """ - # https://unix.stackexchange.com/questions/277892/how-do-i-create-a-blank-pdf-from-the-command-line # noqa + # https://unix.stackexchange.com/questions/277892/how-do-i-create-a-blank-pdf-from-the-command-line # noqa: E501 require(CONVERT, HELP_MISSING_IMAGEMAGICK) run([CONVERT, "xc:none", "-page", paper, filename]) diff --git a/cardinal_pythonlib/tools/remove_duplicate_files.py b/cardinal_pythonlib/tools/remove_duplicate_files.py index ab2f091..2763dbb 100644 --- a/cardinal_pythonlib/tools/remove_duplicate_files.py +++ b/cardinal_pythonlib/tools/remove_duplicate_files.py @@ -25,7 +25,7 @@ **Command-line tool to remove duplicate files from a path.** Largely based on -https://code.activestate.com/recipes/362459-dupinator-detect-and-delete-duplicate-files/ # noqa +https://code.activestate.com/recipes/362459-dupinator-detect-and-delete-duplicate-files/ # noqa: E501 """ @@ -69,7 +69,7 @@ def deduplicate( # ------------------------------------------------------------------------- files_by_size = ( {} - ) # type: Dict[int, List[str]] # maps size to list of filenames # noqa + ) # type: Dict[int, List[str]] # maps size to list of filenames num_considered = 0 for filename in gen_filenames(directories, recursive=recursive): if not os.path.isfile(filename): diff --git a/cardinal_pythonlib/typing_helpers.py b/cardinal_pythonlib/typing_helpers.py index a111403..e599226 100644 --- a/cardinal_pythonlib/typing_helpers.py +++ b/cardinal_pythonlib/typing_helpers.py @@ -68,7 +68,7 @@ def with_typehint(baseclass: Type[T]) -> Type[T]: class MyMixin1(with_typehint(SomeBaseClass))): # ... - """ # noqa: E501 + """ if TYPE_CHECKING: return baseclass return object @@ -88,7 +88,7 @@ def with_typehints(*baseclasses: Type[T]) -> Type[T]: class MyMixin2(*with_typehints(SomeBaseClass, AnotherBaseClass))): # ... - """ # noqa: E501 + """ if TYPE_CHECKING: return baseclasses return object diff --git a/cardinal_pythonlib/version_string.py b/cardinal_pythonlib/version_string.py index 32c15a8..7788efd 100644 --- a/cardinal_pythonlib/version_string.py +++ b/cardinal_pythonlib/version_string.py @@ -31,5 +31,5 @@ """ -VERSION_STRING = "1.1.27" -# Use semantic versioning: http://semver.org/ +VERSION_STRING = "2.0.0" +# Use semantic versioning: https://semver.org/ diff --git a/cardinal_pythonlib/winservice.py b/cardinal_pythonlib/winservice.py index c230009..6f6c223 100644 --- a/cardinal_pythonlib/winservice.py +++ b/cardinal_pythonlib/winservice.py @@ -716,7 +716,7 @@ def _kill(self) -> None: https://stackoverflow.com/questions/1230669/subprocess-deleting-child-processes-in-windows, which uses ``psutil``. - """ # noqa + """ self.warning("Using a recursive hard kill; will assume it worked") pid = self.process.pid gone, still_alive = kill_proc_tree( diff --git a/cardinal_pythonlib/wsgi/cache_mw.py b/cardinal_pythonlib/wsgi/cache_mw.py index fa0ca0b..213e64e 100644 --- a/cardinal_pythonlib/wsgi/cache_mw.py +++ b/cardinal_pythonlib/wsgi/cache_mw.py @@ -45,8 +45,8 @@ # ============================================================================= # DisableClientSideCachingMiddleware # ============================================================================= -# https://stackoverflow.com/questions/49547/making-sure-a-web-page-is-not-cached-across-all-browsers # noqa -# https://stackoverflow.com/questions/3859097/how-to-add-http-headers-in-wsgi-middleware # noqa +# https://stackoverflow.com/questions/49547/making-sure-a-web-page-is-not-cached-across-all-browsers # noqa: E501 +# https://stackoverflow.com/questions/3859097/how-to-add-http-headers-in-wsgi-middleware # noqa: E501 def add_never_cache_headers(headers: TYPE_WSGI_RESPONSE_HEADERS) -> None: @@ -55,7 +55,7 @@ def add_never_cache_headers(headers: TYPE_WSGI_RESPONSE_HEADERS) -> None: """ headers.append( ("Cache-Control", "no-cache, no-store, must-revalidate") - ) # HTTP 1.1 # noqa + ) # HTTP 1.1 headers.append(("Pragma", "no-cache")) # HTTP 1.0 headers.append(("Expires", "0")) # Proxies diff --git a/cardinal_pythonlib/wsgi/constants.py b/cardinal_pythonlib/wsgi/constants.py index c926444..e786669 100644 --- a/cardinal_pythonlib/wsgi/constants.py +++ b/cardinal_pythonlib/wsgi/constants.py @@ -119,10 +119,10 @@ class WsgiEnvVar(object): # optional); https://wsgi.readthedocs.io/en/latest/definitions.html # [3] Also standard WSGI, but not CGI; must always be present. # [4] From non-standard but common HTTP request fields; - # https://en.wikipedia.org/wiki/List_of_HTTP_header_fields#Common_non-standard_request_fields # noqa - # https://github.com/omnigroup/Apache/blob/master/httpd/modules/proxy/mod_proxy_http.c # noqa + # https://en.wikipedia.org/wiki/List_of_HTTP_header_fields#Common_non-standard_request_fields # noqa: E501 + # https://github.com/omnigroup/Apache/blob/master/httpd/modules/proxy/mod_proxy_http.c # noqa: E501 # [5] Non-standard; Nginx-specific? Nonetheless, all "HTTP_" variables in # WSGI should follow the HTTP request headers. # [6] Protocols (i.e. http versus https): - # https://stackoverflow.com/questions/16042647/whats-the-de-facto-standard-for-a-reverse-proxy-to-tell-the-backend-ssl-is-used # noqa - # [7] https://modwsgi.readthedocs.io/en/develop/release-notes/version-4.4.9.html # noqa + # https://stackoverflow.com/questions/16042647/whats-the-de-facto-standard-for-a-reverse-proxy-to-tell-the-backend-ssl-is-used # noqa: E501 + # [7] https://modwsgi.readthedocs.io/en/develop/release-notes/version-4.4.9.html # noqa: E501 diff --git a/cardinal_pythonlib/wsgi/request_logging_mw.py b/cardinal_pythonlib/wsgi/request_logging_mw.py index 8854ad5..90b2874 100644 --- a/cardinal_pythonlib/wsgi/request_logging_mw.py +++ b/cardinal_pythonlib/wsgi/request_logging_mw.py @@ -107,8 +107,8 @@ def __call__( ) -> TYPE_WSGI_APP_RESULT: query_string = environ.get(WsgiEnvVar.QUERY_STRING, "") try: - # https://stackoverflow.com/questions/7835030/obtaining-client-ip-address-from-a-wsgi-app-using-eventlet # noqa - # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For # noqa + # https://stackoverflow.com/questions/7835030/obtaining-client-ip-address-from-a-wsgi-app-using-eventlet # noqa: E501 + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For # noqa: E501 forwarded_for = " [forwarded for {}]".format( environ[WsgiEnvVar.HTTP_X_FORWARDED_FOR] ) diff --git a/cardinal_pythonlib/wsgi/reverse_proxied_mw.py b/cardinal_pythonlib/wsgi/reverse_proxied_mw.py index 7e7005e..782af88 100644 --- a/cardinal_pythonlib/wsgi/reverse_proxied_mw.py +++ b/cardinal_pythonlib/wsgi/reverse_proxied_mw.py @@ -63,7 +63,7 @@ def ip_addresses_from_xff(value: str) -> List[str]: See: - https://en.wikipedia.org/wiki/X-Forwarded-For - - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For # noqa + - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For # noqa: E501 - NOT THIS: https://tools.ietf.org/html/rfc7239 """ if not value: @@ -170,7 +170,7 @@ def first_from_xff(value: str) -> str: Require all granted -""" # noqa +""" # noqa: E501 class ReverseProxiedConfig(object): @@ -281,7 +281,7 @@ class ReverseProxiedMiddleware(object): - http://modwsgi.readthedocs.io/en/develop/release-notes/version-4.4.9.html - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers - """ # noqa + """ # noqa: E501 CANDIDATES_HTTP_HOST = [ # These are variables that may contain a value for HTTP_HOST. @@ -295,7 +295,7 @@ class ReverseProxiedMiddleware(object): CANDIDATES_REMOTE_ADDR = [ # These are variables that may contain a value for REMOTE_ADDR. # However, they differ: - WsgiEnvVar.HTTP_X_FORWARDED_FOR, # may contain many values; first is taken # noqa + WsgiEnvVar.HTTP_X_FORWARDED_FOR, # may contain many values; first is taken # noqa: E501 WsgiEnvVar.HTTP_X_REAL_IP, # may contain only one ] _CANDIDATES_URL_SCHEME_GIVING_PROTOCOL = [ @@ -496,13 +496,13 @@ def __call__( Should we be looking at HTTP_X_FORWARDED_HOST or HTTP_X_FORWARDED_SERVER? - See https://github.com/omnigroup/Apache/blob/master/httpd/modules/proxy/mod_proxy_http.c # noqa + See https://github.com/omnigroup/Apache/blob/master/httpd/modules/proxy/mod_proxy_http.c ... and let's follow mod_wsgi. ----------------------------------------------------------------------- HTTP_HOST versus SERVER_NAME ----------------------------------------------------------------------- - https://stackoverflow.com/questions/2297403/what-is-the-difference-between-http-host-and-server-name-in-php # noqa + https://stackoverflow.com/questions/2297403/what-is-the-difference-between-http-host-and-server-name-in-php ----------------------------------------------------------------------- REWRITING THE PROTOCOL @@ -527,7 +527,7 @@ def __call__( from pprint import pformat; import logging; log = logging.getLogger(__name__); log.critical("Request headers:\n" + pformat(req.inheaders)) - """ # noqa + """ # noqa: E501 if self.debug: log.debug("Starting WSGI environment: \n{}", pformat(environ)) oldenv = environ.copy() @@ -564,7 +564,7 @@ def __call__( newpath = path_info[len(script_name) :] if ( not newpath - ): # e.g. trailing slash omitted from incoming path # noqa + ): # e.g. trailing slash omitted from incoming path newpath = "/" environ[WsgiEnvVar.PATH_INFO] = newpath diff --git a/dev_notes/convert_sql_string_coltype_to_sqlalchemy_type.py b/dev_notes/convert_sql_string_coltype_to_sqlalchemy_type.py new file mode 100644 index 0000000..15e190f --- /dev/null +++ b/dev_notes/convert_sql_string_coltype_to_sqlalchemy_type.py @@ -0,0 +1,52 @@ +# EXPLORATORY CODE ONLY. +# +# PROBLEM: Take a SQL string fragment representing a column type (e.g. +# "VARCHAR(32)", "STRING") and an SQLAlchemy dialect (a core one like mysql or +# sqlite, or a third-party one like databricks), and return the appropriate +# SQLAlchemy type as a TypeEngine class/instance. +# +# CURRENT IMPLEMENTATION: +# cardinal_pythonlib.sqlalchemy.schema.get_sqla_coltype_from_dialect_str() +# ... with its sub-function, _get_sqla_coltype_class_from_str() +# +# DISCUSSION AT: https://github.com/sqlalchemy/sqlalchemy/discussions/12230 + + +# For exploring some files directly: +from sqlalchemy.inspection import inspect # noqa: F401 +import sqlalchemy.dialects.sqlite.base +import sqlalchemy.dialects.sqlite.pysqlite # noqa: F401 + +# Test code for dialects: +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.dialects.mssql import dialect as MSSQLDialect +from sqlalchemy.dialects.mysql import dialect as MySQLDialect +from sqlalchemy.dialects.postgresql import dialect as PostgreSQLDialect +from sqlalchemy.dialects.sqlite import dialect as SQLiteDialect + +# Third-party dialect +from databricks.sqlalchemy import DatabricksDialect + +# Create instances to explore: +default_dialect = DefaultDialect() +postgresql_dialect = PostgreSQLDialect() +mssql_dialect = MSSQLDialect() +mysql_dialect = MySQLDialect() +sqlite_dialect = SQLiteDialect() +databricks_dialect = DatabricksDialect() + +print(sqlite_dialect.ischema_names) + +# The native ones all have an "ischema_names" dictionary, apart from +# DefaultDialect. The Databricks one doesn't. + +# The way SQLAlchemy does this for real is via an Inspector, which passes on +# to the Dialect. +# Inspector: https://docs.sqlalchemy.org/en/20/core/reflection.html#sqlalchemy.engine.reflection.Inspector # noqa: E501 +# Engine: https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Engine # noqa: E501 +# Dialect: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.Dialect # noqa: E501 +# ... get_columns() +# ... type_descriptor(), convers generic SQLA type to dialect-specific type. +# DefaultDialect: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.default.DefaultDialect # noqa: E501 + +# I can't find a generic method. See discussion above: there isn't one. diff --git a/docs/docs_requirements.txt b/docs/docs_requirements.txt index eb486f6..9ffc9b2 100644 --- a/docs/docs_requirements.txt +++ b/docs/docs_requirements.txt @@ -4,8 +4,8 @@ colander cryptography deform dogpile.cache==0.9.2 -# CRATE is on 3.2 -Django<4.0 +# CRATE is on 4.2 +Django>=4.2,<5.0 libChEBIpy pdfkit pyramid==1.10.8 diff --git a/docs/source/autodoc/_index.rst b/docs/source/autodoc/_index.rst index ba08f67..4adf8f6 100644 --- a/docs/source/autodoc/_index.rst +++ b/docs/source/autodoc/_index.rst @@ -163,9 +163,13 @@ Automatic documentation of source code sqlalchemy/sqlfunc.py.rst sqlalchemy/sqlserver.py.rst sqlalchemy/table_identity.py.rst + sqlalchemy/tests/core_query_tests.py.rst sqlalchemy/tests/dump_tests.py.rst + sqlalchemy/tests/insert_on_duplicate_tests.py.rst sqlalchemy/tests/merge_db_tests.py.rst sqlalchemy/tests/orm_inspect_tests.py.rst + sqlalchemy/tests/orm_query_tests.py.rst + sqlalchemy/tests/orm_schema_tests.py.rst sqlalchemy/tests/schema_tests.py.rst stringfunc.py.rst subproc.py.rst diff --git a/docs/source/autodoc/sqlalchemy/tests/core_query_tests.py.rst b/docs/source/autodoc/sqlalchemy/tests/core_query_tests.py.rst new file mode 100644 index 0000000..5d3a0b1 --- /dev/null +++ b/docs/source/autodoc/sqlalchemy/tests/core_query_tests.py.rst @@ -0,0 +1,25 @@ +.. docs/source/autodoc/sqlalchemy/tests/core_query_tests.py.rst + +.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT. + + +.. Copyright (C) 2009-2020 Rudolf Cardinal (rudolf@pobox.com). + . + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + https://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +cardinal_pythonlib.sqlalchemy.tests.core_query_tests +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: cardinal_pythonlib.sqlalchemy.tests.core_query_tests + :members: diff --git a/docs/source/autodoc/sqlalchemy/tests/insert_on_duplicate_tests.py.rst b/docs/source/autodoc/sqlalchemy/tests/insert_on_duplicate_tests.py.rst new file mode 100644 index 0000000..1f6c643 --- /dev/null +++ b/docs/source/autodoc/sqlalchemy/tests/insert_on_duplicate_tests.py.rst @@ -0,0 +1,25 @@ +.. docs/source/autodoc/sqlalchemy/tests/insert_on_duplicate_tests.py.rst + +.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT. + + +.. Copyright (C) 2009-2020 Rudolf Cardinal (rudolf@pobox.com). + . + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + https://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +cardinal_pythonlib.sqlalchemy.tests.insert_on_duplicate_tests +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: cardinal_pythonlib.sqlalchemy.tests.insert_on_duplicate_tests + :members: diff --git a/docs/source/autodoc/sqlalchemy/tests/orm_query_tests.py.rst b/docs/source/autodoc/sqlalchemy/tests/orm_query_tests.py.rst new file mode 100644 index 0000000..efab644 --- /dev/null +++ b/docs/source/autodoc/sqlalchemy/tests/orm_query_tests.py.rst @@ -0,0 +1,25 @@ +.. docs/source/autodoc/sqlalchemy/tests/orm_query_tests.py.rst + +.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT. + + +.. Copyright (C) 2009-2020 Rudolf Cardinal (rudolf@pobox.com). + . + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + https://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +cardinal_pythonlib.sqlalchemy.tests.orm_query_tests +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: cardinal_pythonlib.sqlalchemy.tests.orm_query_tests + :members: diff --git a/docs/source/autodoc/sqlalchemy/tests/orm_schema_tests.py.rst b/docs/source/autodoc/sqlalchemy/tests/orm_schema_tests.py.rst new file mode 100644 index 0000000..06f7fc5 --- /dev/null +++ b/docs/source/autodoc/sqlalchemy/tests/orm_schema_tests.py.rst @@ -0,0 +1,25 @@ +.. docs/source/autodoc/sqlalchemy/tests/orm_schema_tests.py.rst + +.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT. + + +.. Copyright (C) 2009-2020 Rudolf Cardinal (rudolf@pobox.com). + . + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + https://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +cardinal_pythonlib.sqlalchemy.tests.orm_schema_tests +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: cardinal_pythonlib.sqlalchemy.tests.orm_schema_tests + :members: diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index ea0f726..22f565f 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -796,6 +796,9 @@ Quick links: .. _changelog_2024: +2024 +~~~~ + **1.1.26 (2024-03-03)** - Fix ``AttributeError: 'Engine' object has no attribute 'schema_for_object'`` @@ -809,4 +812,37 @@ Quick links: - Replace ugettext_* calls removed in Django 4.0. https://docs.djangoproject.com/en/4.2/releases/4.0/#features-removed-in-4-0 -**1.1.28 (in progress)** +.. _changelog_2025: + +2025 +~~~~ + +**2.0.0 (2025-01-07)** + +- Update for SQLAlchemy 2. + + ADDED: + + - cardinal_pythonlib.sqlalchemy.insert_on_duplicate.insert_with_upsert_if_supported + - cardinal_pythonlib.sqlalchemy.core_query.get_rows_fieldnames_from_select + + REMOVED: + + - cardinal_pythonlib.sqlalchemy.insert_on_duplicate.InsertOnDuplicate + + Use insert_with_upsert_if_supported() instead. + + - cardinal_pythonlib.sqlalchemy.orm_query.get_rows_fieldnames_from_query + + This will now raise NotImplementedError. Use + get_rows_fieldnames_from_select() instead. This reflects a core change in + SQLAlchemy 2, moving towards the use of select() statements for all + queries. + + SHOULDN'T BE NOTICEABLE: + + - cardinal_pythonlib.sqlalchemy.orm_query.CountStarSpecializedQuery has + changed type. But operation is as before, assuming all you did with it + was apply filters (if required) and execute. + + - Multiple internal changes to support SQLAlchemy 2. diff --git a/docs/source/conf.py b/docs/source/conf.py index 37489b6..9254629 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -199,7 +199,7 @@ # -- Options for autodoc extension ------------------------------------------- autoclass_content = "both" -# https://stackoverflow.com/questions/5599254/how-to-use-sphinxs-autodoc-to-document-a-classs-init-self-method # noqa +# https://stackoverflow.com/questions/5599254/how-to-use-sphinxs-autodoc-to-document-a-classs-init-self-method # noqa: E501 # ============================================================================= diff --git a/setup.py b/setup.py index dc053e3..29bd07b 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ "rich-argparse>=0.5.0", # colourful help "scipy", "semantic-version", - "SQLAlchemy>=1.4,<2.0", + "SQLAlchemy>=1.4,<3.0", "sqlparse", ]