From af5457afd267052803a905c7f57add5d3fbf828f Mon Sep 17 00:00:00 2001 From: Gabriel Venegas Castro Date: Mon, 16 Dec 2024 08:51:09 -0600 Subject: [PATCH] SNOW-1776332 Add support for OBJECT (#559) * SNOW-1776332 Add support for OBJECT * Updated description.md * Add missing @pytest.mark.requires_external_volume * Tuple validation in OBJECT class --- DESCRIPTION.md | 1 + pyproject.toml | 3 + src/snowflake/sqlalchemy/base.py | 13 +- src/snowflake/sqlalchemy/custom_types.py | 22 +- .../sqlalchemy/parser/custom_type_parser.py | 37 +- src/snowflake/sqlalchemy/snowdialect.py | 21 +- .../test_structured_datatypes.ambr | 111 ++++- tests/test_structured_datatypes.py | 392 +++++++++++++----- tests/test_unit_structured_types.py | 8 +- 9 files changed, 489 insertions(+), 119 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index bed7670b..d5872bb9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -13,6 +13,7 @@ Source code is also available at: - Fix quoting of `_` as column name - Fix index columns was not being reflected - Fix index reflection cache not working + - Add support for structured OBJECT datatype - v1.7.1(December 02, 2024) - Add support for partition by to copy into diff --git a/pyproject.toml b/pyproject.toml index 84e64faf..b0ae04c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ path = "src/snowflake/sqlalchemy/version.py" development = [ "pre-commit", "pytest", + "setuptools", "pytest-cov", "pytest-timeout", "pytest-rerunfailures", @@ -74,6 +75,8 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] +path = ".venv" +type = "virtual" extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"] features = ["development", "pandas"] python = "3.8" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4c632e7a..3fef7709 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1099,7 +1099,18 @@ def visit_ARRAY(self, type_, **kw): return "ARRAY" def visit_OBJECT(self, type_, **kw): - return "OBJECT" + if type_.is_semi_structured: + return "OBJECT" + else: + contents = [] + for key in type_.items_types: + + row_text = f"{key} {type_.items_types[key][0].compile()}" + # Type and not null is specified + if len(type_.items_types[key]) > 1: + row_text += f"{' NOT NULL' if type_.items_types[key][1] else ''}" + contents.append(row_text) + return "OBJECT" if contents == [] else f"OBJECT({', '.join(contents)})" def visit_BLOB(self, type_, **kw): return "BINARY" diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index f2c950dd..ce7ad592 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -1,9 +1,11 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from typing import Tuple, Union import sqlalchemy.types as sqltypes import sqlalchemy.util as util +from sqlalchemy.types import TypeEngine TEXT = sqltypes.VARCHAR CHARACTER = sqltypes.CHAR @@ -57,9 +59,27 @@ def __init__( super().__init__() -class OBJECT(SnowflakeType): +class OBJECT(StructuredType): __visit_name__ = "OBJECT" + def __init__(self, **items_types: Union[TypeEngine, Tuple[TypeEngine, bool]]): + for key, value in items_types.items(): + if not isinstance(value, tuple): + items_types[key] = (value, False) + + self.items_types = items_types + self.is_semi_structured = len(items_types) == 0 + super().__init__() + + def __repr__(self): + quote_char = "'" + return "OBJECT(%s)" % ", ".join( + [ + f"{repr(key).strip(quote_char)}={repr(value)}" + for key, value in self.items_types.items() + ] + ) + class ARRAY(SnowflakeType): __visit_name__ = "ARRAY" diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py index dada612d..1e99ba56 100644 --- a/src/snowflake/sqlalchemy/parser/custom_type_parser.py +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -49,11 +49,10 @@ "DECIMAL": DECIMAL, "DOUBLE": DOUBLE, "FIXED": DECIMAL, - "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters + "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't have parameters "INT": INTEGER, "INTEGER": INTEGER, "NUMBER": _CUSTOM_DECIMAL, - # 'OBJECT': ? "REAL": REAL, "BYTEINT": SMALLINT, "SMALLINT": SMALLINT, @@ -76,18 +75,19 @@ } -def extract_parameters(text: str) -> list: +def tokenize_parameters(text: str, character_for_strip=",") -> list: """ Extracts parameters from a comma-separated string, handling parentheses. :param text: A string with comma-separated parameters, which may include parentheses. + :param character_for_strip: A character to strip the text. + :return: A list of parameters as strings. :example: For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`. """ - output_parameters = [] parameter = "" open_parenthesis = 0 @@ -98,9 +98,9 @@ def extract_parameters(text: str) -> list: elif c == ")": open_parenthesis -= 1 - if open_parenthesis > 0 or c != ",": + if open_parenthesis > 0 or c != character_for_strip: parameter += c - elif c == ",": + elif c == character_for_strip: output_parameters.append(parameter.strip(" ")) parameter = "" if parameter != "": @@ -138,14 +138,17 @@ def parse_type(type_text: str) -> TypeEngine: parse_type("VARCHAR(255)") String(length=255) """ + index = type_text.find("(") type_name = type_text[:index] if index != -1 else type_text + parameters = ( - extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] + tokenize_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] ) col_type_class = ischema_names.get(type_name, None) col_type_kw = {} + if col_type_class is None: col_type_class = NullType else: @@ -155,6 +158,8 @@ def parse_type(type_text: str) -> TypeEngine: col_type_kw = __parse_type_with_length_parameters(parameters) elif issubclass(col_type_class, MAP): col_type_kw = __parse_map_type_parameters(parameters) + elif issubclass(col_type_class, OBJECT): + col_type_kw = __parse_object_type_parameters(parameters) if col_type_kw is None: col_type_class = NullType col_type_kw = {} @@ -162,6 +167,24 @@ def parse_type(type_text: str) -> TypeEngine: return col_type_class(**col_type_kw) +def __parse_object_type_parameters(parameters): + object_rows = {} + for parameter in parameters: + parameter_parts = tokenize_parameters(parameter, " ") + if len(parameter_parts) >= 2: + key = parameter_parts[0] + value_type = parse_type(parameter_parts[1]) + if isinstance(value_type, NullType): + return None + not_null = ( + len(parameter_parts) == 4 + and parameter_parts[2] == "NOT" + and parameter_parts[3] == "NULL" + ) + object_rows[key] = (value_type, not_null) + return object_rows + + def __parse_map_type_parameters(parameters): if len(parameters) != 2: return None diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 935794d9..dd5e4375 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -1,7 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - import operator import re from collections import defaultdict @@ -9,7 +8,7 @@ from typing import Any, Collection, Optional from urllib.parse import unquote_plus -import sqlalchemy.types as sqltypes +import sqlalchemy.sql.sqltypes as sqltypes from sqlalchemy import event as sa_vnt from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util @@ -17,7 +16,8 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.types import FLOAT, Date, DateTime, Float, NullType, Time +from sqlalchemy.sql.sqltypes import NullType +from sqlalchemy.types import FLOAT, Date, DateTime, Float, Time from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION @@ -33,7 +33,7 @@ SnowflakeTypeCompiler, ) from .custom_types import ( - MAP, + StructuredType, _CUSTOM_Date, _CUSTOM_DateTime, _CUSTOM_Float, @@ -466,6 +466,14 @@ def _get_schema_columns(self, connection, schema, **kw): connection, full_schema_name, **kw ) schema_name = self.denormalize_name(schema) + + iceberg_table_names = self.get_table_names_with_prefix( + connection, + schema=schema_name, + prefix=CustomTablePrefix.ICEBERG.name, + info_cache=kw.get("info_cache", None), + ) + result = connection.execute( text( """ @@ -526,7 +534,10 @@ def _get_schema_columns(self, connection, schema, **kw): col_type_kw["scale"] = numeric_scale elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): col_type_kw["length"] = character_maximum_length - elif issubclass(col_type, MAP): + elif ( + issubclass(col_type, StructuredType) + and table_name in iceberg_table_names + ): if (schema_name, table_name) not in full_columns_descriptions: full_columns_descriptions[(schema_name, table_name)] = ( self.table_columns_as_dict( diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr index 0325a946..714f5d57 100644 --- a/tests/__snapshots__/test_structured_datatypes.ambr +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -5,6 +5,15 @@ # name: test_compile_table_with_double_map 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' # --- +# name: test_compile_table_with_structured_data_type[structured_type0] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type1] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type2] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' +# --- # name: test_insert_map list([ (1, '{\n "100": "item1",\n "200": "item2"\n}'), @@ -16,6 +25,43 @@ Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause ''' # --- +# name: test_insert_structured_object + list([ + (1, '{\n "key1": "item1",\n "key2": 15\n}'), + ]) +# --- +# name: test_insert_structured_object_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(OBJECT_CONSTRUCT('key1', 1, 'key2', 'item1') AS OBJECT(key1 NUMBER(10,0), key2 VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_inspect_structured_data_types[structured_type0-MAP] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)), + }), + ]) +# --- # name: test_inspect_structured_data_types[structured_type0] list([ dict({ @@ -42,6 +88,32 @@ }), ]) # --- +# name: test_inspect_structured_data_types[structured_type1-MAP] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))), + }), + ]) +# --- # name: test_inspect_structured_data_types[structured_type1] list([ dict({ @@ -68,11 +140,40 @@ }), ]) # --- +# name: test_inspect_structured_data_types[structured_type2-OBJECT] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': OBJECT(key1=(VARCHAR(length=16777216), False), key2=(_CUSTOM_DECIMAL(precision=10, scale=0), False)), + }), + ]) +# --- # name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] - "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- # name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)] - "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- # name: test_select_map_orm list([ @@ -88,3 +189,9 @@ list([ ]) # --- +# name: test_select_structured_object_orm + list([ + (1, '{\n "key1": "value2",\n "key2": 2\n}'), + (2, '{\n "key1": "value1",\n "key2": 1\n}'), + ]) +# --- diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py index 4ea0892b..d6beb3e9 100644 --- a/tests/test_structured_datatypes.py +++ b/tests/test_structured_datatypes.py @@ -1,7 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - import pytest from sqlalchemy import ( Column, @@ -19,17 +18,27 @@ from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable -from snowflake.sqlalchemy.custom_types import MAP, TEXT +from snowflake.sqlalchemy.custom_types import MAP, OBJECT, TEXT from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError -def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): +@pytest.mark.parametrize( + "structured_type", + [ + MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), + ], +) +def test_compile_table_with_structured_data_type( + sql_compiler, snapshot, structured_type +): metadata = MetaData() user_table = Table( "clustered_user", metadata, Column("Id", Integer, primary_key=True), - Column("name", MAP(NUMBER(), TEXT())), + Column("name", structured_type), ) create_table = CreateTable(user_table) @@ -38,35 +47,152 @@ def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): @pytest.mark.requires_external_volume -def test_create_table_structured_datatypes( - engine_testaccount, external_volume, base_location +def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + """ + Test inserting data into a table with a MAP column type. + + Args: + engine_testaccount: The SQLAlchemy engine connected to the test account. + external_volume: The external volume to use for the table. + base_location: The base location for the table. + snapshot: The snapshot object for assertion. + """ + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ) + ins = test_map.insert().from_select(["id", "map_id"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot ): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + map_id = Column(MAP(NUMBER(10, 0), TEXT(16777216))) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT(16777216)) + ) + instance = TestIcebergTableOrm(id=0, map_id=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): metadata = MetaData() - table_name = "test_map0" + table_name = "test_select_map_orm" test_map = IcebergTable( table_name, metadata, Column("id", Integer, primary_key=True), - Column("map_id", MAP(NUMBER(10, 0), TEXT())), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), external_volume=external_volume, base_location=base_location, ) metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ) + slt2 = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "map_id"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.map_id!r})" + try: - assert test_map is not None + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) finally: test_map.drop(engine_testaccount) @pytest.mark.requires_external_volume -def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): +def test_insert_structured_object( + engine_testaccount, external_volume, base_location, snapshot +): metadata = MetaData() - table_name = "test_insert_map" + table_name = "test_insert_structured_object" test_map = IcebergTable( table_name, metadata, Column("id", Integer, primary_key=True), - Column("map_id", MAP(NUMBER(10, 0), TEXT())), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), external_volume=external_volume, base_location=base_location, ) @@ -77,10 +203,11 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot slt = select( 1, cast( - text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + text("{'key1':'item1', 'key2': 15}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), ), ) - ins = test_map.insert().from_select(["id", "map_id"], slt) + ins = test_map.insert().from_select(["id", "object_col"], slt) conn.execute(ins) results = conn.execute(test_map.select()) @@ -91,16 +218,125 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot test_map.drop(engine_testaccount) +@pytest.mark.requires_external_volume +def test_insert_structured_object_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + object_col = Column( + OBJECT(key1=(NUMBER(10, 0), False), key2=(TEXT(16777216), False)) + ) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{ 'key1' : 1, 'key2' : 'item1' }"), + OBJECT(key1=(NUMBER(10, 0), False), key2=(TEXT(16777216), False)), + ) + instance = TestIcebergTableOrm(id=0, object_col=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_select_structured_object_orm( + engine_testaccount, external_volume, base_location, snapshot +): + metadata = MetaData() + table_name = "test_select_structured_object_orm" + iceberg_table = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "structured_obj_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + first_select = select( + 2, + cast( + text("{'key1': 'value1', 'key2': 1}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ) + second_select = select( + 1, + cast( + text("{'key1': 'value2', 'key2': 2}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ).union_all(first_select) + insert_statement = iceberg_table.insert().from_select( + ["id", "structured_obj_col"], second_select + ) + conn.execute(insert_statement) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = iceberg_table + + def __repr__(self): + return f"({self.id!r}, {self.structured_obj_col!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + iceberg_table.drop(engine_testaccount) + + @pytest.mark.requires_external_volume @pytest.mark.parametrize( - "structured_type", + "structured_type, expected_type", [ - MAP(NUMBER(10, 0), TEXT()), - MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), + (MAP(NUMBER(10, 0), TEXT(16777216)), MAP), + (MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), MAP), + ( + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT, + ), ], ) def test_inspect_structured_data_types( - engine_testaccount, external_volume, base_location, snapshot, structured_type + engine_testaccount, + external_volume, + base_location, + snapshot, + structured_type, + expected_type, ): metadata = MetaData() table_name = "test_st_types" @@ -108,7 +344,7 @@ def test_inspect_structured_data_types( table_name, metadata, Column("id", Integer, primary_key=True), - Column("map_id", structured_type), + Column("structured_type_col", structured_type), external_volume=external_volume, base_location=base_location, ) @@ -119,7 +355,7 @@ def test_inspect_structured_data_types( columns = inspecter.get_columns(table_name) assert isinstance(columns[0]["type"], NUMBER) - assert isinstance(columns[1]["type"], MAP) + assert isinstance(columns[1]["type"], expected_type) assert columns == snapshot finally: @@ -132,6 +368,7 @@ def test_inspect_structured_data_types( [ "MAP(NUMBER(10, 0), VARCHAR)", "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", + "OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))", ], ) def test_reflect_structured_data_types( @@ -147,7 +384,7 @@ def test_reflect_structured_data_types( create_table_sql = f""" CREATE OR REPLACE ICEBERG TABLE {table_name} ( id number(38,0) primary key, - map_id {structured_type}) + structured_type_col {structured_type}) CATALOG = 'SNOWFLAKE' EXTERNAL_VOLUME = '{external_volume}' BASE_LOCATION = '{base_location}'; @@ -174,47 +411,43 @@ def test_reflect_structured_data_types( @pytest.mark.requires_external_volume -def test_insert_map_orm( - sql_compiler, external_volume, base_location, engine_testaccount, snapshot +def test_create_table_structured_datatypes( + engine_testaccount, external_volume, base_location ): - Base = declarative_base() - session = Session(bind=engine_testaccount) - - class TestIcebergTableOrm(Base): - __tablename__ = "test_iceberg_table_orm" - - @classmethod - def __table_cls__(cls, name, metadata, *arg, **kw): - return IcebergTable(name, metadata, *arg, **kw) - - __table_args__ = { - "external_volume": external_volume, - "base_location": base_location, - } - - id = Column(Integer, Sequence("user_id_seq"), primary_key=True) - map_id = Column(MAP(NUMBER(10, 0), TEXT())) - - def __repr__(self): - return f"({self.id!r}, {self.name!r})" - - Base.metadata.create_all(engine_testaccount) - + metadata = MetaData() + table_name = "test_structured0" + test_structured_dt = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) try: - cast_expr = cast( - text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) - ) - instance = TestIcebergTableOrm(id=0, map_id=cast_expr) - session.add(instance) - with pytest.raises(exc.ProgrammingError) as programming_error: - session.commit() - # TODO: Support variant in insert statement - assert str(programming_error.value.orig) == snapshot + assert test_structured_dt is not None finally: - Base.metadata.drop_all(engine_testaccount) + test_structured_dt.drop(engine_testaccount) -def test_snowflake_tables_with_structured_types(sql_compiler): +@pytest.mark.parametrize( + "structured_type_col", + [ + Column("name", MAP(NUMBER(10, 0), TEXT(16777216))), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ], +) +def test_structured_type_not_supported_in_table_columns_error( + sql_compiler, structured_type_col +): metadata = MetaData() with pytest.raises( StructuredTypeNotSupportedInTableColumnsError @@ -223,49 +456,6 @@ def test_snowflake_tables_with_structured_types(sql_compiler): "clustered_user", metadata, Column("Id", Integer, primary_key=True), - Column("name", MAP(NUMBER(10, 0), TEXT())), + structured_type_col, ) assert programming_error is not None - - -@pytest.mark.requires_external_volume -def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): - metadata = MetaData() - table_name = "test_select_map_orm" - test_map = IcebergTable( - table_name, - metadata, - Column("id", Integer, primary_key=True), - Column("map_id", MAP(NUMBER(10, 0), TEXT())), - external_volume=external_volume, - base_location=base_location, - ) - metadata.create_all(engine_testaccount) - - with engine_testaccount.connect() as conn: - slt1 = select( - 2, - cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), - ) - slt2 = select( - 1, - cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), - ).union_all(slt1) - ins = test_map.insert().from_select(["id", "map_id"], slt2) - conn.execute(ins) - conn.commit() - - Base = declarative_base() - session = Session(bind=engine_testaccount) - - class TestIcebergTableOrm(Base): - __table__ = test_map - - def __repr__(self): - return f"({self.id!r}, {self.map_id!r})" - - try: - data = session.query(TestIcebergTableOrm).all() - snapshot.assert_match(data) - finally: - test_map.drop(engine_testaccount) diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py index c7bcd6ef..474ebde4 100644 --- a/tests/test_unit_structured_types.py +++ b/tests/test_unit_structured_types.py @@ -6,8 +6,8 @@ from snowflake.sqlalchemy import NUMBER from snowflake.sqlalchemy.custom_types import MAP, TEXT from src.snowflake.sqlalchemy.parser.custom_type_parser import ( - extract_parameters, parse_type, + tokenize_parameters, ) @@ -18,7 +18,7 @@ def test_compile_map_with_not_null(snapshot): def test_extract_parameters(): example = "a, b(c, d, f), d" - assert extract_parameters(example) == ["a", "b(c, d, f)", "d"] + assert tokenize_parameters(example) == ["a", "b(c, d, f)", "d"] @pytest.mark.parametrize( @@ -64,6 +64,10 @@ def test_extract_parameters(): ), ("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"), ("OBJECT", "OBJECT"), + ( + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + ), ("ARRAY", "ARRAY"), ("GEOGRAPHY", "GEOGRAPHY"), ("GEOMETRY", "GEOMETRY"),