Skip to content

Commit

Permalink
fix: Fixing test for literals due to change in sqlalchemy core tests (#…
Browse files Browse the repository at this point in the history
…384)

* fix: Fixing test for literals due to change in sqlalchemy core tests

* tests: remove editable install in tests

* One more literal test fix
  • Loading branch information
ankiaga authored Feb 1, 2024
1 parent 00561f8 commit 62cccc3
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 17 deletions.
1 change: 1 addition & 0 deletions .github/sync-repo-settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ branchProtectionRules:
- 'unit'
- 'compliance_tests_13'
- 'compliance_tests_14'
- 'compliance_tests_20'
- 'migration_tests'
- 'cla/google'
- 'Kokoro'
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
format_type,
)
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql import elements
from sqlalchemy import ForeignKeyConstraint, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
Expand Down Expand Up @@ -314,6 +315,12 @@ def render_literal_value(self, value, type_):
in string. Override the method to add additional escape before using it to
generate a SQL statement.
"""
if value is None and not type_.should_evaluate_none:
# issue #10535 - handle NULL in the compiler without placing
# this onto each type, except for "evaluate None" types
# (e.g. JSON)
return self.process(elements.Null._instance())

raw = ["\\", "'", '"', "\n", "\t", "\r"]
if isinstance(value, str) and any(single in value for single in raw):
value = 'r"""{}"""'.format(value)
Expand Down
12 changes: 6 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def compliance_test_13(session):
)

session.install("mock")
session.install("-e", ".[tracing]")
session.install(".[tracing]")
session.run("pip", "install", "sqlalchemy>=1.1.13,<=1.3.24", "--force-reinstall")
session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall")
session.run("pip", "install", "opentelemetry-sdk<=1.10", "--force-reinstall")
Expand Down Expand Up @@ -191,7 +191,7 @@ def compliance_test_14(session):
)

session.install("mock")
session.install("-e", ".[tracing]")
session.install(".[tracing]")
session.run("pip", "install", "sqlalchemy>=1.4,<2.0", "--force-reinstall")
session.run("python", "create_test_database.py")
session.run(
Expand Down Expand Up @@ -231,7 +231,7 @@ def compliance_test_20(session):
)

session.install("mock")
session.install("-e", ".[tracing]")
session.install(".[tracing]")
session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall")
session.run("python", "create_test_database.py")

Expand All @@ -257,7 +257,7 @@ def unit(session):
# Run SQLAlchemy dialect compliance test suite with OpenTelemetry.
session.install("pytest")
session.install("mock")
session.install("-e", ".")
session.install(".")
session.install("opentelemetry-api==1.1.0")
session.install("opentelemetry-sdk==1.1.0")
session.install("opentelemetry-instrumentation==0.20b0")
Expand Down Expand Up @@ -292,7 +292,7 @@ def _migration_test(session):
session.run("pip", "install", "sqlalchemy>=1.3.11,<2.0", "--force-reinstall")

session.install("pytest")
session.install("-e", ".")
session.install(".")
session.install("alembic")

session.run("python", "create_test_database.py")
Expand Down Expand Up @@ -360,7 +360,7 @@ def snippets(session):
session.install(
"git+https://github.com/googleapis/python-spanner.git#egg=google-cloud-spanner"
)
session.install("-e", ".")
session.install(".")
session.run("python", "create_test_database.py")
session.run(
"py.test",
Expand Down
63 changes: 63 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,72 @@

import pytest
from sqlalchemy.dialects import registry
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.sql.elements import literal

registry.register("spanner", "google.cloud.sqlalchemy_spanner", "SpannerDialect")

pytest.register_assert_rewrite("sqlalchemy.testing.assertions")

from sqlalchemy.testing.plugin.pytestplugin import * # noqa: E402, F401, F403


@pytest.fixture
def literal_round_trip_spanner(metadata, connection):
# for literal, we test the literal render in an INSERT
# into a typed column. we can then SELECT it back as its
# official type;

def run(
type_,
input_,
output,
filter_=None,
compare=None,
support_whereclause=True,
):
t = Table("t", metadata, Column("x", type_))
t.create(connection)

for value in input_:
ins = t.insert().values(x=literal(value, type_, literal_execute=True))
connection.execute(ins)

if support_whereclause:
if compare:
stmt = t.select().where(
t.c.x
== literal(
compare,
type_,
literal_execute=True,
),
t.c.x
== literal(
input_[0],
type_,
literal_execute=True,
),
)
else:
stmt = t.select().where(
t.c.x
== literal(
compare if compare is not None else input_[0],
type_,
literal_execute=True,
)
)
else:
stmt = t.select()

rows = connection.execute(stmt).all()
assert rows, "No rows returned"
for row in rows:
value = row[0]
if filter_ is not None:
value = filter_(value)
assert value in output

return run
56 changes: 45 additions & 11 deletions test/test_suite_20.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@
UnicodeTextTest as _UnicodeTextTest,
_UnicodeFixture as __UnicodeFixture,
) # noqa: F401, F403
from test._helpers import get_db_url, get_project
from test._helpers import (
get_db_url,
get_project,
)

config.test_schema = ""

Expand All @@ -162,7 +165,7 @@ class BooleanTest(_BooleanTest):
def test_render_literal_bool(self):
pass

def test_render_literal_bool_true(self, literal_round_trip):
def test_render_literal_bool_true(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:
Expand All @@ -171,9 +174,9 @@ def test_render_literal_bool_true(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(Boolean(), [True], [True])
literal_round_trip_spanner(Boolean(), [True], [True])

def test_render_literal_bool_false(self, literal_round_trip):
def test_render_literal_bool_false(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:
Expand All @@ -182,7 +185,7 @@ def test_render_literal_bool_false(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(Boolean(), [False], [False])
literal_round_trip_spanner(Boolean(), [False], [False])

@pytest.mark.skip("Not supported by Cloud Spanner")
def test_whereclause(self):
Expand Down Expand Up @@ -2003,6 +2006,9 @@ def test_huge_int_auto_accommodation(self, connection, intvalue):
intvalue,
)

def test_literal(self, literal_round_trip_spanner):
literal_round_trip_spanner(Integer, [5], [5])


class _UnicodeFixture(__UnicodeFixture):
@classmethod
Expand Down Expand Up @@ -2189,6 +2195,19 @@ def test_dont_truncate_rightside(
args[1],
)

def test_literal(self, literal_round_trip_spanner):
# note that in Python 3, this invokes the Unicode
# datatype for the literal part because all strings are unicode
literal_round_trip_spanner(String(40), ["some text"], ["some text"])

def test_literal_quoting(self, literal_round_trip_spanner):
data = """some 'text' hey "hi there" that's text"""
literal_round_trip_spanner(String(40), [data], [data])

