diff --git a/universal_transfer_operator/example_dags/example_universal_transfer_operator.py b/universal_transfer_operator/example_dags/example_universal_transfer_operator.py index 5382c8292..611002c5e 100644 --- a/universal_transfer_operator/example_dags/example_universal_transfer_operator.py +++ b/universal_transfer_operator/example_dags/example_universal_transfer_operator.py @@ -3,7 +3,7 @@ from airflow import DAG -from universal_transfer_operator.constants import TransferMode +from universal_transfer_operator.constants import FileType, TransferMode from universal_transfer_operator.datasets.file.base import File from universal_transfer_operator.datasets.table import Metadata, Table from universal_transfer_operator.integrations.fivetran import Connector, Destination, FiveTranOptions, Group @@ -30,6 +30,75 @@ ), ) + transfer_non_native_s3_to_sqlite = UniversalTransferOperator( + task_id="transfer_non_native_s3_to_sqlite", + source_dataset=File( + path="s3://astro-sdk-test/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV + ), + destination_dataset=Table(name="uto_s3_table", conn_id="sqlite_default"), + ) + + transfer_non_native_gs_to_sqlite = UniversalTransferOperator( + task_id="transfer_non_native_gs_to_sqlite", + source_dataset=File( + path="gs://uto-test/uto/csv_files/", conn_id="google_cloud_default", filetype=FileType.CSV + ), + destination_dataset=Table(name="uto_gs_table", conn_id="sqlite_default"), + ) + + transfer_non_native_s3_to_snowflake = UniversalTransferOperator( + task_id="transfer_non_native_s3_to_snowflake", + source_dataset=File( + path="s3://astro-sdk-test/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV + ), + destination_dataset=Table(name="uto_s3_table", conn_id="snowflake_default"), + ) + + transfer_non_native_gs_to_snowflake = UniversalTransferOperator( + task_id="transfer_non_native_gs_to_snowflake", + source_dataset=File( + path="gs://uto-test/uto/csv_files/", conn_id="google_cloud_default", filetype=FileType.CSV + ), + destination_dataset=Table(name="uto_gs_table", conn_id="snowflake_default"), + ) + + transfer_non_native_gs_to_bigquery = UniversalTransferOperator( + task_id="transfer_non_native_gs_to_bigquery", + source_dataset=File(path="gs://uto-test/uto/homes_main.csv", conn_id="google_cloud_default"), + destination_dataset=Table( + name="uto_gs_to_bigquery_table", conn_id="google_cloud_default", metadata=Metadata(schema="astro") + ), + ) + + transfer_non_native_s3_to_bigquery = UniversalTransferOperator( + task_id="transfer_non_native_s3_to_bigquery", + source_dataset=File( + path="s3://astro-sdk-test/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV + ), + destination_dataset=Table( + name="uto_s3_to_bigquery_table", conn_id="google_cloud_default", metadata=Metadata(schema="astro") + ), + ) + + transfer_non_native_bigquery_to_snowflake = UniversalTransferOperator( + task_id="transfer_non_native_bigquery_to_snowflake", + source_dataset=Table( + name="uto_s3_to_bigquery_table", conn_id="google_cloud_default", metadata=Metadata(schema="astro") + ), + destination_dataset=Table( + name="uto_bigquery_to_snowflake_table", + conn_id="snowflake_default", + ), + ) + + transfer_non_native_bigquery_to_sqlite = UniversalTransferOperator( + task_id="transfer_non_native_bigquery_to_sqlite", + source_dataset=Table( + name="uto_s3_to_bigquery_table", conn_id="google_cloud_default", metadata=Metadata(schema="astro") + ), + destination_dataset=Table(name="uto_bigquery_to_sqlite_table", conn_id="sqlite_default"), + ) + transfer_fivetran_with_connector_id = UniversalTransferOperator( task_id="transfer_fivetran_with_connector_id", source_dataset=File(path="s3://astro-sdk-test/uto/", conn_id="aws_default"), diff --git a/universal_transfer_operator/pyproject.toml b/universal_transfer_operator/pyproject.toml index 7a4d9028a..eed175647 100644 --- a/universal_transfer_operator/pyproject.toml +++ b/universal_transfer_operator/pyproject.toml @@ -54,7 +54,10 @@ google = [ snowflake = [ "apache-airflow-providers-snowflake", "snowflake-sqlalchemy>=1.2.0", - "snowflake-connector-python[pandas]", + "snowflake-connector-python[pandas]<3.0.0", + # pinning snowflake-connector-python[pandas]<3.0.0 due to a conflict in snowflake-connector-python/pyarrow/google + # packages and pandas-gbq/google packages which is forcing pandas-gbq of version 0.13.2 installed, which is not + # compatible with pandas 1.5.3 ] amazon = [ @@ -72,7 +75,10 @@ all = [ "apache-airflow-providers-google>=6.4.0", "apache-airflow-providers-snowflake", "smart-open[all]>=5.2.1", - "snowflake-connector-python[pandas]", + "snowflake-connector-python[pandas]<3.0.0", + # pinning snowflake-connector-python[pandas]<3.0.0 due to a conflict in snowflake-connector-python/pyarrow/google + # packages and pandas-gbq/google packages which is forcing pandas-gbq of version 0.13.2 installed, which is not + # compatible with pandas 1.5.3 "snowflake-sqlalchemy>=1.2.0", "sqlalchemy-bigquery>=1.3.0", "s3fs", diff --git a/universal_transfer_operator/src/universal_transfer_operator/constants.py b/universal_transfer_operator/src/universal_transfer_operator/constants.py index b8fb5a938..647d145d6 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/constants.py +++ b/universal_transfer_operator/src/universal_transfer_operator/constants.py @@ -99,3 +99,4 @@ def __repr__(self): LoadExistStrategy = Literal["replace", "append"] DEFAULT_CHUNK_SIZE = 1000000 ColumnCapitalization = Literal["upper", "lower", "original"] +DEFAULT_SCHEMA = "tmp_transfers" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py index 9c50929e4..775f62c06 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py @@ -5,12 +5,40 @@ from universal_transfer_operator.constants import TransferMode from universal_transfer_operator.data_providers.base import DataProviders from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Table from universal_transfer_operator.utils import TransferParameters, get_class_name -DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING = dict.fromkeys( - ["s3", "aws"], "universal_transfer_operator.data_providers.filesystem.aws.s3" -) | dict.fromkeys( - ["gs", "google_cloud_platform"], "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs" +DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING = ( + dict.fromkeys( + [("s3", File), ("aws", File)], "universal_transfer_operator.data_providers.filesystem.aws.s3" + ) + | dict.fromkeys( + [("gs", Table), ("google_cloud_platform", Table)], + "universal_transfer_operator.data_providers.database.google.bigquery", + ) + | dict.fromkeys( + [("gs", File), ("google_cloud_platform", File)], + "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs", + ) + | dict.fromkeys( + [ + ("sqlite", Table), + ], + "universal_transfer_operator.data_providers.database.sqlite", + ) + | dict.fromkeys( + [ + ("snowflake", Table), + ], + "universal_transfer_operator.data_providers.database.snowflake", + ) + | dict.fromkeys( + [ + (None, File), + ], + "universal_transfer_operator.data_providers.filesystem.local", + ) ) @@ -19,8 +47,12 @@ def create_dataprovider( transfer_params: TransferParameters = None, transfer_mode: TransferMode = TransferMode.NONNATIVE, ) -> DataProviders: - conn_type = BaseHook.get_connection(dataset.conn_id).conn_type - module_path = DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING[conn_type] + print(dataset) + if dataset.conn_id != "": + conn_type = BaseHook.get_connection(dataset.conn_id).conn_type + else: + conn_type = None + module_path = DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING[(conn_type, type(dataset))] module = importlib.import_module(module_path) class_name = get_class_name(module_ref=module, suffix="DataProvider") data_provider: DataProviders = getattr(module, class_name)( diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py index 8b50053f9..c7e95d1be 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py @@ -97,3 +97,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/__init__.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/base.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/base.py new file mode 100644 index 000000000..7ef85f092 --- /dev/null +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/base.py @@ -0,0 +1,676 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pandas as pd +import sqlalchemy + +if TYPE_CHECKING: # pragma: no cover + from sqlalchemy.engine.cursor import CursorResult + +import warnings + +import attr +from airflow.hooks.dbapi import DbApiHook +from pandas.io.sql import SQLDatabase +from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.schema import Table as SqlaTable + +from universal_transfer_operator.constants import ( + DEFAULT_CHUNK_SIZE, + ColumnCapitalization, + LoadExistStrategy, + Location, +) +from universal_transfer_operator.data_providers.base import DataProviders +from universal_transfer_operator.data_providers.filesystem import resolve_file_path_pattern +from universal_transfer_operator.data_providers.filesystem.base import FileStream +from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.datasets.dataframe.pandas import PandasDataframe +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SCHEMA +from universal_transfer_operator.universal_transfer_operator import TransferParameters +from universal_transfer_operator.utils import get_dataset_connection_type + + +class DatabaseDataProvider(DataProviders): + """DatabaseProviders represent all the DataProviders interactions with Databases.""" + + _create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {}" + _drop_table_statement: str = "DROP TABLE IF EXISTS {}" + _create_table_statement: str = "CREATE TABLE IF NOT EXISTS {} AS {}" + # Used to normalize the ndjson when appending fields in nested fields. + # Example - + # ndjson - {'a': {'b': 'val'}} + # the col names generated is 'a.b'. char '.' maybe an illegal char in some db's col name. + # Contains the illegal char and there replacement, where the value in + # illegal_column_name_chars[0] will be replaced by value in illegal_column_name_chars_replacement[0] + illegal_column_name_chars: list[str] = [] + illegal_column_name_chars_replacement: list[str] = [] + # In run_raw_sql operator decides if we want to return results directly or process them by handler provided + IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = False + NATIVE_PATHS: dict[Any, Any] = {} + DEFAULT_SCHEMA = SCHEMA + + def __init__( + self, + dataset: Table, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def sql_type(self): + raise NotImplementedError + + @property + def hook(self) -> DbApiHook: + """Return an instance of the database-specific Airflow hook.""" + raise NotImplementedError + + @property + def connection(self) -> sqlalchemy.engine.base.Connection: + """Return a Sqlalchemy connection object for the given database.""" + return self.sqlalchemy_engine.connect() + + @property + def sqlalchemy_engine(self) -> sqlalchemy.engine.base.Engine: + """Return Sqlalchemy engine.""" + return self.hook.get_sqlalchemy_engine() # type: ignore[no-any-return] + + @property + def transport_params(self) -> dict | None: # skipcq: PYL-R0201 + """Get credentials required by smart open to access files""" + return None + + def run_sql( + self, + sql: str | ClauseElement = "", + parameters: dict | None = None, + **kwargs, + ) -> CursorResult: + """ + Return the results to running a SQL statement. + + Whenever possible, this method should be implemented using Airflow Hooks, + since this will simplify the integration with Async operators. + + :param sql: Contains SQL query to be run against database + :param parameters: Optional parameters to be used to render the query + :param autocommit: Optional autocommit flag + """ + if parameters is None: + parameters = {} + + if "sql_statement" in kwargs: # pragma: no cover + warnings.warn( + "`sql_statement` is deprecated and will be removed in future release" + "Please use `sql` param instead.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get("sql_statement") # type: ignore + + # We need to autocommit=True to make sure the query runs. This is done exclusively for SnowflakeDatabase's + # truncate method to reflect changes. + if isinstance(sql, str): + result = self.connection.execute( + sqlalchemy.text(sql).execution_options(autocommit=True), parameters + ) + else: + result = self.connection.execute(sql, parameters) + return result + + def columns_exist(self, table: Table, columns: list[str]) -> bool: + """ + Check that a list of columns exist in the given table. + + :param table: The table to check in. + :param columns: The columns to check. + + :returns: whether the columns exist in the table or not. + """ + sqla_table = self.get_sqla_table(table) + return all( + any(sqla_column.name == column for sqla_column in sqla_table.columns) for column in columns + ) + + def get_sqla_table(self, table: Table) -> SqlaTable: + """ + Return SQLAlchemy table instance + + :param table: Astro Table to be converted to SQLAlchemy table instance + """ + return SqlaTable( + table.name, + table.sqlalchemy_metadata, + autoload_with=self.sqlalchemy_engine, + extend_existing=True, + ) + + def table_exists(self, table: Table) -> bool: + """ + Check if a table exists in the database. + + :param table: Details of the table we want to check that exists + """ + table_qualified_name = self.get_table_qualified_name(table) + inspector = sqlalchemy.inspect(self.sqlalchemy_engine) + return bool(inspector.dialect.has_table(self.connection, table_qualified_name)) + + def check_if_transfer_supported(self, source_dataset: Table) -> bool: + """ + Checks if the transfer is supported from source to destination based on source_dataset. + + :param source_dataset: Table present in the source location + """ + source_connection_type = get_dataset_connection_type(source_dataset) + return Location(source_connection_type) in self.transfer_mapping + + def read(self): + """ ""Read the dataset and write to local reference location""" + raise NotImplementedError + + def write(self, source_ref: FileStream): + """ + Write the data from local reference location to the dataset. + + :param source_ref: Stream of data to be loaded into output table. + """ + return self.load_file_to_table( + input_file=source_ref.actual_file, + output_table=self.dataset, + ) + + def load_data_from_source_natively(self, source_dataset: Table, destination_dataset: Dataset) -> None: + """ + Loads data from source dataset to the destination using data provider + """ + if not self.check_if_transfer_supported(source_dataset=source_dataset): + raise ValueError("Transfer not supported yet.") + + source_connection_type = get_dataset_connection_type(source_dataset) + method_name = self.LOAD_DATA_NATIVELY_FROM_SOURCE.get(source_connection_type) + if method_name: + transfer_method = self.__getattribute__(method_name) + return transfer_method( + source_dataset=source_dataset, + destination_dataset=destination_dataset, + ) + else: + raise ValueError(f"No transfer performed from {source_connection_type} to S3.") + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError + + # --------------------------------------------------------- + # Table metadata + # --------------------------------------------------------- + @staticmethod + def get_table_qualified_name(table: Table) -> str: # skipcq: PYL-R0201 + """ + Return table qualified name. This is Database-specific. + For instance, in Sqlite this is the table name. In Snowflake, however, it is the database, schema and table + + :param table: The table we want to retrieve the qualified name for. + """ + # Initially this method belonged to the Table class. + # However, in order to have an agnostic table class implementation, + # we are keeping all methods which vary depending on the database within the Database class. + if table.metadata and table.metadata.schema: + qualified_name = f"{table.metadata.schema}.{table.name}" + else: + qualified_name = table.name + return qualified_name + + @property + def default_metadata(self) -> Metadata: + """ + Extract the metadata available within the Airflow connection associated with + self.dataset.conn_id. + + :return: a Metadata instance + """ + raise NotImplementedError + + def populate_table_metadata(self, table: Table) -> Table: + """ + Given a table, check if the table has metadata. + If the metadata is missing, and the database has metadata, assign it to the table. + If the table schema was not defined by the end, retrieve the user-defined schema. + This method performs the changes in-place and also returns the table. + + :param table: Table to potentially have their metadata changed + :return table: Return the modified table + """ + if table.metadata and table.metadata.is_empty() and self.default_metadata: + table.metadata = self.default_metadata + if not table.metadata.schema: + table.metadata.schema = self.DEFAULT_SCHEMA + return table + + # --------------------------------------------------------- + # Table creation & deletion methods + # --------------------------------------------------------- + def create_table_using_columns(self, table: Table) -> None: + """ + Create a SQL table using the table columns. + + :param table: The table to be created. + """ + if not table.columns: + raise ValueError("To use this method, table.columns must be defined") + + metadata = table.sqlalchemy_metadata + sqlalchemy_table = sqlalchemy.Table(table.name, metadata, *table.columns) + metadata.create_all(self.sqlalchemy_engine, tables=[sqlalchemy_table]) + + def is_native_autodetect_schema_available( # skipcq: PYL-R0201 + self, file: File # skipcq: PYL-W0613 + ) -> bool: + """ + Check if native auto detection of schema is available. + + :param file: File used to check the file type of to decide + whether there is a native auto detection available for it. + """ + return False + + def create_table_using_native_schema_autodetection( + self, + table: Table, + file: File, + ) -> None: + """ + Create a SQL table, automatically inferring the schema using the given file via native database support. + + :param table: The table to be created. + :param file: File used to infer the new table columns. + """ + raise NotImplementedError("Missing implementation of native schema autodetection.") + + def create_table_using_schema_autodetection( + self, + table: Table, + file: File | None = None, + dataframe: pd.DataFrame | None = None, + columns_names_capitalization: ColumnCapitalization = "original", # skipcq + ) -> None: + """ + Create a SQL table, automatically inferring the schema using the given file. + + :param table: The table to be created. + :param file: File used to infer the new table columns. + :param dataframe: Dataframe used to infer the new table columns if there is no file + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + """ + if file is None: + if dataframe is None: + raise ValueError( + "File or Dataframe is required for creating table using schema autodetection" + ) + source_dataframe = dataframe + else: + source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT) + + db = SQLDatabase(engine=self.sqlalchemy_engine) + db.prep_table( + source_dataframe, + table.name.lower(), + schema=table.metadata.schema, + if_exists="replace", + index=False, + ) + + def create_table( + self, + table: Table, + file: File | None = None, + dataframe: pd.DataFrame | None = None, + columns_names_capitalization: ColumnCapitalization = "original", + use_native_support: bool = True, + ) -> None: + """ + Create a table either using its explicitly defined columns or inferring + it's columns from a given file. + + :param table: The table to be created + :param file: (optional) File used to infer the table columns. + :param dataframe: (optional) Dataframe used to infer the new table columns if there is no file + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + """ + if table.columns: + self.create_table_using_columns(table) + elif use_native_support and file and self.is_native_autodetect_schema_available(file): + self.create_table_using_native_schema_autodetection(table, file) + else: + self.create_table_using_schema_autodetection( + table, + file=file, + dataframe=dataframe, + columns_names_capitalization=columns_names_capitalization, + ) + + def create_table_from_select_statement( + self, + statement: str, + target_table: Table, + parameters: dict | None = None, + ) -> None: + """ + Export the result rows of a query statement into another table. + + :param statement: SQL query statement + :param target_table: Destination table where results will be recorded. + :param parameters: (Optional) parameters to be used to render the SQL query + """ + statement = self._create_table_statement.format( + self.get_table_qualified_name(target_table), statement + ) + self.run_sql(statement, parameters) + + def drop_table(self, table: Table) -> None: + """ + Delete a SQL table, if it exists. + + :param table: The table to be deleted. + """ + statement = self._drop_table_statement.format(self.get_table_qualified_name(table)) + self.run_sql(statement) + + # --------------------------------------------------------- + # Table load methods + # --------------------------------------------------------- + + def create_schema_and_table_if_needed( + self, + table: Table, + file: File, + normalize_config: dict | None = None, + columns_names_capitalization: ColumnCapitalization = "original", + if_exists: LoadExistStrategy = "replace", + use_native_support: bool = True, + ): + """ + Checks if the autodetect schema exists for native support else creates the schema and table + :param table: Table to create + :param file: File path and conn_id for object stores + :param normalize_config: pandas json_normalize params config + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + :param if_exists: Overwrite file if exists + :param use_native_support: Use native support for data transfer if available on the destination + """ + is_schema_autodetection_supported = self.check_schema_autodetection_is_supported(source_file=file) + is_file_pattern_based_schema_autodetection_supported = ( + self.check_file_pattern_based_schema_autodetection_is_supported(source_file=file) + ) + if if_exists == "replace": + self.drop_table(table) + if use_native_support and is_schema_autodetection_supported and not file.is_pattern(): + return + if ( + use_native_support + and file.is_pattern() + and is_schema_autodetection_supported + and is_file_pattern_based_schema_autodetection_supported + ): + return + + self.create_schema_if_needed(table.metadata.schema) + if if_exists == "replace" or not self.table_exists(table): + files = resolve_file_path_pattern( + file, + normalize_config=normalize_config, + filetype=file.type.name, + transfer_params=self.transfer_params, + transfer_mode=self.transfer_mode, + ) + self.create_table( + table, + # We only use the first file for inferring the table schema + files[0], + columns_names_capitalization=columns_names_capitalization, + use_native_support=use_native_support, + ) + + def fetch_all_rows(self, table: Table, row_limit: int = -1) -> list: + """ + Fetches all rows for a table and returns as a list. This is needed because some + databases have different cursors that require different methods to fetch rows + + :param row_limit: Limit the number of rows returned, by default return all rows. + :param table: The table metadata needed to fetch the rows + :return: a list of rows + """ + statement = f"SELECT * FROM {self.get_table_qualified_name(table)}" # skipcq: BAN-B608 + if row_limit > -1: + statement = statement + f" LIMIT {row_limit}" # skipcq: BAN-B608 + response = self.run_sql(statement) + return response.fetchall() # type: ignore + + def load_file_to_table( + self, + input_file: File, + output_table: Table, + normalize_config: dict | None = None, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + columns_names_capitalization: ColumnCapitalization = "original", + **kwargs, + ): + """ + Load content of multiple files in output_table. + Multiple files are sourced from the file path, which can also be path pattern. + + :param input_file: File path and conn_id for object stores + :param output_table: Table to create + :param if_exists: Overwrite file if exists + :param chunk_size: Specify the number of records in each batch to be written at a time + :param use_native_support: Use native support for data transfer if available on the destination + :param normalize_config: pandas json_normalize params config + :param native_support_kwargs: kwargs to be used by method involved in native support flow + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + :param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer + """ + normalize_config = normalize_config or {} + + self.create_schema_and_table_if_needed( + file=input_file, + table=output_table, + columns_names_capitalization=columns_names_capitalization, + if_exists=if_exists, + normalize_config=normalize_config, + ) + self.load_file_to_table_using_pandas( + input_file=input_file, + output_table=output_table, + normalize_config=normalize_config, + if_exists="append", + chunk_size=chunk_size, + ) + + def load_file_to_table_using_pandas( + self, + input_file: File, + output_table: Table, + normalize_config: dict | None = None, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + ): + input_files = resolve_file_path_pattern( + file=input_file, + normalize_config=normalize_config, + filetype=input_file.type.name, + transfer_params=self.transfer_params, + transfer_mode=self.transfer_mode, + ) + + for file in input_files: + self.load_pandas_dataframe_to_table( + self.get_dataframe_from_file(file), + output_table, + chunk_size=chunk_size, + if_exists=if_exists, + ) + + def load_pandas_dataframe_to_table( + self, + source_dataframe: pd.DataFrame, + target_table: Table, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> None: + """ + Create a table with the dataframe's contents. + If the table already exists, append or replace the content, depending on the value of `if_exists`. + + :param source_dataframe: Local or remote filepath + :param target_table: Table in which the file will be loaded + :param if_exists: Strategy to be used in case the target table already exists. + :param chunk_size: Specify the number of rows in each batch to be written at a time. + """ + self._assert_not_empty_df(source_dataframe) + + source_dataframe.to_sql( + self.get_table_qualified_name(target_table), + con=self.sqlalchemy_engine, + if_exists=if_exists, + chunksize=chunk_size, + method="multi", + index=False, + ) + + @staticmethod + def _assert_not_empty_df(df): + """Raise error if dataframe empty + + param df: A dataframe + """ + if df.empty: + raise ValueError("Can't load empty dataframe") + + @staticmethod + def get_dataframe_from_file(file: File): + """ + Get pandas dataframe file. We need export_to_dataframe() for Biqqery,Snowflake and Redshift except for Postgres. + For postgres we are overriding this method and using export_to_dataframe_via_byte_stream(). + export_to_dataframe_via_byte_stream copies files in a buffer and then use that buffer to ingest data. + With this approach we have significant performance boost for postgres. + + :param file: File path and conn_id for object stores + """ + + return file.export_to_dataframe() + + def check_schema_autodetection_is_supported( # skipcq: PYL-R0201 + self, source_file: File # skipcq: PYL-W0613 + ) -> bool: + """ + Checks if schema autodetection is handled natively by the database. Return False by default. + + :param source_file: File from which we need to transfer data + """ + return False + + def check_file_pattern_based_schema_autodetection_is_supported( # skipcq: PYL-R0201 + self, source_file: File # skipcq: PYL-W0613 + ) -> bool: + """ + Checks if schema autodetection is handled natively by the database for file + patterns and prefixes. Return False by default. + + :param source_file: File from which we need to transfer data + """ + return False + + def row_count(self, table: Table): + """ + Returns the number of rows in a table. + + :param table: table to count + :return: The number of rows in the table + """ + result = self.run_sql( + f"select count(*) from {self.get_table_qualified_name(table)}" # skipcq: BAN-B608 + ).scalar() + return result + + # --------------------------------------------------------- + # Schema Management + # --------------------------------------------------------- + + def create_schema_if_needed(self, schema: str | None) -> None: + """ + This function checks if the expected schema exists in the database. If the schema does not exist, + it will attempt to create it. + + :param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc) + """ + # We check if the schema exists first because snowflake will fail on a create schema query even if it + # doesn't actually create a schema. + if schema and not self.schema_exists(schema): + statement = self._create_schema_statement.format(schema) + self.run_sql(statement) + + def schema_exists(self, schema: str) -> bool: + """ + Checks if a schema exists in the database + + :param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc) + """ + raise NotImplementedError + + # --------------------------------------------------------- + # Extract methods + # --------------------------------------------------------- + + def export_table_to_pandas_dataframe(self) -> pd.DataFrame: + """ + Copy the content of a table to an in-memory Pandas dataframe. + """ + + if self.table_exists(self.dataset): + ValueError(f"The table {self.dataset.name} does not exist") + + sqla_table = self.get_sqla_table(self.dataset) + df = pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine) + return PandasDataframe.from_pandas_df(df) diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/__init__.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/bigquery.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/bigquery.py new file mode 100644 index 000000000..17fde5ba2 --- /dev/null +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/bigquery.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from tempfile import NamedTemporaryFile + +import attr +import pandas as pd +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from google.api_core.exceptions import ( + NotFound as GoogleNotFound, +) +from sqlalchemy import create_engine +from sqlalchemy.engine.base import Engine + +from universal_transfer_operator.constants import DEFAULT_CHUNK_SIZE, LoadExistStrategy +from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider, FileStream +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.settings import BIGQUERY_SCHEMA, BIGQUERY_SCHEMA_LOCATION +from universal_transfer_operator.universal_transfer_operator import TransferParameters + + +class BigqueryDataProvider(DatabaseDataProvider): + """SnowflakeDataProvider represent all the DataProviders interactions with Snowflake Databases.""" + + DEFAULT_SCHEMA = BIGQUERY_SCHEMA + + illegal_column_name_chars: list[str] = ["."] + illegal_column_name_chars_replacement: list[str] = ["_"] + + _create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {} OPTIONS (location='{}')" + + def __init__( + self, + dataset: Table, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def sql_type(self) -> str: + return "bigquery" + + @property + def hook(self) -> BigQueryHook: + """Retrieve Airflow hook to interface with the Snowflake database.""" + return BigQueryHook( + gcp_conn_id=self.dataset.conn_id, use_legacy_sql=False, location=BIGQUERY_SCHEMA_LOCATION + ) + + @property + def sqlalchemy_engine(self) -> Engine: + """Return SQAlchemy engine.""" + uri = self.hook.get_uri() + with self.hook.provide_gcp_credential_file_as_context(): + return create_engine(uri) + + @property + def default_metadata(self) -> Metadata: + """ + Fill in default metadata values for table objects addressing snowflake databases + """ + return Metadata( + schema=self.DEFAULT_SCHEMA, + database=self.hook.project_id, + ) # type: ignore + + def read(self): + """Read the dataset and write to local reference location""" + + with NamedTemporaryFile(mode="w", suffix=".parquet", delete=False) as tmp_file: + df = self.export_table_to_pandas_dataframe() + df.to_parquet(tmp_file.name) + local_temp_file = FileStream( + remote_obj_buffer=tmp_file.file, + actual_filename=tmp_file.name, + actual_file=File(path=tmp_file.name), + ) + yield local_temp_file + + def write(self, source_ref: FileStream): + """ + Write the data from local reference location to the dataset + + :param source_ref: Stream of data to be loaded into snowflake table. + """ + return self.load_file_to_table( + input_file=source_ref.actual_file, + output_table=self.dataset, + ) + + # --------------------------------------------------------- + # Table metadata + # --------------------------------------------------------- + + def schema_exists(self, schema: str) -> bool: + """ + Checks if a dataset exists in the BigQuery + + :param schema: Bigquery namespace + """ + try: + self.hook.get_dataset(dataset_id=schema) + except GoogleNotFound: + # google.api_core.exceptions throws when a resource is not found + return False + return True + + def _get_schema_location(self, schema: str | None = None) -> str: + """ + Get region where the schema is created + + :param schema: Bigquery namespace + """ + if schema is None: + return "" + try: + dataset = self.hook.get_dataset(dataset_id=schema) + return str(dataset.location) + except GoogleNotFound: + # google.api_core.exceptions throws when a resource is not found + return "" + + def load_pandas_dataframe_to_table( + self, + source_dataframe: pd.DataFrame, + target_table: Table, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> None: + """ + Create a table with the dataframe's contents. + If the table already exists, append or replace the content, depending on the value of `if_exists`. + + :param source_dataframe: Local or remote filepath + :param target_table: Table in which the file will be loaded + :param if_exists: Strategy to be used in case the target table already exists. + :param chunk_size: Specify the number of rows in each batch to be written at a time. + """ + self._assert_not_empty_df(source_dataframe) + + try: + creds = self.hook._get_credentials() # skipcq PYL-W021 + except AttributeError: + # Details: https://github.com/astronomer/astro-sdk/issues/703 + creds = self.hook.get_credentials() + source_dataframe.to_gbq( + self.get_table_qualified_name(target_table), + if_exists=if_exists, + chunksize=chunk_size, + project_id=self.hook.project_id, + credentials=creds, + ) + + def create_schema_if_needed(self, schema: str | None) -> None: + """ + This function checks if the expected schema exists in the database. If the schema does not exist, + it will attempt to create it. + + :param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc) + """ + # We check if the schema exists first because BigQuery will fail on a create schema query even if it + # doesn't actually create a schema. + if schema and not self.schema_exists(schema): + table_schema = self.dataset.metadata.schema if self.dataset and self.dataset.metadata else None + table_location = self._get_schema_location(table_schema) + + location = table_location or BIGQUERY_SCHEMA_LOCATION + statement = self._create_schema_statement.format(schema, location) + self.run_sql(statement) + + def truncate_table(self, table): + """Truncate table""" + self.run_sql(f"TRUNCATE {self.get_table_qualified_name(table)}") + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: db_name.schema_name.table_name + """ + dataset = self.dataset.metadata.database or self.dataset.metadata.schema + return f"{self.hook.project_id}.{dataset}.{self.dataset.name}" + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: snowflake://ACCOUNT + """ + return self.sql_type + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name()}" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/snowflake.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/snowflake.py new file mode 100644 index 000000000..47eb6ac1e --- /dev/null +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/snowflake.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +from typing import Sequence + +import attr +import pandas as pd +from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from snowflake.connector import pandas_tools + +from universal_transfer_operator.constants import DEFAULT_CHUNK_SIZE, ColumnCapitalization, LoadExistStrategy +from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider, FileStream +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SNOWFLAKE_SCHEMA +from universal_transfer_operator.universal_transfer_operator import TransferParameters + + +class SnowflakeDataProvider(DatabaseDataProvider): + """SnowflakeDataProvider represent all the DataProviders interactions with Snowflake Databases.""" + + DEFAULT_SCHEMA = SNOWFLAKE_SCHEMA + + def __init__( + self, + dataset: Table, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def sql_type(self) -> str: + return "snowflake" + + @property + def hook(self) -> SnowflakeHook: + """Retrieve Airflow hook to interface with the Snowflake database.""" + kwargs = {} + _hook = SnowflakeHook(snowflake_conn_id=self.dataset.conn_id) + if self.dataset and self.dataset.metadata: + if _hook.database is None and self.dataset.metadata.database: + kwargs.update({"database": self.dataset.metadata.database}) + if _hook.schema is None and self.dataset.metadata.schema: + kwargs.update({"schema": self.dataset.metadata.schema}) + return SnowflakeHook(snowflake_conn_id=self.dataset.conn_id, **kwargs) + + @property + def default_metadata(self) -> Metadata: + """ + Fill in default metadata values for table objects addressing snowflake databases + """ + connection = self.hook.get_conn() + return Metadata( # type: ignore + schema=connection.schema, + database=connection.database, + ) + + def read(self): + """ ""Read the dataset and write to local reference location""" + raise NotImplementedError + + def write(self, source_ref: FileStream): + """ + Write the data from local reference location to the dataset + + :param source_ref: Stream of data to be loaded into snowflake table. + """ + return self.load_file_to_table( + input_file=source_ref.actual_file, + output_table=self.dataset, + ) + + # --------------------------------------------------------- + # Table metadata + # --------------------------------------------------------- + @staticmethod + def get_table_qualified_name(table: Table) -> str: # skipcq: PYL-R0201 + """ + Return table qualified name. In Snowflake, it is the database, schema and table + + :param table: The table we want to retrieve the qualified name for. + """ + qualified_name_lists = [ + table.metadata.database, + table.metadata.schema, + table.name, + ] + qualified_name = ".".join(name for name in qualified_name_lists if name) + return qualified_name + + def schema_exists(self, schema: str) -> bool: + """ + Checks if a schema exists in the database + + :param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc) + """ + + # Below code is added due to breaking change in apache-airflow-providers-snowflake==3.2.0, + # we need to pass handler param to get the rows. But in version apache-airflow-providers-snowflake==3.1.0 + # if we pass the handler provider raises an exception AttributeError 'sfid'. + try: + schemas = self.hook.run( + "SELECT SCHEMA_NAME from information_schema.schemata WHERE LOWER(SCHEMA_NAME) = %(schema_name)s;", + parameters={"schema_name": schema.lower()}, + handler=lambda cur: cur.fetchall(), + ) + except AttributeError: + schemas = self.hook.run( + "SELECT SCHEMA_NAME from information_schema.schemata WHERE LOWER(SCHEMA_NAME) = %(schema_name)s;", + parameters={"schema_name": schema.lower()}, + ) + try: + # Handle case for apache-airflow-providers-snowflake<4.0.1 + created_schemas = [x["SCHEMA_NAME"] for x in schemas] + except TypeError: + # Handle case for apache-airflow-providers-snowflake>=4.0.1 + created_schemas = [x[0] for x in schemas] + return len(created_schemas) == 1 + + def create_table_using_schema_autodetection( + self, + table: Table, + file: File | None = None, + dataframe: pd.DataFrame | None = None, + columns_names_capitalization: ColumnCapitalization = "original", + ) -> None: # skipcq PYL-W0613 + """ + Create a SQL table, automatically inferring the schema using the given file. + Overriding default behaviour and not using the `prep_table` since it doesn't allow the adding quotes. + + :param table: The table to be created. + :param file: File used to infer the new table columns. + :param dataframe: Dataframe used to infer the new table columns if there is no file + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + """ + if file is None: + if dataframe is None: + raise ValueError( + "File or Dataframe is required for creating table using schema autodetection" + ) + source_dataframe = dataframe + else: + source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT) + + # We are changing the case of table name to ease out on the requirements to add quotes in raw queries. + # ToDO - Currently, we cannot to append using load_file to a table name which is having name in lower case. + pandas_tools.write_pandas( + conn=self.hook.get_conn(), + df=source_dataframe, + table_name=table.name.upper(), + schema=table.metadata.schema, + database=table.metadata.database, + chunk_size=DEFAULT_CHUNK_SIZE, + quote_identifiers=self.use_quotes(source_dataframe), + auto_create_table=True, + ) + # We are truncating since we only expect table to be created with required schema. + # Since this method is used by both native and pandas path we cannot skip this step. + self.truncate_table(table) + + def load_pandas_dataframe_to_table( + self, + source_dataframe: pd.DataFrame, + target_table: Table, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> None: + """ + Create a table with the dataframe's contents. + If the table already exists, append or replace the content, depending on the value of `if_exists`. + + :param source_dataframe: Local or remote filepath + :param target_table: Table in which the file will be loaded + :param if_exists: Strategy to be used in case the target table already exists. + :param chunk_size: Specify the number of rows in each batch to be written at a time. + """ + self._assert_not_empty_df(source_dataframe) + + auto_create_table = False + if not self.table_exists(target_table): + auto_create_table = True + elif if_exists == "replace": + self.create_table(target_table, dataframe=source_dataframe) + + # We are changing the case of table name to ease out on the requirements to add quotes in raw queries. + # ToDO - Currently, we cannot to append using load_file to a table name which is having name in lower case. + pandas_tools.write_pandas( + conn=self.hook.get_conn(), + df=source_dataframe, + table_name=target_table.name.upper(), + schema=target_table.metadata.schema, + database=target_table.metadata.database, + chunk_size=chunk_size, + quote_identifiers=self.use_quotes(source_dataframe), + auto_create_table=auto_create_table, + ) + + def truncate_table(self, table): + """Truncate table""" + self.run_sql(f"TRUNCATE {self.get_table_qualified_name(table)}") + + @classmethod + def use_quotes(cls, cols: Sequence[str]) -> bool: + """ + With snowflake identifier we have two cases, + + 1. When Upper/Mixed case col names are used + We are required to preserver the text casing of the col names. By adding the quotes around identifier. + 2. When lower case col names are used + We can use them as is + + This is done to be in sync with Snowflake SQLAlchemy dialect. + https://docs.snowflake.com/en/user-guide/sqlalchemy.html#object-name-case-handling + + Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all + lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during + schema-level communication (i.e. during table and index reflection). If you use uppercase object names, + SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause + mismatches against data dictionary data received from Snowflake, so unless identifier names have been truly + created as case sensitive using quotes (e.g. "TestDb"), all lowercase names should be used on the SQLAlchemy + side. + + :param cols: list of columns + """ + return any(col for col in cols if not col.islower() and not col.isupper()) + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: db_name.schema_name.table_name + """ + conn = self.hook.get_connection(self.dataset.conn_id) + conn_extra = conn.extra_dejson + schema = conn_extra.get("schema") or conn.schema + db = conn_extra.get("database") + return f"{db}.{schema}.{self.dataset.name}" + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: snowflake://ACCOUNT + """ + account = self.hook.get_connection(self.dataset.conn_id).extra_dejson.get("account") + return f"{self.sql_type}://{account}" + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name()}" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/sqlite.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/sqlite.py new file mode 100644 index 000000000..4470b496a --- /dev/null +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/sqlite.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import socket + +import attr +from airflow.providers.sqlite.hooks.sqlite import SqliteHook +from sqlalchemy import MetaData as SqlaMetaData, create_engine +from sqlalchemy.engine.base import Engine +from sqlalchemy.sql.schema import Table as SqlaTable + +from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider, FileStream +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.universal_transfer_operator import TransferParameters + + +class SqliteDataProvider(DatabaseDataProvider): + """SqliteDataProvider represent all the DataProviders interactions with Sqlite Databases.""" + + def __init__( + self, + dataset: Table, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def sql_type(self) -> str: + return "sqlite" + + @property + def hook(self) -> SqliteHook: + """Retrieve Airflow hook to interface with the Sqlite database.""" + return SqliteHook(sqlite_conn_id=self.dataset.conn_id) + + @property + def sqlalchemy_engine(self) -> Engine: + """Return SQAlchemy engine.""" + # Airflow uses sqlite3 library and not SqlAlchemy for SqliteHook + # and it only uses the hostname directly. + airflow_conn = self.hook.get_connection(self.dataset.conn_id) + return create_engine(f"sqlite:///{airflow_conn.host}") + + @property + def default_metadata(self) -> Metadata: + """Since Sqlite does not use Metadata, we return an empty Metadata instances.""" + return Metadata() + + def read(self): + """ ""Read the dataset and write to local reference location""" + raise NotImplementedError + + def write(self, source_ref: FileStream): + """Write the data from local reference location to the dataset""" + return self.load_file_to_table( + input_file=source_ref.actual_file, + output_table=self.dataset, + ) + + # --------------------------------------------------------- + # Table metadata + # --------------------------------------------------------- + @staticmethod + def get_table_qualified_name(table: Table) -> str: + """ + Return the table qualified name. + + :param table: The table we want to retrieve the qualified name for. + """ + return str(table.name) + + def populate_table_metadata(self, table: Table) -> Table: + """ + Since SQLite does not have a concept of databases or schemas, we just return the table as is, + without any modifications. + """ + table.conn_id = table.conn_id or self.dataset.conn_id + return table + + def create_schema_if_needed(self, schema: str | None) -> None: + """ + Since SQLite does not have schemas, we do not need to set a schema here. + """ + + def schema_exists(self, schema: str) -> bool: # skipcq PYL-W0613,PYL-R0201 + """ + Check if a schema exists. We return false for sqlite since sqlite does not have schemas + """ + return False + + def get_sqla_table(self, table: Table) -> SqlaTable: + """ + Return SQLAlchemy table instance + + :param table: Astro Table to be converted to SQLAlchemy table instance + """ + return SqlaTable(table.name, SqlaMetaData(), autoload_with=self.sqlalchemy_engine) + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: /tmp/local.db.table_name + """ + conn = self.hook.get_connection(self.dataset.conn_id) + return f"{conn.host}.{self.dataset.name}" + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: file://127.0.0.1:22 + """ + conn = self.hook.get_connection(self.dataset.conn_id) + port = conn.port or 22 + return f"file://{socket.gethostbyname(socket.gethostname())}:{port}" + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name()}" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/__init__.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/__init__.py index e69de29bb..84460553f 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/__init__.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/__init__.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from universal_transfer_operator.constants import FileType, TransferMode +from universal_transfer_operator.data_providers import create_dataprovider +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.utils import TransferParameters + + +def resolve_file_path_pattern( + file: File, + filetype: FileType | None = None, + normalize_config: dict | None = None, + transfer_params: TransferParameters = None, + transfer_mode: TransferMode = TransferMode.NONNATIVE, +) -> list[File]: + """get file objects by resolving path_pattern from local/object stores + path_pattern can be + 1. local location - glob pattern + 2. s3/gcs location - prefix + + :param file: File dataset object + :param filetype: constant to provide an explicit file type + :param normalize_config: parameters in dict format of pandas json_normalize() function + :param transfer_params: kwargs to be used by method involved in transfer flow. + :param transfer_mode: Use transfer_mode TransferMode; native, non-native or thirdparty. + """ + location = create_dataprovider( + dataset=file, + transfer_params=transfer_params, + transfer_mode=transfer_mode, + ) + files = [] + for path in location.paths: + if not path.endswith("/"): + file = File( + path=path, + conn_id=file.conn_id, + filetype=filetype, + normalize_config=normalize_config, + ) + files.append(file) + if len(files) == 0: + raise FileNotFoundError(f"File(s) not found for path/pattern '{file.path}'") + return files diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py index 75f149ba1..46f4a6bef 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py @@ -174,3 +174,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py index 95787195b..519768612 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py @@ -28,6 +28,7 @@ class TempFile: class FileStream: remote_obj_buffer: io.IOBase actual_filename: Path + actual_file: File class BaseFilesystemProviders(DataProviders): @@ -90,7 +91,9 @@ def read_using_smart_open(self): files = self.paths for file in files: yield FileStream( - remote_obj_buffer=self._convert_remote_file_to_byte_stream(file), actual_filename=file + remote_obj_buffer=self._convert_remote_file_to_byte_stream(file), + actual_filename=file, + actual_file=self.dataset, ) def _convert_remote_file_to_byte_stream(self, file: str) -> io.IOBase: @@ -181,3 +184,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py index 781ff6821..709baa504 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py @@ -59,9 +59,10 @@ def transport_params(self) -> dict: def paths(self) -> list[str]: """Resolve GS file paths with prefix""" url = urlparse(self.dataset.path) + prefix = url.path[1:] prefixes = self.hook.list( bucket_name=self.bucket_name, # type: ignore - prefix=self.prefix, + prefix=prefix, delimiter=self.delimiter, ) paths = [urlunparse((url.scheme, url.netloc, keys, "", "", "")) for keys in prefixes] @@ -179,3 +180,11 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + raise NotImplementedError diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/local.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/local.py new file mode 100644 index 000000000..3cbf86280 --- /dev/null +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/local.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import glob +import os +import pathlib +from urllib.parse import urlparse + +from universal_transfer_operator.data_providers.filesystem.base import BaseFilesystemProviders + + +class LocalDataProvider(BaseFilesystemProviders): + """Handler Local file path operations""" + + @property + def paths(self) -> list[str]: + """Resolve local filepath""" + url = urlparse(self.dataset.path) + path_object = pathlib.Path(url.path) + if path_object.is_dir(): + paths = [str(filepath) for filepath in path_object.rglob("*")] + else: + paths = glob.glob(url.path) + return paths + + def validate_conn(self): + """Override as conn_id is not always required for local location.""" + + @property + def size(self) -> int: + """Return the size in bytes of the given file. + + :return: File size in bytes + """ + path = pathlib.Path(self.path) + return os.path.getsize(path) + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + return os.path.basename(self.path) + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + return urlparse(self.path).path diff --git a/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py b/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py index 2e8ee0bc9..a3e850ef7 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py +++ b/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py @@ -32,6 +32,12 @@ class File(Dataset): uri: str = field(init=False) extra: dict = field(init=True, factory=dict) + @property + def location(self): + from universal_transfer_operator.data_providers import create_dataprovider + + return create_dataprovider(dataset=self) + @property def size(self) -> int: """ diff --git a/universal_transfer_operator/src/universal_transfer_operator/datasets/table.py b/universal_transfer_operator/src/universal_transfer_operator/datasets/table.py index 68aedb807..2c19d87a1 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/datasets/table.py +++ b/universal_transfer_operator/src/universal_transfer_operator/datasets/table.py @@ -1,12 +1,17 @@ from __future__ import annotations -from urllib.parse import urlparse +import random +import string +from typing import Any from attr import define, field, fields_dict -from sqlalchemy import Column +from sqlalchemy import Column, MetaData from universal_transfer_operator.datasets.base import Dataset +MAX_TABLE_NAME_LENGTH = 62 +TEMP_PREFIX = "_tmp" + @define class Metadata: @@ -50,42 +55,110 @@ class Table(Dataset): uri: str = field(init=False) extra: dict = field(init=True, factory=dict) - @property - def sql_type(self): - raise NotImplementedError - def exists(self): """Check if the table exists or not""" raise NotImplementedError - def __str__(self) -> str: - return self.path + def _create_unique_table_name(self, prefix: str = "") -> str: + """ + If a table is instantiated without a name, create a unique table for it. + This new name should be compatible with all supported databases. + """ + schema_length = len((self.metadata and self.metadata.schema) or "") + 1 + prefix_length = len(prefix) + + unique_id = random.choice(string.ascii_lowercase) + "".join( + random.choice(string.ascii_lowercase + string.digits) + for _ in range(MAX_TABLE_NAME_LENGTH - schema_length - prefix_length) + ) + if prefix: + unique_id = f"{prefix}{unique_id}" - def __hash__(self) -> int: - return hash((self.path, self.conn_id)) + return unique_id - def dataset_scheme(self): + def create_similar_table(self) -> Table: """ - Return the scheme based on path + Create a new table with a unique name but with the same metadata. """ - parsed = urlparse(self.path) - return parsed.scheme + return Table( # type: ignore + name=self._create_unique_table_name(), + conn_id=self.conn_id, + metadata=self.metadata, + ) - def dataset_namespace(self): + @property + def sqlalchemy_metadata(self) -> MetaData: + """Return the Sqlalchemy metadata for the given table.""" + if self.metadata and self.metadata.schema: + alchemy_metadata = MetaData(schema=self.metadata.schema) + else: + alchemy_metadata = MetaData() + return alchemy_metadata + + @property + def row_count(self) -> Any: """ - The namespace of a dataset can be combined to form a URI (scheme:[//authority]path) + Return the row count of table. + """ + from universal_transfer_operator.data_providers import create_dataprovider + + database_provider = create_dataprovider(dataset=self) + return database_provider.row_count(self) - Namespace = scheme:[//authority] (the dataset) + @property + def sql_type(self) -> Any: + from universal_transfer_operator.data_providers import create_dataprovider + + if self.conn_id: + return create_dataprovider(dataset=self).sql_type + + def to_json(self): + return { + "class": "Table", + "name": self.name, + "metadata": { + "schema": self.metadata.schema, + "database": self.metadata.database, + }, + "temp": self.temp, + "conn_id": self.conn_id, + } + + @classmethod + def from_json(cls, obj: dict): + return Table( + name=obj["name"], + metadata=Metadata(**obj["metadata"]), + temp=obj["temp"], + conn_id=obj["conn_id"], + ) + + def openlineage_dataset_name(self) -> str: """ - parsed = urlparse(self.path) - namespace = f"{self.dataset_scheme()}://{parsed.netloc}" - return namespace + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + from universal_transfer_operator.data_providers import create_dataprovider + + database_provider = create_dataprovider(dataset=self) + return database_provider.openlineage_dataset_name(table=self) - def dataset_name(self): + def openlineage_dataset_namespace(self) -> str: """ - The name of a dataset can be combined to form a URI (scheme:[//authority]path) + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + from universal_transfer_operator.data_providers import create_dataprovider + + database_provider = create_dataprovider(dataset=self) + return database_provider.openlineage_dataset_namespace() - Name = path (the datasets) + def openlineage_dataset_uri(self) -> str: """ - parsed = urlparse(self.path) - return parsed.path if self.path else self.name + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + from universal_transfer_operator.data_providers import create_dataprovider + + database_provider = create_dataprovider(dataset=self) + return f"{database_provider.openlineage_dataset_uri(table=self)}" diff --git a/universal_transfer_operator/src/universal_transfer_operator/settings.py b/universal_transfer_operator/src/universal_transfer_operator/settings.py index 280d24138..57a4856a1 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/settings.py +++ b/universal_transfer_operator/src/universal_transfer_operator/settings.py @@ -4,6 +4,8 @@ from airflow.version import version as airflow_version from packaging.version import Version +from universal_transfer_operator.constants import DEFAULT_SCHEMA + # Section name for universal transfer operator configs in airflow.cfg SECTION_KEY = "universal_transfer_operator" @@ -23,3 +25,20 @@ # We only need PandasDataframe and other custom serialization and deserialization # if Airflow >= 2.5 and Pickling is not enabled and neither Custom XCom backend is used NEED_CUSTOM_SERIALIZATION = AIRFLOW_25_PLUS and IS_BASE_XCOM_BACKEND and not ENABLE_XCOM_PICKLING + +# Bigquery list of all the valid locations: https://cloud.google.com/bigquery/docs/locations +DEFAULT_BIGQUERY_SCHEMA_LOCATION = "us" +SCHEMA = conf.get(SECTION_KEY, "sql_schema", fallback=DEFAULT_SCHEMA) +POSTGRES_SCHEMA = conf.get(SECTION_KEY, "postgres_default_schema", fallback=SCHEMA) +BIGQUERY_SCHEMA = conf.get(SECTION_KEY, "bigquery_default_schema", fallback=SCHEMA) +SNOWFLAKE_SCHEMA = conf.get(SECTION_KEY, "snowflake_default_schema", fallback=SCHEMA) +REDSHIFT_SCHEMA = conf.get(SECTION_KEY, "redshift_default_schema", fallback=SCHEMA) +MSSQL_SCHEMA = conf.get(SECTION_KEY, "mssql_default_schema", fallback=SCHEMA) + +BIGQUERY_SCHEMA_LOCATION = conf.get( + SECTION_KEY, "bigquery_dataset_location", fallback=DEFAULT_BIGQUERY_SCHEMA_LOCATION +) + +LOAD_TABLE_AUTODETECT_ROWS_COUNT = conf.getint( + section=SECTION_KEY, key="load_table_autodetect_rows_count", fallback=1000 +) diff --git a/universal_transfer_operator/test/test_data_provider/test_data_provider.py b/universal_transfer_operator/test/test_data_provider/test_data_provider.py index e40e0f31a..e013460cb 100644 --- a/universal_transfer_operator/test/test_data_provider/test_data_provider.py +++ b/universal_transfer_operator/test/test_data_provider/test_data_provider.py @@ -3,6 +3,7 @@ from universal_transfer_operator.data_providers import create_dataprovider from universal_transfer_operator.data_providers.filesystem.aws.s3 import S3DataProvider from universal_transfer_operator.data_providers.filesystem.google.cloud.gcs import GCSDataProvider +from universal_transfer_operator.data_providers.filesystem.local import LocalDataProvider from universal_transfer_operator.datasets.file.base import File @@ -11,6 +12,7 @@ [ {"dataset": File("s3://astro-sdk-test/uto/", conn_id="aws_default"), "expected": S3DataProvider}, {"dataset": File("gs://uto-test/uto/", conn_id="google_cloud_default"), "expected": GCSDataProvider}, + {"dataset": File("/tmp/test", conn_id=""), "expected": LocalDataProvider}, ], ids=lambda d: d["dataset"].conn_id, ) diff --git a/universal_transfer_operator/test/test_filesystem/__init__.py b/universal_transfer_operator/test/test_filesystem/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/universal_transfer_operator/test/test_filesystem/test_local.py b/universal_transfer_operator/test/test_filesystem/test_local.py new file mode 100644 index 000000000..8021a5e2f --- /dev/null +++ b/universal_transfer_operator/test/test_filesystem/test_local.py @@ -0,0 +1,40 @@ +import os +import pathlib +import shutil +import uuid + +import pytest + +from universal_transfer_operator.data_providers.filesystem.local import LocalDataProvider +from universal_transfer_operator.datasets.file.base import File + +CWD = pathlib.Path(__file__).parent +DATA_DIR = str(CWD) + "/../../data/" + +LOCAL_FILEPATH = f"{CWD}/../../data/homes2.csv" +LOCAL_DIR = f"/tmp/{uuid.uuid4()}/" +LOCAL_DIR_FILE_1 = str(pathlib.Path(LOCAL_DIR, "file_1.txt")) +LOCAL_DIR_FILE_2 = str(pathlib.Path(LOCAL_DIR, "file_2.txt")) + + +@pytest.fixture() +def local_dir(): + """create temp dir""" + os.mkdir(LOCAL_DIR) + open(LOCAL_DIR_FILE_1, "a").close() + open(LOCAL_DIR_FILE_2, "a").close() + yield + shutil.rmtree(LOCAL_DIR) + + +def test_size(): + """Test get_size() of for local file.""" + dataset = File(path=LOCAL_DIR_FILE_1) + assert LocalDataProvider(dataset).size == 65 + + +def test_get_paths_with_local_dir(local_dir): # skipcq: PYL-W0612 + """with local filepath""" + dataset = File(path=LOCAL_DIR_FILE_1) + location = LocalDataProvider(dataset) + assert sorted(location.paths) == [LOCAL_DIR_FILE_1]