diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py index 42f1bac..2848fa3 100644 --- a/target_snowflake/connector.py +++ b/target_snowflake/connector.py @@ -1,6 +1,7 @@ from operator import contains, eq from typing import Dict, List, Sequence, Tuple, Union, cast +import snowflake.sqlalchemy.custom_types as sct import sqlalchemy from singer_sdk import typing as th from singer_sdk.connectors import SQLConnector @@ -35,6 +36,34 @@ def evaluate_typemaps(type_maps, compare_value, unmatched_value): return unmatched_value +def _jsonschema_type_check(jsonschema_type: dict, type_check: Tuple[str]) -> bool: + """Return True if the jsonschema_type supports the provided type. + + Args: + jsonschema_type: The type dict. + type_check: A tuple of type strings to look for. + + Returns: + True if the schema suports the type. + """ + if "type" in jsonschema_type: + if isinstance(jsonschema_type["type"], (list, tuple)): + for schema_type in jsonschema_type["type"]: + if schema_type in type_check: + return True + else: + if jsonschema_type.get("type") in type_check: # noqa: PLR5501 + return True + + # TODO: remove following release of https://github.com/meltano/sdk/issues/1774 + if any( + _jsonschema_type_check(t, type_check) + for t in jsonschema_type.get("anyOf", ()) + ): + return True + + return False + class SnowflakeConnector(SQLConnector): """Snowflake Target Connector. @@ -43,7 +72,7 @@ class SnowflakeConnector(SQLConnector): allow_column_add: bool = True # Whether ADD COLUMN is supported. allow_column_rename: bool = True # Whether RENAME COLUMN is supported. - allow_column_alter: bool = False # Whether altering column types is supported. + allow_column_alter: bool = True # Whether altering column types is supported. allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. allow_temp_tables: bool = True # Whether temp tables are supported. @@ -70,10 +99,33 @@ def get_table_columns( if full_table_name in self.table_cache: return self.table_cache[full_table_name] else: - parsed_columns = super().get_table_columns(full_table_name, column_names) + _, schema_name, table_name = self.parse_full_table_name(full_table_name) + inspector = sqlalchemy.inspect(self._engine) + columns = inspector.get_columns(table_name, schema_name) + + parsed_columns = { + col_meta["name"]: sqlalchemy.Column( + col_meta["name"], + self._convert_type(col_meta["type"]), + nullable=col_meta.get("nullable", False), + ) + for col_meta in columns + if not column_names + or col_meta["name"].casefold() in {col.casefold() for col in column_names} + } self.table_cache[full_table_name] = parsed_columns return parsed_columns + def _convert_type(self, sql_type): + if isinstance(sql_type, sct.TIMESTAMP_NTZ): + return TIMESTAMP_NTZ + elif isinstance(sql_type, sct.NUMBER): + return NUMBER + elif isinstance(sql_type, sct.VARIANT): + return VARIANT + else: + return sql_type + def get_sqlalchemy_url(self, config: dict) -> str: """Generates a SQLAlchemy URL for Snowflake. @@ -204,12 +256,12 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: TypeMap(eq, sqlalchemy.types.VARCHAR(maxlength), None), ] type_maps = [ - TypeMap(th._jsonschema_type_check, NUMBER(), ("integer",)), - TypeMap(th._jsonschema_type_check, VARIANT(), ("object",)), - TypeMap(th._jsonschema_type_check, VARIANT(), ("array",)), + TypeMap(_jsonschema_type_check, NUMBER(), ("integer",)), + TypeMap(_jsonschema_type_check, VARIANT(), ("object",)), + TypeMap(_jsonschema_type_check, VARIANT(), ("array",)), ] # apply type maps - if th._jsonschema_type_check(jsonschema_type, ("string",)): + if _jsonschema_type_check(jsonschema_type, ("string",)): datelike_type = th.get_datelike_property_type(jsonschema_type) target_type = evaluate_typemaps(string_submaps, datelike_type, target_type) else: diff --git a/tests/core.py b/tests/core.py index 37b8a1e..4837532 100644 --- a/tests/core.py +++ b/tests/core.py @@ -391,6 +391,33 @@ def validate(self) -> None: row = result.first() assert len(row) == 12 +class SnowflakeTargetExistingTableAlter(SnowflakeTargetExistingTable): + + name = "existing_table_alter" + # This sends a schema that will request altering from TIMESTAMP_NTZ to VARCHAR + + def setup(self) -> None: + connector = self.target.default_sink_class.connector_class(self.target.config) + table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper() + connector.connection.execute( + f""" + CREATE OR REPLACE TABLE {table} ( + ID VARCHAR(16777216), + COL_STR VARCHAR(16777216), + COL_TS TIMESTAMP_NTZ(9), + COL_INT STRING, + COL_BOOL BOOLEAN, + COL_VARIANT VARIANT, + _SDC_BATCHED_AT TIMESTAMP_NTZ(9), + _SDC_DELETED_AT VARCHAR(16777216), + _SDC_EXTRACTED_AT TIMESTAMP_NTZ(9), + _SDC_RECEIVED_AT TIMESTAMP_NTZ(9), + _SDC_SEQUENCE NUMBER(38,0), + _SDC_TABLE_VERSION NUMBER(38,0), + PRIMARY KEY (ID) + ) + """ + ) target_tests = TestSuite( kind="target", @@ -417,5 +444,6 @@ def validate(self) -> None: SnowflakeTargetReservedWordsNoKeyProps, SnowflakeTargetColonsInColName, SnowflakeTargetExistingTable, + SnowflakeTargetExistingTableAlter, ], ) diff --git a/tests/target_test_streams/existing_table_alter.singer b/tests/target_test_streams/existing_table_alter.singer new file mode 100644 index 0000000..4fc7e2e --- /dev/null +++ b/tests/target_test_streams/existing_table_alter.singer @@ -0,0 +1,2 @@ +{ "type": "SCHEMA", "stream": "existing_table_alter", "schema": { "properties": { "id": { "type": [ "string", "null" ] }, "col_str": { "type": [ "string", "null" ] }, "col_ts": { "format": "date-time", "type": [ "string", "null" ] }, "col_int": { "type": "integer" }, "col_bool": { "type": [ "boolean", "null" ] }, "col_variant": {"type": "object"} }, "type": "object" }, "key_properties": [ "id" ], "bookmark_properties": [ "col_ts" ] } +{ "type": "RECORD", "stream": "existing_table_alter", "record": { "id": "123", "col_str": "foo", "col_ts": "2023-06-13 11:50:04.072", "col_int": 5, "col_bool": true, "col_variant": {"key": "val"} }, "time_extracted": "2023-06-14T18:08:23.074716+00:00" }