diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py new file mode 100644 index 0000000..608baac --- /dev/null +++ b/target_snowflake/connector.py @@ -0,0 +1,354 @@ +import os +from operator import contains, eq +from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, cast +from urllib.parse import urlparse +from uuid import uuid4 + +import snowflake.sqlalchemy.custom_types as sct +import sqlalchemy +from singer_sdk import typing as th +from singer_sdk.batch import lazy_chunked_generator +from singer_sdk.connectors import SQLConnector +from singer_sdk.helpers._batch import BaseBatchFileEncoding, BatchConfig +from singer_sdk.helpers._typing import conform_record_data_types +from snowflake.sqlalchemy import URL +from sqlalchemy.engine import Engine +from sqlalchemy.sql import text + + +class TypeMap: + def __init__(self, operator, map_value, match_value=None): + self.operator = operator + self.map_value = map_value + self.match_value = match_value + + def match(self, compare_value): + try: + if self.match_value: + return self.operator(compare_value, self.match_value) + return self.operator(compare_value) + except TypeError: + return False + + +def evaluate_typemaps(type_maps, compare_value, unmatched_value): + for type_map in type_maps: + if type_map.match(compare_value): + return type_map.map_value + return unmatched_value + + +class SnowflakeConnector(SQLConnector): + """Snowflake Target Connector. + + This class handles all DDL and type conversions. + """ + + allow_column_add: bool = True # Whether ADD COLUMN is supported. + allow_column_rename: bool = True # Whether RENAME COLUMN is supported. + allow_column_alter: bool = True # Whether altering column types is supported. + allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. + allow_temp_tables: bool = True # Whether temp tables are supported. + + def get_sqlalchemy_url(self, config: dict) -> str: + """Generates a SQLAlchemy URL for Snowflake. + + Args: + config: The configuration for the connector. + """ + params = { + "account": config["account"], + "user": config["user"], + "password": config["password"], + "database": config["database"], + } + + for option in ["warehouse", "role"]: + if config.get(option): + params[option] = config.get(option) + + return URL(**params) + + def create_engine(self) -> Engine: + """Creates and returns a new engine. Do not call outside of _engine. + + NOTE: Do not call this method. The only place that this method should + be called is inside the self._engine method. If you'd like to access + the engine on a connector, use self._engine. + + This method exists solely so that tap/target developers can override it + on their subclass of SQLConnector to perform custom engine creation + logic. + + Returns: + A new SQLAlchemy Engine. + """ + return sqlalchemy.create_engine( + self.sqlalchemy_url, + connect_args={ + "session_parameters": { + "QUOTED_IDENTIFIERS_IGNORE_CASE": "TRUE", + } + }, + echo=False, + ) + + @staticmethod + def get_column_alter_ddl( + table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine + ) -> sqlalchemy.DDL: + """Get the alter column DDL statement. + + Override this if your database uses a different syntax for altering columns. + + Args: + table_name: Fully qualified table name of column to alter. + column_name: Column name to alter. + column_type: New column type string. + + Returns: + A sqlalchemy DDL instance. + """ + return sqlalchemy.DDL( + "ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s SET DATA TYPE %(column_type)s", + { + "table_name": table_name, + "column_name": column_name, + "column_type": column_type, + }, + ) + + @staticmethod + def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: + """Return a JSON Schema representation of the provided type. + + Uses custom Snowflake types from [snowflake-sqlalchemy](https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/src/snowflake/sqlalchemy/custom_types.py) + + Args: + jsonschema_type: The JSON Schema representation of the source type. + + Returns: + The SQLAlchemy type representation of the data type. + """ + # start with default implementation + target_type = SQLConnector.to_sql_type(jsonschema_type) + # snowflake max and default varchar length + # https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html + maxlength = jsonschema_type.get("maxLength", 16777216) + # define type maps + string_submaps = [ + TypeMap(eq, sct.TIMESTAMP_NTZ(), "date-time"), + TypeMap(contains, sqlalchemy.types.TIME(), "time"), + TypeMap(eq, sqlalchemy.types.DATE(), "date"), + TypeMap(eq, sqlalchemy.types.VARCHAR(maxlength), None), + ] + type_maps = [ + TypeMap(th._jsonschema_type_check, sct.NUMBER(), ("integer",)), + TypeMap(th._jsonschema_type_check, sct.VARIANT(), ("object",)), + TypeMap(th._jsonschema_type_check, sct.VARIANT(), ("array",)), + ] + # apply type maps + if th._jsonschema_type_check(jsonschema_type, ("string",)): + datelike_type = th.get_datelike_property_type(jsonschema_type) + target_type = evaluate_typemaps(string_submaps, datelike_type, target_type) + else: + target_type = evaluate_typemaps(type_maps, jsonschema_type, target_type) + + return cast(sqlalchemy.types.TypeEngine, target_type) + + def prepare_schema(self, schema_name: str) -> None: + """Create the target database schema. + + Args: + schema_name: The target schema name. + """ + with self._connect() as conn: + conn.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") + + # Custom SQL get methods + + def _get_put_statement(self, sync_id: str, file_uri: str) -> Tuple[text, dict]: + """Get Snowflake PUT statement.""" + return (text(f"put '{file_uri}' '@~/target-snowflake/{sync_id}'"), {}) + + def _get_merge_from_stage_statement( + self, full_table_name, schema, sync_id, file_format, key_properties + ): + """Get Snowflake MERGE statement.""" + + # convert from case in JSON to UPPER column name + column_selections = [ + f"$1:{property_name}::{self.to_sql_type(property_def)} as {property_name.upper()}" + for property_name, property_def in schema["properties"].items() + ] + + # use UPPER from here onwards + upper_properties = [col.upper() for col in schema["properties"].keys()] + upper_key_properties = [col.upper() for col in key_properties] + join_expr = " and ".join( + [f'd."{key}" = s."{key}"' for key in upper_key_properties] + ) + matched_clause = ", ".join( + [f'd."{col}" = s."{col}"' for col in upper_properties] + ) + not_matched_insert_cols = ", ".join(upper_properties) + not_matched_insert_values = ", ".join( + [f's."{col}"' for col in upper_properties] + ) + return ( + text( + f"merge into {full_table_name} d using " + + f"(select {', '.join(column_selections)} from '@~/target-snowflake/{sync_id}'" + + f"(file_format => {file_format})) s " + + f"on {join_expr} " + + f"when matched then update set {matched_clause} " + + f"when not matched then insert ({not_matched_insert_cols}) " + + f"values ({not_matched_insert_values})" + ), + {}, + ) + + def _get_copy_statement(self, full_table_name, schema, sync_id, file_format): + """Get Snowflake COPY statement.""" + # convert from case in JSON to UPPER column name + column_selections = [ + f"$1:{property_name}::{self.to_sql_type(property_def)} as {property_name.upper()}" + for property_name, property_def in schema["properties"].items() + ] + return ( + text( + f"copy into {full_table_name} from " + + f"(select {', '.join(column_selections)} from " + + f"'@~/target-snowflake/{sync_id}')" + + f"file_format = (format_name='{file_format}')" + ), + {}, + ) + + def _get_file_format_statement(self, file_format): + """Get Snowflake CREATE FILE FORMAT statement.""" + return ( + text( + f"create or replace file format {file_format}" + + "type = 'JSON' compression = 'GZIP'" + ), + {}, + ) + + def _get_drop_file_format_statement(self, file_format): + """Get Snowflake DROP FILE FORMAT statement.""" + return ( + text(f"drop file format if exists {file_format}"), + {}, + ) + + def _get_stage_files_remove_statement(self, sync_id): + """Get Snowflake REMOVE statement.""" + return ( + text(f"remove '@~/target-snowflake/{sync_id}/'"), + {}, + ) + + # Custom connector methods + + def put_batches_to_stage(self, sync_id: str, files: Sequence[str]) -> None: + """Upload a batch of records to Snowflake. + + Args: + sync_id: The sync ID for the batch. + files: The files containing records to upload. + """ + with self._connect() as conn: + for file_uri in files: + put_statement, kwargs = self._get_put_statement( + sync_id=sync_id, file_uri=file_uri + ) + conn.execute(put_statement, **kwargs) + + def create_file_format(self, file_format: str) -> None: + """Create a file format in the schema. + + Args: + file_format: The name of the file format. + """ + with self._connect() as conn: + file_format_statement, kwargs = self._get_file_format_statement( + file_format=file_format + ) + self.logger.debug( + f"Creating file format with SQL: {file_format_statement!s}" + ) + conn.execute(file_format_statement, **kwargs) + + def merge_from_stage( + self, + full_table_name: str, + schema: dict, + sync_id: str, + file_format: str, + key_properties: Sequence[str], + ): + """Merge data from a stage into a table. + + Args: + sync_id: The sync ID for the batch. + schema: The schema of the data. + key_properties: The primary key properties of the data. + """ + with self._connect() as conn: + merge_statement, kwargs = self._get_merge_from_stage_statement( + full_table_name=full_table_name, + schema=schema, + sync_id=sync_id, + file_format=file_format, + key_properties=key_properties, + ) + self.logger.debug(f"Merging with SQL: {merge_statement!s}") + conn.execute(merge_statement, **kwargs) + + def copy_from_stage( + self, full_table_name: str, schema: dict, sync_id: str, file_format: str + ): + """Copy data from a stage into a table. + + Args: + full_table_name: The fully-qualified name of the table. + schema: The schema of the data. + sync_id: The sync ID for the batch. + file_format: The name of the file format. + """ + with self._connect() as conn: + copy_statement, kwargs = self._get_copy_statement( + full_table_name=full_table_name, + schema=schema, + sync_id=sync_id, + file_format=file_format, + ) + self.logger.debug(f"Copying with SQL: {copy_statement!s}") + conn.execute(copy_statement, **kwargs) + + def drop_file_format(self, file_format: str) -> None: + """Drop a file format in the schema. + + Args: + file_format: The name of the file format. + """ + with self._connect() as conn: + drop_statement, kwargs = self._get_drop_file_format_statement( + file_format=file_format + ) + self.logger.debug(f"Dropping file format with SQL: {drop_statement!s}") + conn.execute(drop_statement, **kwargs) + + def remove_staged_files(self, sync_id: str) -> None: + """Remove staged files. + + Args: + sync_id: The sync ID for the batch. + """ + with self._connect() as conn: + remove_statement, kwargs = self._get_stage_files_remove_statement( + sync_id=sync_id + ) + self.logger.debug(f"Removing staged files with SQL: {remove_statement!s}") + conn.execute(remove_statement, **kwargs) diff --git a/target_snowflake/sinks.py b/target_snowflake/sinks.py index d1d6947..752a616 100644 --- a/target_snowflake/sinks.py +++ b/target_snowflake/sinks.py @@ -4,207 +4,111 @@ import gzip import json import os -from operator import contains, eq -from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, cast +import typing as t from urllib.parse import urlparse from uuid import uuid4 -import snowflake.sqlalchemy.custom_types as sct -import sqlalchemy -from singer_sdk import typing as th -from singer_sdk.batch import lazy_chunked_generator -from singer_sdk.connectors import SQLConnector -from singer_sdk.helpers._batch import BaseBatchFileEncoding, BatchConfig +from singer_sdk import PluginBase, SQLConnector +from singer_sdk.batch import JSONLinesBatcher +from singer_sdk.helpers._batch import BatchConfig from singer_sdk.helpers._typing import conform_record_data_types from singer_sdk.sinks import SQLSink -from snowflake.sqlalchemy import URL -from sqlalchemy.engine import Engine -from sqlalchemy.sql import text + +from target_snowflake.connector import SnowflakeConnector DEFAULT_BATCH_CONFIG = { - "encoding": { - "format": "jsonl", - "compression": "gzip" - }, - "storage": { - "root": "file://" - } + "encoding": {"format": "jsonl", "compression": "gzip"}, + "storage": {"root": "file://"}, } -class TypeMap: - def __init__(self, operator, map_value, match_value=None): - self.operator = operator - self.map_value = map_value - self.match_value = match_value - def match(self, compare_value): - try: - if self.match_value: - return self.operator(compare_value, self.match_value) - return self.operator(compare_value) - except TypeError: - return False +class SnowflakeSink(SQLSink): + """Snowflake target sink class.""" + connector_class = SnowflakeConnector -def evaluate_typemaps(type_maps, compare_value, unmatched_value): - for type_map in type_maps: - if type_map.match(compare_value): - return type_map.map_value - return unmatched_value + def __init__( + self, + target: PluginBase, + stream_name: str, + schema: dict, + key_properties: list[str] | None, + connector: SQLConnector | None = None, + ) -> None: + """Initialize Snowflake Sink.""" + self.target = target + super().__init__( + target=target, + stream_name=stream_name, + schema=schema, + key_properties=key_properties, + connector=connector, + ) + @property + def schema_name(self) -> t.Optional[str]: + schema = super().schema_name or self.config.get("schema") + return schema.upper() if schema else None -class SnowflakeConnector(SQLConnector): - """The connector for Snowflake. + @property + def database_name(self) -> t.Optional[str]: + db = super().database_name or self.config.get("database") + return db.upper() if db else None - This class handles all DDL and type conversions. - """ + @property + def table_name(self) -> str: + return super().table_name.upper() - allow_column_add: bool = True # Whether ADD COLUMN is supported. - allow_column_rename: bool = True # Whether RENAME COLUMN is supported. - allow_column_alter: bool = True # Whether altering column types is supported. - allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. - allow_temp_tables: bool = True # Whether temp tables are supported. - column_cache = {} + def bulk_insert_records( + self, + full_table_name: str, + schema: dict, + records: t.Iterable[t.Dict[str, t.Any]], + ) -> t.Optional[int]: + """Bulk insert records to an existing destination table. - def column_exists(self, full_table_name: str, column_name: str) -> bool: - """Determine if the target table already exists. + The default implementation uses a generic SQLAlchemy bulk insert operation. + This method may optionally be overridden by developers in order to provide + faster, native bulk uploads. Args: full_table_name: the target table name. - column_name: the target column name. + schema: the JSON schema for the new table, to be used when inferring column + names. + records: the input records. Returns: - True if table exists, False if not. + True if table exists, False if not, None if unsure or undetectable. """ - if full_table_name not in self.column_cache: - self.column_cache[full_table_name] = self.get_table_columns( - full_table_name + # prepare records for serialization + processed_records = ( + conform_record_data_types( + stream_name=self.stream_name, + record=rcd, + schema=schema, + level="RECURSIVE", + logger=self.logger, ) - return column_name in self.column_cache[full_table_name] - - def schema_exists(self, schema_name: str) -> bool: - """Determine if the target database schema already exists. - - Args: - schema_name: The target database schema name. - - Returns: - True if the database schema exists, False if not. - """ - schema_names = sqlalchemy.inspect(self._engine).get_schema_names() - return schema_name.lower() in schema_names - - def get_sqlalchemy_url(self, config: dict) -> str: - """Generates a SQLAlchemy URL for Snowflake. - - Args: - config: The configuration for the connector. - """ - params = { - "account": config["account"], - "user": config["user"], - "password": config["password"], - "database": config["database"], - } - - for option in ["warehouse", "role"]: - if config.get(option): - params[option] = config.get(option) - - return URL(**params) - - def create_engine(self) -> Engine: - """Creates and returns a new engine. Do not call outside of _engine. - - NOTE: Do not call this method. The only place that this method should - be called is inside the self._engine method. If you'd like to access - the engine on a connector, use self._engine. - - This method exists solely so that tap/target developers can override it - on their subclass of SQLConnector to perform custom engine creation - logic. - - Returns: - A new SQLAlchemy Engine. - """ - return sqlalchemy.create_engine( - self.sqlalchemy_url, - connect_args={ - "session_parameters": { - "QUOTED_IDENTIFIERS_IGNORE_CASE": "TRUE", - } - }, - echo=False, + for rcd in records ) - @staticmethod - def get_column_alter_ddl( - table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine - ) -> sqlalchemy.DDL: - """Get the alter column DDL statement. - - Override this if your database uses a different syntax for altering columns. - - Args: - table_name: Fully qualified table name of column to alter. - column_name: Column name to alter. - column_type: New column type string. - - Returns: - A sqlalchemy DDL instance. - """ - return sqlalchemy.DDL( - "ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s SET DATA TYPE %(column_type)s", - { - "table_name": table_name, - "column_name": column_name, - "column_type": column_type, - }, + # serialize to batch files and upload + # TODO: support other batchers + batcher = JSONLinesBatcher( + tap_name=self.target.name, + stream_name=self.stream_name, + batch_config=self.batch_config, ) + batches = batcher.get_batches(records=processed_records) + for files in batches: + self.insert_batch_files_via_internal_stage( + full_table_name=full_table_name, + files=files, + ) + # if records list, we can quickly return record count. + return len(records) if isinstance(records, list) else None - @staticmethod - def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: - """Return a JSON Schema representation of the provided type. - - Uses custom Snowflake types from [snowflake-sqlalchemy](https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/src/snowflake/sqlalchemy/custom_types.py) - - Args: - jsonschema_type: The JSON Schema representation of the source type. - - Returns: - The SQLAlchemy type representation of the data type. - """ - # start with default implementation - target_type = SQLConnector.to_sql_type(jsonschema_type) - # snowflake max and default varchar length - # https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html - maxlength = jsonschema_type.get("maxLength", 16777216) - # define type maps - string_submaps = [ - TypeMap(eq, sct.TIMESTAMP_NTZ(), "date-time"), - TypeMap(contains, sqlalchemy.types.TIME(), "time"), - TypeMap(eq, sqlalchemy.types.DATE(), "date"), - TypeMap(eq, sqlalchemy.types.VARCHAR(maxlength), None), - ] - type_maps = [ - TypeMap(th._jsonschema_type_check, sct.NUMBER(), ("integer",)), - TypeMap(th._jsonschema_type_check, sct.VARIANT(), ("object",)), - TypeMap(th._jsonschema_type_check, sct.VARIANT(), ("array",)), - ] - # apply type maps - if th._jsonschema_type_check(jsonschema_type, ("string",)): - datelike_type = th.get_datelike_property_type(jsonschema_type) - target_type = evaluate_typemaps(string_submaps, datelike_type, target_type) - else: - target_type = evaluate_typemaps(type_maps, jsonschema_type, target_type) - - return cast(sqlalchemy.types.TypeEngine, target_type) - - -class SnowflakeSink(SQLSink): - """Snowflake target sink class.""" - - connector_class = SnowflakeConnector + # Custom methods to process batch files @property def batch_config(self) -> BatchConfig | None: @@ -215,88 +119,11 @@ def batch_config(self) -> BatchConfig | None: """ raw = self.config.get("batch_config", DEFAULT_BATCH_CONFIG) return BatchConfig.from_dict(raw) - - @property - def schema_name(self) -> Optional[str]: - schema = super().schema_name or self.config.get("schema") - return schema.upper() if schema else None - - @property - def database_name(self) -> Optional[str]: - db = super().database_name or self.config.get("database") - return db.upper() if db else None - - @property - def table_name(self) -> str: - return super().table_name.upper() - - def _get_put_statement(self, sync_id: str, file_uri: str) -> Tuple[text, dict]: - """Get Snowflake PUT statement.""" - return (text(f"put '{file_uri}' '@~/target-snowflake/{sync_id}'"), {}) - - def _get_merge_statement(self, full_table_name, schema, sync_id, file_format): - """Get Snowflake MERGE statement.""" - # convert from case in JSON to UPPER column name - column_selections = [ - f"$1:{property_name}::{self.connector.to_sql_type(property_def)} as {property_name.upper()}" - for property_name, property_def in schema["properties"].items() - ] - # use UPPER from here onwards - upper_properties = [col.upper() for col in schema["properties"].keys()] - upper_key_properties = [col.upper() for col in self.key_properties] - join_expr = " and ".join( - [f'd."{key}" = s."{key}"' for key in upper_key_properties] - ) - matched_clause = ", ".join( - [f'd."{col}" = s."{col}"' for col in upper_properties] - ) - not_matched_insert_cols = ", ".join(upper_properties) - not_matched_insert_values = ", ".join( - [f's."{col}"' for col in upper_properties] - ) - return ( - text( - f"merge into {full_table_name} d using " - + f"(select {', '.join(column_selections)} from '@~/target-snowflake/{sync_id}'" - + f"(file_format => {file_format})) s " - + f"on {join_expr} " - + f"when matched then update set {matched_clause} " - + f"when not matched then insert ({not_matched_insert_cols}) " - + f"values ({not_matched_insert_values})" - ), - {}, - ) - - def _get_copy_statement(self, full_table_name, schema, sync_id, file_format): - """Get Snowflake COPY statement.""" - # convert from case in JSON to UPPER column name - column_selections = [ - f"$1:{property_name}::{self.connector.to_sql_type(property_def)} as {property_name.upper()}" - for property_name, property_def in schema["properties"].items() - ] - return ( - text( - f"copy into {full_table_name} from " + - f"(select {', '.join(column_selections)} from " + - f"'@~/target-snowflake/{sync_id}')" + - f"file_format = (format_name='{file_format}')" - ), - {}, - ) - - def _get_file_format_statement(self, file_format): - return ( - text( - f"create or replace file format {file_format}" - + "type = 'JSON' compression = 'GZIP'" - ), - {}, - ) - def process_batch_files( + def insert_batch_files_via_internal_stage( self, - encoding: BaseBatchFileEncoding, - files: Sequence[str], + full_table_name: str, + files: t.Sequence[str], ) -> None: """Process a batch file with the given batch context. @@ -304,142 +131,39 @@ def process_batch_files( encoding: The batch file encoding. files: The batch files to process. """ - self.logger.info(f"Processing batch of {len(files)} files") + self.logger.info("Processing batch of files.") try: sync_id = f"{self.stream_name}-{uuid4()}" file_format = f'{self.database_name}.{self.schema_name}."{sync_id}"' - # PUT batches to remote stage - for file_uri in files: - put_statement, kwargs = self._get_put_statement( - sync_id=sync_id, file_uri=file_uri - ) - self.connector.connection.execute(put_statement, **kwargs) - # create schema - self.logger.debug("Preparing target schema") - self.connector.connection.execute( - text( - f"create schema if not exists {self.database_name}.{self.schema_name}" - ) - ) - # create file format in new schema - file_format_statement, kwargs = self._get_file_format_statement( - file_format=file_format - ) - self.logger.debug(f"Creating file format with SQL: {file_format_statement}") - self.connector.connection.execute(file_format_statement, **kwargs) + self.connector.put_batches_to_stage(sync_id=sync_id, files=files) + self.connector.prepare_schema(schema_name=self.schema_name) + self.connector.create_file_format(file_format=file_format) + if self.key_properties: # merge into destination table - merge_statement, kwargs = self._get_merge_statement( - full_table_name=self.full_table_name, + self.connector.merge_from_stage( + full_table_name=full_table_name, schema=self.conform_schema(self.schema), sync_id=sync_id, file_format=file_format, + key_properties=self.key_properties, ) - self.logger.info(f"Merging batch with SQL: {merge_statement}") - self.connector.connection.execute(merge_statement, **kwargs) + else: - copy_statement, kwargs = self._get_copy_statement( - full_table_name=self.full_table_name, + self.connector.copy_from_stage( + full_table_name=full_table_name, schema=self.conform_schema(self.schema), sync_id=sync_id, file_format=file_format, ) - self.logger.info(f"Copying batch with SQL: {copy_statement}") - self.connector.connection.execute(copy_statement, **kwargs) finally: - # clean up file format self.logger.debug("Cleaning up after batch processing") - self.connector.connection.execute( - text(f"drop file format if exists {file_format}") - ) - # clean up staged files - self.connector.connection.execute( - text(f"remove '@~/target-snowflake/{sync_id}/'") - ) + self.connector.drop_file_format(file_format=file_format) + self.connector.remove_staged_files(sync_id=sync_id) # clean up local files if self.config.get("clean_up_batch_files"): for file_url in files: file_path = urlparse(file_url).path if os.path.exists(file_path): os.remove(file_path) - - def bulk_insert_records( - self, - full_table_name: str, - schema: dict, - records: Iterable[Dict[str, Any]], - ) -> Optional[int]: - """Bulk insert records to an existing destination table. - - The default implementation uses a generic SQLAlchemy bulk insert operation. - This method may optionally be overridden by developers in order to provide - faster, native bulk uploads. - - Args: - full_table_name: the target table name. - schema: the JSON schema for the new table, to be used when inferring column - names. - records: the input records. - - Returns: - True if table exists, False if not, None if unsure or undetectable. - """ - # prepare records for serialization - processed_records = ( - conform_record_data_types( - stream_name=self.stream_name, - record=rcd, - schema=schema, - level="RECURSIVE", - logger=self.logger, - ) - for rcd in records - ) - # serialize to batch files - encoding, files = self.get_batches( - batch_config=self.batch_config, records=processed_records - ) - self.process_batch_files(encoding=encoding, files=files) - # if records list, we can quickly return record count. - return len(records) if isinstance(records, list) else None - - # Copied and modified from `singer_sdk.streams.core.Stream` - def get_batches( - self, - batch_config: BatchConfig, - records: Iterable[Dict[str, Any]], - ) -> Iterable[tuple[BaseBatchFileEncoding, list[str]]]: - """Batch generator function. - - Developers are encouraged to override this method to customize batching - behavior for databases, bulk APIs, etc. - - Args: - batch_config: Batch config for this stream. - context: Stream partition or context dictionary. - - Yields: - A tuple of (encoding, manifest) for each batch. - """ - sync_id = f"target-snowflake--{self.stream_name}-{uuid4()}" - prefix = batch_config.storage.prefix or "" - file_urls = [] - for i, chunk in enumerate( - lazy_chunked_generator( - records, - self.MAX_SIZE_DEFAULT, # no point being larger than the sink max size, as thats the max number of records that will arrive - ), - start=1, - ): - filename = f"{prefix}{sync_id}-{i}.json.gz" - with batch_config.storage.fs() as fs: - with fs.open(filename, "wb") as f: - # TODO: Determine compression from config. - with gzip.GzipFile(fileobj=f, mode="wb") as gz: - gz.writelines( - (json.dumps(record) + "\n").encode() for record in chunk - ) - file_urls.append(fs.geturl(filename)) - - return batch_config.encoding, file_urls