From d78f0c07c1701fa9889350b9cee31ae188b7fd71 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Tue, 2 Jul 2024 15:24:48 +0200 Subject: [PATCH] Snow 1058245 SqlAlchemy 2.0 support (#469) SNOW-1058245-sqlalchemy-20-support: Add support for installation SQLAlchemy 2.0 --- .github/workflows/build_test.yml | 108 +++++++--- .github/workflows/create_req_files.yml | 6 +- .github/workflows/jira_close.yml | 2 +- .github/workflows/jira_comment.yml | 4 +- .github/workflows/jira_issue.yml | 4 +- .github/workflows/python-publish.yml | 2 +- .github/workflows/stale_issue_bot.yml | 2 +- DESCRIPTION.md | 6 +- pyproject.toml | 14 +- snyk/requirements.txt | 2 +- snyk/requiremtnts.txt | 2 + src/snowflake/sqlalchemy/base.py | 44 ++-- src/snowflake/sqlalchemy/compat.py | 36 ++++ src/snowflake/sqlalchemy/custom_commands.py | 3 +- src/snowflake/sqlalchemy/functions.py | 16 ++ src/snowflake/sqlalchemy/requirements.py | 16 ++ src/snowflake/sqlalchemy/snowdialect.py | 67 +++--- src/snowflake/sqlalchemy/util.py | 12 +- src/snowflake/sqlalchemy/version.py | 2 +- tests/conftest.py | 32 +-- tests/sqlalchemy_test_suite/conftest.py | 7 + tests/sqlalchemy_test_suite/test_suite.py | 4 + tests/sqlalchemy_test_suite/test_suite_20.py | 205 +++++++++++++++++++ tests/test_compiler.py | 2 +- tests/test_core.py | 85 +++----- tests/test_custom_functions.py | 25 +++ tests/test_orm.py | 42 ++-- tests/test_pandas.py | 11 +- tests/test_qmark.py | 4 +- tox.ini | 10 +- 30 files changed, 558 insertions(+), 217 deletions(-) create mode 100644 snyk/requiremtnts.txt create mode 100644 src/snowflake/sqlalchemy/compat.py create mode 100644 src/snowflake/sqlalchemy/functions.py create mode 100644 tests/sqlalchemy_test_suite/test_suite_20.py create mode 100644 tests/test_custom_functions.py diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index be19f1f1..3baa6a0d 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -33,8 +33,8 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Set PY run: echo "PY=$(hatch run gh-cache-sum)" >> $GITHUB_ENV @@ -49,6 +49,10 @@ jobs: name: Test package build and installation runs-on: ubuntu-latest needs: lint + strategy: + fail-fast: true + matrix: + hatch-env: [default, sa20] steps: - uses: actions/checkout@v4 with: @@ -59,15 +63,14 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package run: | - python -m hatch clean - python -m hatch build + python -m hatch -e ${{ matrix.hatch-env }} build --clean - name: Install and check import run: | - python -m pip install dist/snowflake_sqlalchemy-*.whl + python -m uv pip install dist/snowflake_sqlalchemy-*.whl python -c "import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)" test-dialect: @@ -79,7 +82,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -98,8 +101,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and prepare environment run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -125,7 +128,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -144,8 +147,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -162,8 +165,8 @@ jobs: path: | ./coverage.xml - test-dialect-run-v20: - name: Test dialect run v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-v20: + name: Test dialect v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: @@ -171,7 +174,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -197,21 +200,67 @@ jobs: .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Run tests - run: hatch run test-run_v20 + run: hatch run sa20:test-dialect - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-run-20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml + + test-dialect-compatibility-v20: + name: Test dialect v20 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run sa20:test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v20-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml combine-coverage: name: Combine coverage if: ${{ success() || failure() }} - needs: [test-dialect, test-dialect-compatibility, test-dialect-run-v20] + needs: [test-dialect, test-dialect-compatibility, test-dialect-v20, test-dialect-compatibility-v20] runs-on: ubuntu-latest steps: - name: Set up Python @@ -220,8 +269,8 @@ jobs: python-version: "3.8" - name: Prepare environment run: | - pip install -U pip - pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch hatch env create default - uses: actions/checkout@v4 with: @@ -233,22 +282,15 @@ jobs: run: | hatch run coverage combine -a artifacts/coverage.xml_*/coverage.xml hatch run coverage report -m - hatch run coverage xml -o combined_coverage.xml - hatch run coverage html -d htmlcov - name: Store coverage reports uses: actions/upload-artifact@v4 with: - name: combined_coverage.xml - path: combined_coverage.xml - - name: Store htmlcov report - uses: actions/upload-artifact@v4 - with: - name: combined_htmlcov - path: htmlcov + name: coverage.xml + path: coverage.xml - name: Uplaod to codecov uses: codecov/codecov-action@v4 with: - file: combined_coverage.xml + file: coverage.xml env_vars: OS,PYTHON fail_ci_if_error: false flags: unittests diff --git a/.github/workflows/create_req_files.yml b/.github/workflows/create_req_files.yml index 618b3024..2cb7a371 100644 --- a/.github/workflows/create_req_files.yml +++ b/.github/workflows/create_req_files.yml @@ -21,10 +21,10 @@ jobs: - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel + run: python -m pip install -U setuptools pip wheel uv - name: Install Snowflake SQLAlchemy shell: bash - run: python -m pip install . + run: python -m uv pip install . - name: Generate reqs file name shell: bash run: echo "requirements_file=temp_requirement/requirements_$(python -c 'from sys import version_info;print(str(version_info.major)+str(version_info.minor))').reqs" >> $GITHUB_ENV @@ -34,7 +34,7 @@ jobs: mkdir temp_requirement echo "# Generated on: $(python --version)" >${{ env.requirements_file }} python -m pip freeze | grep -v snowflake-sqlalchemy 1>>${{ env.requirements_file }} 2>/dev/null - echo "snowflake-sqlalchemy==$(python -m pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} + echo "snowflake-sqlalchemy==$(python -m uv pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} id: create-reqs-file - name: Show created req file shell: bash diff --git a/.github/workflows/jira_close.yml b/.github/workflows/jira_close.yml index 5b170d75..7862f483 100644 --- a/.github/workflows/jira_close.yml +++ b/.github/workflows/jira_close.yml @@ -17,7 +17,7 @@ jobs: token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 954929fa..8533c14c 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} @@ -22,7 +22,7 @@ jobs: jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Comment on issue - uses: atlassian/gajira-comment@master + uses: atlassian/gajira-comment@v3 if: startsWith(steps.extract.outputs.jira, 'SNOW-') with: issue: "${{ steps.extract.outputs.jira }}" diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 31b93aae..85c774ca 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -23,7 +23,7 @@ jobs: path: . - name: Login - uses: atlassian/gajira-login@v2.0.0 + uses: atlassian/gajira-login@v3 env: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} @@ -31,7 +31,7 @@ jobs: - name: Create JIRA Ticket id: create - uses: atlassian/gajira-create@v2.0.1 + uses: atlassian/gajira-create@v3 with: project: SNOW issuetype: Bug diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index ab4be45b..23116e7a 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -35,7 +35,7 @@ jobs: - name: Build package run: python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/stale_issue_bot.yml b/.github/workflows/stale_issue_bot.yml index 6d76e9f4..4ee56ff8 100644 --- a/.github/workflows/stale_issue_bot.yml +++ b/.github/workflows/stale_issue_bot.yml @@ -10,7 +10,7 @@ jobs: stale: runs-on: ubuntu-latest steps: - - uses: actions/stale@v7 + - uses: actions/stale@v9 with: close-issue-message: 'To clean up and re-prioritize bugs and feature requests we are closing all issues older than 6 months as of Apr 1, 2023. If there are any issues or feature requests that you would like us to address, please re-create them. For urgent issues, opening a support case with this link [Snowflake Community](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge) is the fastest way to get a response' days-before-issue-stale: ${{ inputs.staleDays }} diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 2f228781..8b4dcd37 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,13 +9,17 @@ Source code is also available at: # Release Notes +- v1.6.0(Not released) + + - support for installing with SQLAlchemy 2.0.x + - v1.5.4 - Add ability to set ORDER / NOORDER sequence on columns with IDENTITY - v1.5.3(April 16, 2024) - - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 - v1.5.2(April 11, 2024) diff --git a/pyproject.toml b/pyproject.toml index 3f95df46..d2316a44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["SQLAlchemy>=1.4.19,<2.0.0", "snowflake-connector-python<4.0.0"] +dependencies = ["SQLAlchemy>=1.4.19", "snowflake-connector-python<4.0.0"] [tool.hatch.version] path = "src/snowflake/sqlalchemy/version.py" @@ -73,8 +73,14 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] features = ["development", "pandas"] python = "3.8" +installer = "uv" + +[tool.hatch.envs.sa20] +extra-dependencies = ["SQLAlchemy>=1.4.19,<=2.1.0"] +python = "3.8" [tool.hatch.envs.default.env-vars] COVERAGE_FILE = "coverage.xml" @@ -82,10 +88,10 @@ SQLACHEMY_WARN_20 = "1" [tool.hatch.envs.default.scripts] check = "pre-commit run --all-files" -test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite" -test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" -test-run_v20 = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite --run_v20_sqlalchemy" +test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite tests/" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" +check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" [tool.ruff] line-length = 88 diff --git a/snyk/requirements.txt b/snyk/requirements.txt index 3a77e0f9..0166d751 100644 --- a/snyk/requirements.txt +++ b/snyk/requirements.txt @@ -1,2 +1,2 @@ -SQLAlchemy>=1.4.19,<2.0.0 +SQLAlchemy>=1.4.19 snowflake-connector-python<4.0.0 diff --git a/snyk/requiremtnts.txt b/snyk/requiremtnts.txt new file mode 100644 index 00000000..a92c527e --- /dev/null +++ b/snyk/requiremtnts.txt @@ -0,0 +1,2 @@ +snowflake-connector-python<4.0.0 +SQLAlchemy>=1.4.19,<2.1.0 diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index e008c92f..1aaa881e 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -13,13 +13,14 @@ from sqlalchemy.orm import context from sqlalchemy.orm.context import _MapperEntity from sqlalchemy.schema import Sequence, Table -from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import compiler, expression, functions from sqlalchemy.sql.base import CompileState from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.selectable import Lateral, SelectState -from sqlalchemy.util.compat import string_types +from .compat import IS_VERSION_20, args_reducer, string_types from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from .functions import flatten from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -324,17 +325,9 @@ def _join_determine_implicit_left_side( return left, replace_from_obj_index, use_entity_index + @args_reducer(positions_to_drop=(6, 7)) def _join_left_to_right( - self, - entities_collection, - left, - right, - onclause, - prop, - create_aliases, - aliased_generation, - outerjoin, - full, + self, entities_collection, left, right, onclause, prop, outerjoin, full ): """given raw "left", "right", "onclause" parameters consumed from a particular key within _join(), add a real ORMJoin object to @@ -364,7 +357,7 @@ def _join_left_to_right( use_entity_index, ) = self._join_place_explicit_left_side(entities_collection, left) - if left is right and not create_aliases: + if left is right: raise sa_exc.InvalidRequestError( "Can't construct a join from %s to %s, they " "are the same entity" % (left, right) @@ -373,9 +366,15 @@ def _join_left_to_right( # the right side as given often needs to be adapted. additionally # a lot of things can be wrong with it. handle all that and # get back the new effective "right" side - r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop, create_aliases, aliased_generation - ) + + if IS_VERSION_20: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + else: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, False, False + ) if not r_info.is_selectable: extra_criteria = self._get_extra_criteria(r_info) @@ -979,24 +978,23 @@ def visit_identity_column(self, identity, **kw): def get_identity_options(self, identity_options): text = [] if identity_options.increment is not None: - text.append(f"INCREMENT BY {identity_options.increment:d}") + text.append("INCREMENT BY %d" % identity_options.increment) if identity_options.start is not None: - text.append(f"START WITH {identity_options.start:d}") + text.append("START WITH %d" % identity_options.start) if identity_options.minvalue is not None: - text.append(f"MINVALUE {identity_options.minvalue:d}") + text.append("MINVALUE %d" % identity_options.minvalue) if identity_options.maxvalue is not None: - text.append(f"MAXVALUE {identity_options.maxvalue:d}") + text.append("MAXVALUE %d" % identity_options.maxvalue) if identity_options.nominvalue is not None: text.append("NO MINVALUE") if identity_options.nomaxvalue is not None: text.append("NO MAXVALUE") if identity_options.cache is not None: - text.append(f"CACHE {identity_options.cache:d}") + text.append("CACHE %d" % identity_options.cache) if identity_options.cycle is not None: text.append("CYCLE" if identity_options.cycle else "NO CYCLE") if identity_options.order is not None: text.append("ORDER" if identity_options.order else "NOORDER") - return " ".join(text) @@ -1066,3 +1064,5 @@ def visit_GEOMETRY(self, type_, **kw): construct_arguments = [(Table, {"clusterby": None})] + +functions.register_function("flatten", flatten) diff --git a/src/snowflake/sqlalchemy/compat.py b/src/snowflake/sqlalchemy/compat.py new file mode 100644 index 00000000..9e97e574 --- /dev/null +++ b/src/snowflake/sqlalchemy/compat.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from __future__ import annotations + +import functools +from typing import Callable + +from sqlalchemy import __version__ as SA_VERSION +from sqlalchemy import util + +string_types = (str,) +returns_unicode = util.symbol("RETURNS_UNICODE") + +IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0) + + +def args_reducer(positions_to_drop: tuple): + """Removes args at positions provided in tuple positions_to_drop. + + For example tuple (3, 5) will remove items at third and fifth position. + Keep in mind that on class methods first postion is cls or self. + """ + + def fn_wrapper(fn: Callable): + @functools.wraps(fn) + def wrapper(*args): + reduced_args = args + if not IS_VERSION_20: + reduced_args = tuple( + arg for idx, arg in enumerate(args) if idx not in positions_to_drop + ) + fn(*reduced_args) + + return wrapper + + return fn_wrapper diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index cec16673..15585bd5 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -10,7 +10,8 @@ from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.roles import FromClauseRole -from sqlalchemy.util.compat import string_types + +from .compat import string_types NoneType = type(None) diff --git a/src/snowflake/sqlalchemy/functions.py b/src/snowflake/sqlalchemy/functions.py new file mode 100644 index 00000000..c08aa734 --- /dev/null +++ b/src/snowflake/sqlalchemy/functions.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import warnings + +from sqlalchemy.sql import functions as sqlfunc + +FLATTEN_WARNING = "For backward compatibility params are not rendered." + + +class flatten(sqlfunc.GenericFunction): + name = "flatten" + + def __init__(self, *args, **kwargs): + warnings.warn(FLATTEN_WARNING, DeprecationWarning, stacklevel=2) + super().__init__(*args, **kwargs) diff --git a/src/snowflake/sqlalchemy/requirements.py b/src/snowflake/sqlalchemy/requirements.py index ea30a823..f2844804 100644 --- a/src/snowflake/sqlalchemy/requirements.py +++ b/src/snowflake/sqlalchemy/requirements.py @@ -289,9 +289,25 @@ def datetime_implicit_bound(self): # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + @property + def date_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + + @property + def time_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + @property def timestamp_microseconds_implicit_bound(self): # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding # parameters in string forms of timestamp values. # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + + @property + def array_type(self): + return exclusions.closed() diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 2e40d03c..04305a00 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -5,6 +5,7 @@ import operator from collections import defaultdict from functools import reduce +from typing import Any from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -15,7 +16,6 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.sqltypes import String from sqlalchemy.types import ( BIGINT, BINARY, @@ -40,6 +40,7 @@ from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 +from snowflake.sqlalchemy.compat import returns_unicode from .base import ( SnowflakeCompiler, @@ -63,7 +64,11 @@ _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name, parse_url_boolean +from .util import ( + _update_connection_application_name, + parse_url_boolean, + parse_url_integer, +) colspecs = { Date: _CUSTOM_Date, @@ -134,7 +139,7 @@ class SnowflakeDialect(default.DefaultDialect): # unicode strings supports_unicode_statements = True supports_unicode_binds = True - returns_unicode_strings = String.RETURNS_UNICODE + returns_unicode_strings = returns_unicode description_encoding = None # No lastrowid support. See SNOW-11155 @@ -195,10 +200,34 @@ class SnowflakeDialect(default.DefaultDialect): @classmethod def dbapi(cls): + return cls.import_dbapi() + + @classmethod + def import_dbapi(cls): from snowflake import connector return connector + @staticmethod + def parse_query_param_type(name: str, value: Any) -> Any: + """Cast param value if possible to type defined in connector-python.""" + if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)): + return value + + _, expected_type = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance(value, expected_type): + return value + + elif bool in expected_type: + return parse_url_boolean(value) + elif int in expected_type: + return parse_url_integer(value) + else: + return value + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: @@ -235,47 +264,25 @@ def create_connect_args(self, url: URL): # URL sets the query parameter values as strings, we need to cast to expected types when necessary for name, value in query.items(): - maybe_type_configuration = DEFAULT_CONFIGURATION.get(name) - if ( - not maybe_type_configuration - ): # if the parameter is not found in the type mapping, pass it through as a string - opts[name] = value - continue - - (_, expected_type) = maybe_type_configuration - if not isinstance(expected_type, tuple): - expected_type = (expected_type,) - - if isinstance( - value, expected_type - ): # if the expected type is str, pass it through as a string - opts[name] = value - - elif ( - bool in expected_type - ): # if the expected type is bool, parse it and pass as a boolean - opts[name] = parse_url_boolean(value) - else: - # TODO: other types like int are stil passed through as string - # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447 - opts[name] = value + opts[name] = self.parse_query_param_type(name, value) return ([], opts) - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): """ Checks if the table exists """ return self._has_object(connection, "TABLE", table_name, schema) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): """ Checks if the sequence exists """ return self._has_object(connection, "SEQUENCE", sequence_name, schema) def _has_object(self, connection, object_type, object_name, schema=None): - full_name = self._denormalize_quote_join(schema, object_name) try: results = connection.execute( diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 32e07373..a1aefff9 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -7,7 +7,7 @@ from typing import Any from urllib.parse import quote_plus -from sqlalchemy import exc, inspection, sql, util +from sqlalchemy import exc, inspection, sql from sqlalchemy.exc import NoForeignKeysError from sqlalchemy.orm.interfaces import MapperProperty from sqlalchemy.orm.util import _ORMJoin as sa_orm_util_ORMJoin @@ -19,6 +19,7 @@ from snowflake.connector.compat import IS_STR from snowflake.connector.connection import SnowflakeConnection +from snowflake.sqlalchemy import compat from ._constants import ( APPLICATION_NAME, @@ -124,6 +125,13 @@ def parse_url_boolean(value: str) -> bool: raise ValueError(f"Invalid boolean value detected: '{value}'") +def parse_url_integer(value: str) -> int: + try: + return int(value) + except ValueError as e: + raise ValueError(f"Invalid int value detected: '{value}") from e + + # handle Snowflake BCR bcr-1057 # the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that @@ -212,7 +220,7 @@ def __init__( # then the "_joined_from_info" concept can go left_orm_info = getattr(left, "_joined_from_info", left_info) self._joined_from_info = right_info - if isinstance(onclause, util.string_types): + if isinstance(onclause, compat.string_types): onclause = getattr(left_orm_info.entity, onclause) # #### diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 61c9fc41..56509b7d 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.5.3" +VERSION = "1.6.0" diff --git a/tests/conftest.py b/tests/conftest.py index a9c2560a..d4dab3d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,21 +46,6 @@ TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" -def pytest_addoption(parser): - parser.addoption( - "--run_v20_sqlalchemy", - help="Use only 2.0 SQLAlchemy APIs, any legacy features (< 2.0) will not be supported." - "Turning on this option will set future flag to True on Engine and Session objects according to" - "the migration guide: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html", - action="store_true", - ) - - -@pytest.fixture(scope="session") -def run_v20_sqlalchemy(pytestconfig): - return pytestconfig.option.run_v20_sqlalchemy - - @pytest.fixture(scope="session") def on_travis(): return os.getenv("TRAVIS", "").lower() == "true" @@ -160,20 +145,21 @@ def url_factory(**kwargs) -> URL: return URL(**url_params) -def get_engine(url: URL, run_v20_sqlalchemy=False, **engine_kwargs): +def get_engine(url: URL, **engine_kwargs): engine_params = { "poolclass": NullPool, - "future": run_v20_sqlalchemy, + "future": True, + "echo": True, } engine_params.update(engine_kwargs) - engine = create_engine(url, **engine_kwargs) + engine = create_engine(url, **engine_params) return engine @pytest.fixture() -def engine_testaccount(request, run_v20_sqlalchemy): +def engine_testaccount(request): url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @@ -181,17 +167,17 @@ def engine_testaccount(request, run_v20_sqlalchemy): @pytest.fixture() def engine_testaccount_with_numpy(request): url = url_factory(numpy=True) - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @pytest.fixture() -def engine_testaccount_with_qmark(request, run_v20_sqlalchemy): +def engine_testaccount_with_qmark(request): snowflake.connector.paramstyle = "qmark" url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine diff --git a/tests/sqlalchemy_test_suite/conftest.py b/tests/sqlalchemy_test_suite/conftest.py index 31cd7c5c..f0464c7d 100644 --- a/tests/sqlalchemy_test_suite/conftest.py +++ b/tests/sqlalchemy_test_suite/conftest.py @@ -15,6 +15,7 @@ import snowflake.connector from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy.compat import IS_VERSION_20 from ..conftest import get_db_parameters from ..util import random_string @@ -25,6 +26,12 @@ TEST_SCHEMA_2 = f"{TEST_SCHEMA}_2" +if IS_VERSION_20: + collect_ignore_glob = ["test_suite.py"] +else: + collect_ignore_glob = ["test_suite_20.py"] + + # patch sqlalchemy.testing.config.Confi.__init__ for schema name randomization # same schema name would result in conflict as we're running tests in parallel in the CI def config_patched__init__(self, db, db_opts, options, file_config): diff --git a/tests/sqlalchemy_test_suite/test_suite.py b/tests/sqlalchemy_test_suite/test_suite.py index d79e511e..643d1559 100644 --- a/tests/sqlalchemy_test_suite/test_suite.py +++ b/tests/sqlalchemy_test_suite/test_suite.py @@ -69,6 +69,10 @@ def test_empty_insert(self, connection): def test_empty_insert_multiple(self, connection): pass + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + # 2. Patched Tests diff --git a/tests/sqlalchemy_test_suite/test_suite_20.py b/tests/sqlalchemy_test_suite/test_suite_20.py new file mode 100644 index 00000000..1f79c4e9 --- /dev/null +++ b/tests/sqlalchemy_test_suite/test_suite_20.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Integer, testing +from sqlalchemy.schema import Column, Sequence, Table +from sqlalchemy.testing import config +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import ( + BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, +) +from sqlalchemy.testing.suite import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite import DateTimeHistoricTest as _DateTimeHistoricTest +from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest +from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest +from sqlalchemy.testing.suite import InsertBehaviorTest as _InsertBehaviorTest +from sqlalchemy.testing.suite import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest +from sqlalchemy.testing.suite import SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest +from sqlalchemy.testing.suite import TimeMicrosecondsTest as _TimeMicrosecondsTest +from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest +from sqlalchemy.testing.suite import * # noqa + +# 1. Unsupported by snowflake db + +del ComponentReflectionTest # require indexes not supported by snowflake +del HasIndexTest # require indexes not supported by snowflake +del QuotedNameArgumentTest # require indexes not supported by snowflake + + +class LongNameBlowoutTest(_LongNameBlowoutTest): + # The combination ("ix",) is removed due to Snowflake not supporting indexes + def ix(self, metadata, connection): + pytest.skip("ix required index feature not supported by Snowflake") + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_bound_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_limit_expr_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset_zero(self, connection): + pass + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip( + "Snowflake does not support inserting empty values, the value may be a literal or an expression." + ) + def test_empty_insert(self, connection): + pass + + @pytest.mark.skip( + "Snowflake does not support inserting empty values, The value may be a literal or an expression." + ) + def test_empty_insert_multiple(self, connection): + pass + + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + + +# road to 2.0 +class TrueDivTest(_TrueDivTest): + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer_bound(self, connection): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer(self, connection, left, right, expected): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + +class TimeMicrosecondsTest(_TimeMicrosecondsTest): + def __init__(self): + super().__init__() + + +class DateTimeHistoricTest(_DateTimeHistoricTest): + def __init__(self): + super().__init__() + + +# 2. Patched Tests + + +class HasSequenceTest(_HasSequenceTest): + # Override the define_tables method as snowflake does not support 'nomaxvalue'/'nominvalue' + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + # Replace Sequence("other_seq") creation as in the original test suite, + # the Sequence created with 'nomaxvalue' and 'nominvalue' + # which snowflake does not support: + # Sequence("other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True) + Sequence("other_seq", metadata=metadata) + if testing.requires.schemas.enabled: + Sequence("user_id_seq", schema=config.test_schema, metadata=metadata) + Sequence("schema_seq", schema=config.test_schema, metadata=metadata) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + +class LikeFunctionsTest(_LikeFunctionsTest): + @testing.requires.regexp_match + @testing.combinations( + ("a.cde.*", {1, 5, 6, 9}), + ("abc.*", {1, 5, 6, 9, 10}), + ("^abc.*", {1, 5, 6, 9, 10}), + (".*9cde.*", {8}), + ("^a.*", set(range(1, 11))), + (".*(b|c).*", set(range(1, 11))), + ("^(b|c).*", set()), + ) + def test_regexp_match(self, text, expected): + super().test_regexp_match(text, expected) + + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde.*"), {2, 3, 4, 7, 8, 10}) + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + # snowflake returns a row with numbers of rows updated and number of multi-joined rows updated + assert r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # snowflake returns a row with number of rows deleted + assert r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_fk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_fk_column_order() + + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_pk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_pk_column_order() + + +class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname): + super().test_fk_ref(connection, metadata, use_composite, tablename, columnname) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0fd75c38..40207b41 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -5,7 +5,7 @@ from sqlalchemy import Integer, String, and_, func, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table -from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing.assertions import AssertsCompiledSQL from snowflake.sqlalchemy import snowdialect diff --git a/tests/test_core.py b/tests/test_core.py index 6c8d7416..179133c8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -34,7 +34,7 @@ inspect, text, ) -from sqlalchemy.exc import DBAPIError, NoSuchTableError +from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select import snowflake.connector.errors @@ -406,16 +406,6 @@ def test_insert_tables(engine_testaccount): str(users.join(addresses)) == "users JOIN addresses ON " "users.id = addresses.user_id" ) - assert ( - str( - users.join( - addresses, - addresses.c.email_address.like(users.c.name + "%"), - ) - ) - == "users JOIN addresses " - "ON addresses.email_address LIKE users.name || :name_1" - ) s = select(users.c.fullname).select_from( users.join( @@ -444,7 +434,7 @@ def test_table_does_not_exist(engine_testaccount): """ meta = MetaData() with pytest.raises(NoSuchTableError): - Table("does_not_exist", meta, autoload=True, autoload_with=engine_testaccount) + Table("does_not_exist", meta, autoload_with=engine_testaccount) @pytest.mark.skip( @@ -470,9 +460,7 @@ def test_reflextion(engine_testaccount): ) try: meta = MetaData() - user_reflected = Table( - "user", meta, autoload=True, autoload_with=engine_testaccount - ) + user_reflected = Table("user", meta, autoload_with=engine_testaccount) assert user_reflected.c == ["user.id", "user.name", "user.fullname"] finally: conn.execute("DROP TABLE IF EXISTS user") @@ -1071,28 +1059,15 @@ def harass_inspector(): assert outcome -@pytest.mark.timeout(15) -def test_region(): - engine = create_engine( - URL( - user="testuser", - password="testpassword", - account="testaccount", - region="eu-central-1", - login_timeout=5, - ) - ) - try: - engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg - - -@pytest.mark.timeout(15) -def test_azure(): +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "region", + ( + pytest.param("eu-central-1", id="region"), + pytest.param("east-us-2.azure", id="azure"), + ), +) +def test_connection_timeout_error(region): engine = create_engine( URL( user="testuser", @@ -1102,13 +1077,13 @@ def test_azure(): login_timeout=5, ) ) - try: + + with pytest.raises(OperationalError) as excinfo: engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg + + assert excinfo.value.orig.errno == 250001 + assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg + assert region not in excinfo.value.orig.msg def test_load_dialect(): @@ -1535,11 +1510,16 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): metadata.create_all(engine_testaccount) inspector = inspect(engine_testaccount) # Do test - original_execute = inspector.bind.execute + connection = inspector.bind.connect() + original_execute = connection.execute + + too_many_columns_was_raised = False def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command: + if "_get_schema_columns" in command.text: # Creating exception exactly how SQLAlchemy does + nonlocal too_many_columns_was_raised + too_many_columns_was_raised = True raise DBAPIError.instance( """ SELECT /* sqlalchemy:_get_schema_columns */ @@ -1571,9 +1551,12 @@ def mock_helper(command, *args, **kwargs): else: return original_execute(command, *args, **kwargs) - with patch.object(inspector.bind, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) + with patch.object(engine_testaccount, "connect") as conn: + conn.return_value = connection + with patch.object(connection, "execute", side_effect=mock_helper): + column_metadata = inspector.get_columns("users", db_parameters["schema"]) assert len(column_metadata) == 4 + assert too_many_columns_was_raised # Clean up metadata.drop_all(engine_testaccount) @@ -1615,9 +1598,7 @@ def test_column_type_schema(engine_testaccount): """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns assert ( len(columns) == len(ischema_names_baseline) - 1 @@ -1638,9 +1619,7 @@ def test_result_type_and_value(engine_testaccount): ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) current_date = date.today() current_utctime = datetime.utcnow() current_localtime = pytz.utc.localize(current_utctime, is_dst=False).astimezone( diff --git a/tests/test_custom_functions.py b/tests/test_custom_functions.py new file mode 100644 index 00000000..2a1e1cb5 --- /dev/null +++ b/tests/test_custom_functions.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest +from sqlalchemy import func + +from snowflake.sqlalchemy import snowdialect + + +def test_flatten_does_not_render_params(): + """This behavior is for backward compatibility. + + In previous version params were not rendered. + In future this behavior will change. + """ + flat = func.flatten("[1, 2]", outer=True) + res = flat.compile(dialect=snowdialect.dialect()) + + assert str(res) == "flatten(%(flatten_1)s)" + + +def test_flatten_emits_warning(): + expected_warning = "For backward compatibility params are not rendered." + with pytest.warns(DeprecationWarning, match=expected_warning): + func.flatten().compile(dialect=snowdialect.dialect()) diff --git a/tests/test_orm.py b/tests/test_orm.py index e485d737..f53cd708 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Session, declarative_base, relationship -def test_basic_orm(engine_testaccount, run_v20_sqlalchemy): +def test_basic_orm(engine_testaccount): """ Tests declarative """ @@ -46,7 +46,6 @@ def __repr__(self): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) our_user = session.query(User).filter_by(name="ed").first() @@ -56,7 +55,7 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount, run_v20_sqlalchemy): +def test_orm_one_to_many_relationship(engine_testaccount): """ Tests One to Many relationship """ @@ -97,7 +96,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -124,7 +122,7 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_delete_cascade(engine_testaccount, run_v20_sqlalchemy): +def test_delete_cascade(engine_testaccount): """ Test delete cascade """ @@ -169,7 +167,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -189,7 +186,7 @@ def __repr__(self): WIP """, ) -def test_orm_query(engine_testaccount, run_v20_sqlalchemy): +def test_orm_query(engine_testaccount): """ Tests ORM query """ @@ -210,7 +207,6 @@ def __repr__(self): # TODO: insert rows session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy # TODO: query.all() for name, fullname in session.query(User.name, User.fullname): @@ -220,7 +216,7 @@ def __repr__(self): # MultipleResultsFound if not one result -def test_schema_including_db(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_db(engine_testaccount, db_parameters): """ Test schema parameter including database separated by a dot. """ @@ -243,7 +239,6 @@ class User(Base): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) ret_user = session.query(User.id, User.name).first() @@ -255,7 +250,7 @@ class User(Base): Base.metadata.drop_all(engine_testaccount) -def test_schema_including_dot(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_dot(engine_testaccount, db_parameters): """ Tests pseudo schema name including dot. """ @@ -276,7 +271,6 @@ class User(Base): fullname = Column(String) session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy query = session.query(User.id) assert str(query).startswith( 'SELECT {db}."{schema}.{schema}".{db}.users.id'.format( @@ -285,9 +279,7 @@ class User(Base): ) -def test_schema_translate_map( - engine_testaccount, db_parameters, sql_compiler, run_v20_sqlalchemy -): +def test_schema_translate_map(engine_testaccount, db_parameters): """ Test schema translate map execution option works replaces schema correctly """ @@ -310,7 +302,6 @@ class User(Base): schema_translate_map={schema_map: db_parameters["schema"]} ) as con: session = Session(bind=con) - session.future = run_v20_sqlalchemy with con.begin(): Base.metadata.create_all(con) try: @@ -367,18 +358,29 @@ class Department(Base): .select_from(Employee) .outerjoin(sub) ) - assert ( - str(query.compile(engine_testaccount)).replace("\n", "") - == "SELECT employees.employee_id, departments.department_id " + compiled_stmts = ( + # v1.x + "SELECT employees.employee_id, departments.department_id " "FROM departments, employees LEFT OUTER JOIN LATERAL " "(SELECT departments.department_id AS department_id, departments.name AS name " - "FROM departments) AS anon_1" + "FROM departments) AS anon_1", + # v2.x + "SELECT employees.employee_id, departments.department_id " + "FROM employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1, departments", ) + compiled_stmt = str(query.compile(engine_testaccount)).replace("\n", "") + assert compiled_stmt in compiled_stmts + with caplog.at_level(logging.DEBUG): assert [res for res in session.execute(query)] assert ( "SELECT employees.employee_id, departments.department_id FROM departments" in caplog.text + ) or ( + "SELECT employees.employee_id, departments.department_id FROM employees" + in caplog.text ) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index ef64d65e..63cd6d0e 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -27,6 +27,7 @@ from snowflake.connector import ProgrammingError from snowflake.connector.pandas_tools import make_pd_writer, pd_writer +from snowflake.sqlalchemy.compat import IS_VERSION_20 def _create_users_addresses_tables(engine_testaccount, metadata): @@ -240,8 +241,8 @@ def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_num conn.exec_driver_sql(f"DROP TABLE {test_table_name};") -def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_pandas_writeback(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -352,8 +353,8 @@ def test_pandas_invalid_make_pd_writer(engine_testaccount): ) -def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_percent_signs(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -376,7 +377,7 @@ def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): not_like_sql = f"select * from {table_name} where c2 not like '%b%'" like_sql = f"select * from {table_name} where c2 like '%b%'" calculate_sql = "SELECT 1600 % 400 AS a, 1599 % 400 as b" - if run_v20_sqlalchemy: + if IS_VERSION_20: not_like_sql = sqlalchemy.text(not_like_sql) like_sql = sqlalchemy.text(like_sql) calculate_sql = sqlalchemy.text(calculate_sql) diff --git a/tests/test_qmark.py b/tests/test_qmark.py index f98fa7d3..3761181a 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -12,11 +12,11 @@ THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -def test_qmark_bulk_insert(run_v20_sqlalchemy, engine_testaccount_with_qmark): +def test_qmark_bulk_insert(engine_testaccount_with_qmark): """ Bulk insert using qmark paramstyle """ - if run_v20_sqlalchemy and sys.version_info < (3, 8): + if sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) diff --git a/tox.ini b/tox.ini index 0c1cb483..7f605627 100644 --- a/tox.ini +++ b/tox.ini @@ -34,7 +34,7 @@ passenv = setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} SQLALCHEMY_WARN_20 = 1 - ci: SNOWFLAKE_PYTEST_OPTS = -vvv + ci: SNOWFLAKE_PYTEST_OPTS = -vvv --tb=long commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ @@ -44,12 +44,6 @@ commands = pytest \ --cov "snowflake.sqlalchemy" --cov-append \ --junitxml {toxworkdir}/junit_{envname}.xml \ {posargs:tests/sqlalchemy_test_suite} - pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - --cov "snowflake.sqlalchemy" --cov-append \ - --junitxml {toxworkdir}/junit_{envname}.xml \ - --run_v20_sqlalchemy \ - {posargs:tests} [testenv:.pkg_external] deps = build @@ -86,7 +80,7 @@ commands = pre-commit run --all-files python -c 'import pathlib; print("hint: run \{\} install to add checks as pre-commit hook".format(pathlib.Path(r"{envdir}") / "bin" / "pre-commit"))' [pytest] -addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite +addopts = -ra --ignore=tests/sqlalchemy_test_suite junit_family = legacy log_level = info markers =