Skip to content

Commit

Permalink
mraba/underscore_column_id: use _ as column identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mraba committed Oct 22, 2024
1 parent 43c6b56 commit 6f4c15d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import operator
import re
import string

from sqlalchemy import exc as sa_exc
from sqlalchemy import inspect, sql
Expand Down Expand Up @@ -106,7 +107,7 @@
AUTOCOMMIT_REGEXP = re.compile(
r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE
)

ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"}))

"""
Overwrite methods to handle Snowflake BCR change:
Expand Down Expand Up @@ -431,6 +432,7 @@ def _join_left_to_right(

class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = {x.lower() for x in RESERVED_WORDS}
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS

def __init__(self, dialect, **kw):
quote = '"'
Expand Down
17 changes: 16 additions & 1 deletion tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from sqlalchemy import Integer, String, and_, func, select
from sqlalchemy import Integer, String, and_, func, insert, select
from sqlalchemy.schema import DropColumnComment, DropTableComment
from sqlalchemy.sql import column, quoted_name, table
from sqlalchemy.testing.assertions import AssertsCompiledSQL
Expand Down Expand Up @@ -33,6 +33,21 @@ def test_now_func(self):
dialect="snowflake",
)

def test_underscore_as_valid_identifier(self):
_table = table(
"table_1745924",
column("ca", Integer),
column("cb", String),
column("_", String),
)

stmt = insert(_table).values(ca=1, cb="test", _="test_")
self.assert_compile(
stmt,
'INSERT INTO table_1745924 (ca, cb, "_") VALUES (%(ca)s, %(cb)s, %(_)s)',
dialect="snowflake",
)

def test_multi_table_delete(self):
statement = table1.delete().where(table1.c.id == table2.c.id)
self.assert_compile(
Expand Down
23 changes: 23 additions & 0 deletions tests/test_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,26 @@ def test_table_name_with_reserved_words(engine_testaccount, db_parameters):
finally:
insert_table.drop(engine_testaccount)
return insert_table


def test_table_column_as_underscore(engine_testaccount):
metadata = MetaData()
test_table_name = "table_1745924"
insert_table = Table(
test_table_name,
metadata,
Column("ca", Integer),
Column("cb", String),
Column("_", String),
)
metadata.create_all(engine_testaccount)
try:
inspector = inspect(engine_testaccount)
columns_in_insert = inspector.get_columns(test_table_name)
assert len(columns_in_insert) == 3
assert columns_in_insert[0]["name"] == "ca"
assert columns_in_insert[1]["name"] == "cb"
assert columns_in_insert[2]["name"] == "_"
finally:
insert_table.drop(engine_testaccount)
return insert_table
46 changes: 46 additions & 0 deletions tests/test_quote_identifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.

from sqlalchemy import (
Column,
Integer,
MetaData,
String,
Table,
create_engine,
insert,
select,
)

from snowflake.sqlalchemy import URL

from .parameters import CONNECTION_PARAMETERS


def test_insert_with_identifier():
metadata = MetaData()
table = Table(
"table_1745924",
metadata,
Column("ca", Integer),
Column("cb", String),
Column("_", String),
)

engine = create_engine(URL(**CONNECTION_PARAMETERS))

try:
metadata.create_all(engine)

with engine.connect() as connection:
connection.execute(insert(table).values(ca=1, cb="test", _="test_"))
connection.execute(
insert(table).values({"ca": 2, "cb": "test", "_": "test_"})
)
result = connection.execute(select(table)).fetchall()
assert result == [
(1, "test", "test_"),
(2, "test", "test_"),
]
finally:
metadata.drop_all(engine)

0 comments on commit 6f4c15d

Please sign in to comment.