def test_literal_backslashes(self, literal_round_trip_spanner):
data = r"backslash one \ backslash two \\ end"
literal_round_trip_spanner(String(40), [data], [data])


class TextTest(_TextTest):
@classmethod
Expand Down Expand Up @@ -2224,6 +2243,21 @@ def test_text_empty_strings(self, connection):
def test_text_null_strings(self, connection):
pass

def test_literal(self, literal_round_trip_spanner):
literal_round_trip_spanner(Text, ["some text"], ["some text"])

def test_literal_quoting(self, literal_round_trip_spanner):
data = """some 'text' hey "hi there" that's text"""
literal_round_trip_spanner(Text, [data], [data])

def test_literal_backslashes(self, literal_round_trip_spanner):
data = r"backslash one \ backslash two \\ end"
literal_round_trip_spanner(Text, [data], [data])

def test_literal_percentsigns(self, literal_round_trip_spanner):
data = r"percent % signs %% percent"
literal_round_trip_spanner(Text, [data], [data])


class NumericTest(_NumericTest):
@testing.fixture
Expand Down Expand Up @@ -2254,7 +2288,7 @@ def run(type_, input_, output, filter_=None, check_scale=False):
return run

@emits_warning(r".*does \*not\* support Decimal objects natively")
def test_render_literal_numeric(self, literal_round_trip):
def test_render_literal_numeric(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:
Expand All @@ -2263,14 +2297,14 @@ def test_render_literal_numeric(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(
literal_round_trip_spanner(
Numeric(precision=8, scale=4),
[decimal.Decimal("15.7563")],
[decimal.Decimal("15.7563")],
)

@emits_warning(r".*does \*not\* support Decimal objects natively")
def test_render_literal_numeric_asfloat(self, literal_round_trip):
def test_render_literal_numeric_asfloat(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:
Expand All @@ -2279,13 +2313,13 @@ def test_render_literal_numeric_asfloat(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(
literal_round_trip_spanner(
Numeric(precision=8, scale=4, asdecimal=False),
[decimal.Decimal("15.7563")],
[15.7563],
)

def test_render_literal_float(self, literal_round_trip):
def test_render_literal_float(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:
Expand All @@ -2294,7 +2328,7 @@ def test_render_literal_float(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(
literal_round_trip_spanner(
Float(4),
[decimal.Decimal("15.7563")],
[15.7563],
Expand Down

0 comments on commit 62cccc3

Please sign in to comment.