From 6f4c15de5c5ebd059ea79d706ba4b1890993786d Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Tue, 22 Oct 2024 14:04:56 +0200 Subject: [PATCH] mraba/underscore_column_id: use `_` as column identifier --- src/snowflake/sqlalchemy/base.py | 4 ++- tests/test_compiler.py | 17 +++++++++++- tests/test_quote.py | 23 ++++++++++++++++ tests/test_quote_identifiers.py | 46 ++++++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 tests/test_quote_identifiers.py diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 56631728..bba910c9 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -5,6 +5,7 @@ import itertools import operator import re +import string from sqlalchemy import exc as sa_exc from sqlalchemy import inspect, sql @@ -106,7 +107,7 @@ AUTOCOMMIT_REGEXP = re.compile( r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE ) - +ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"})) """ Overwrite methods to handle Snowflake BCR change: @@ -431,6 +432,7 @@ def _join_left_to_right( class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS def __init__(self, dialect, **kw): quote = '"' diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 40207b41..55451c2f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Integer, String, and_, func, select +from sqlalchemy import Integer, String, and_, func, insert, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table from sqlalchemy.testing.assertions import AssertsCompiledSQL @@ -33,6 +33,21 @@ def test_now_func(self): dialect="snowflake", ) + def test_underscore_as_valid_identifier(self): + _table = table( + "table_1745924", + column("ca", Integer), + column("cb", String), + column("_", String), + ) + + stmt = insert(_table).values(ca=1, cb="test", _="test_") + self.assert_compile( + stmt, + 'INSERT INTO table_1745924 (ca, cb, "_") VALUES (%(ca)s, %(cb)s, %(_)s)', + dialect="snowflake", + ) + def test_multi_table_delete(self): statement = table1.delete().where(table1.c.id == table2.c.id) self.assert_compile( diff --git a/tests/test_quote.py b/tests/test_quote.py index ca6f36dd..0dd69059 100644 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -38,3 +38,26 @@ def test_table_name_with_reserved_words(engine_testaccount, db_parameters): finally: insert_table.drop(engine_testaccount) return insert_table + + +def test_table_column_as_underscore(engine_testaccount): + metadata = MetaData() + test_table_name = "table_1745924" + insert_table = Table( + test_table_name, + metadata, + Column("ca", Integer), + Column("cb", String), + Column("_", String), + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_insert = inspector.get_columns(test_table_name) + assert len(columns_in_insert) == 3 + assert columns_in_insert[0]["name"] == "ca" + assert columns_in_insert[1]["name"] == "cb" + assert columns_in_insert[2]["name"] == "_" + finally: + insert_table.drop(engine_testaccount) + return insert_table diff --git a/tests/test_quote_identifiers.py b/tests/test_quote_identifiers.py new file mode 100644 index 00000000..c78dbcaa --- /dev/null +++ b/tests/test_quote_identifiers.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + create_engine, + insert, + select, +) + +from snowflake.sqlalchemy import URL + +from .parameters import CONNECTION_PARAMETERS + + +def test_insert_with_identifier(): + metadata = MetaData() + table = Table( + "table_1745924", + metadata, + Column("ca", Integer), + Column("cb", String), + Column("_", String), + ) + + engine = create_engine(URL(**CONNECTION_PARAMETERS)) + + try: + metadata.create_all(engine) + + with engine.connect() as connection: + connection.execute(insert(table).values(ca=1, cb="test", _="test_")) + connection.execute( + insert(table).values({"ca": 2, "cb": "test", "_": "test_"}) + ) + result = connection.execute(select(table)).fetchall() + assert result == [ + (1, "test", "test_"), + (2, "test", "test_"), + ] + finally: + metadata.drop_all(engine)