diff --git a/src/astro/databases/snowflake.py b/src/astro/databases/snowflake.py index 5e53960ef..d6c2fe0e8 100644 --- a/src/astro/databases/snowflake.py +++ b/src/astro/databases/snowflake.py @@ -1,17 +1,135 @@ """Snowflake database implementation.""" -from typing import Dict, List, Tuple +import logging +import random +import string +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple import pandas as pd from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook from pandas.io.sql import SQLDatabase from snowflake.connector import pandas_tools - -from astro.constants import DEFAULT_CHUNK_SIZE, LoadExistStrategy, MergeConflictStrategy +from snowflake.connector.errors import ProgrammingError + +from astro.constants import ( + DEFAULT_CHUNK_SIZE, + FileLocation, + FileType, + LoadExistStrategy, + MergeConflictStrategy, +) from astro.databases.base import BaseDatabase +from astro.files import File from astro.sql.table import Metadata, Table DEFAULT_CONN_ID = SnowflakeHook.default_conn_name +ASTRO_SDK_TO_SNOWFLAKE_FILE_FORMAT_MAP = { + FileType.CSV: "CSV", + FileType.NDJSON: "JSON", + FileType.PARQUET: "PARQUET", +} + +COPY_OPTIONS = { + FileType.CSV: "ON_ERROR=CONTINUE", + FileType.NDJSON: "MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE", + FileType.PARQUET: "MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE", +} + + +@dataclass +class SnowflakeStage: + """ + Dataclass which abstracts properties of a Snowflake Stage. + + Snowflake Stages are used to loading tables and unloading data from tables into files. + + Example: + + .. code-block:: python + + snowflake_stage = SnowflakeStage( + name="stage_name", + url="gcs://bucket/prefix", + metadata=Metadata(database="SNOWFLAKE_DATABASE", schema="SNOWFLAKE_SCHEMA"), + ) + + .. seealso:: + `Snowflake official documentation on stage creation + `_ + """ + + name: str = "" + _name: str = field(init=False, repr=False, default="") + url: str = "" + metadata: Metadata = field(default_factory=Metadata) + + @staticmethod + def _create_unique_name() -> str: + """ + Generate a valid Snowflake stage name. + + :return: unique stage name + """ + return ( + "stage_" + + random.choice(string.ascii_lowercase) + + "".join( + random.choice(string.ascii_lowercase + string.digits) for _ in range(7) + ) + ) + + def set_url_from_file(self, file: File) -> None: + """ + Given a file to be loaded/unloaded to from Snowflake, identifies its folder and + sets as self.url. + + It is also responsbile for adjusting any path specific requirements for Snowflake. + + :param file: File to be loaded/unloaded to from Snowflake + """ + # the stage URL needs to be the folder where the files are + # https://docs.snowflake.com/en/sql-reference/sql/create-stage.html#external-stage-parameters-externalstageparams + url = file.path[: file.path.rfind("/") + 1] + self.url = url.replace("gs://", "gcs://") + + @property # type: ignore + def name(self) -> str: + """ + Return either the user-defined name or auto-generated one. + + :return: stage name + :sphinx-autoapi-skip: + """ + if not self._name: + self._name = self._create_unique_name() + return self._name + + @name.setter + def name(self, value: str) -> None: + """ + Set the stage name. + + :param value: Stage name. + """ + if not isinstance(value, property) and value != self._name: + self._name = value + + @property + def qualified_name(self) -> str: + """ + Return stage qualified name. In Snowflake, it is the database, schema and table + + :return: Snowflake stage qualified name (e.g. database.schema.table) + """ + qualified_name_lists = [ + self.metadata.database, + self.metadata.schema, + self.name, + ] + qualified_name = ".".join(name for name in qualified_name_lists if name) + return qualified_name + class SnowflakeDatabase(BaseDatabase): """ @@ -20,6 +138,7 @@ class SnowflakeDatabase(BaseDatabase): """ def __init__(self, conn_id: str = DEFAULT_CONN_ID): + self.storage_integration: Optional[str] = None super().__init__(conn_id) @property @@ -57,6 +176,121 @@ def get_table_qualified_name(table: Table) -> str: # skipcq: PYL-R0201 qualified_name = ".".join(name for name in qualified_name_lists if name) return qualified_name + # --------------------------------------------------------- + # Snowflake stage methods + # --------------------------------------------------------- + + @staticmethod + def _create_stage_auth_sub_statement( + file: File, storage_integration: Optional[str] = None + ) -> str: + """ + Create authentication-related line for the Snowflake CREATE STAGE. + Raise an exception if it is not defined. + + :param file: File to be copied from/to using stage + :param storage_integration: Previously created Snowflake storage integration + :return: String containing line to be used for authentication on the remote storage + """ + + if storage_integration is not None: + auth = f"storage_integration = {storage_integration};" + else: + if file.location.location_type == FileLocation.GS: + raise ValueError( + "In order to create an stage for GCS, `storage_integration` is required." + ) + elif file.location.location_type == FileLocation.S3: + aws = file.location.hook.get_credentials() + if aws.access_key and aws.secret_key: + auth = f"credentials=(aws_key_id='{aws.access_key}' aws_secret_key='{aws.secret_key}');" + else: + raise ValueError( + "In order to create an stage for S3, one of the following is required: " + "* `storage_integration`" + "* AWS_KEY_ID and SECRET_KEY_ID" + ) + return auth + + def create_stage( + self, + file: File, + storage_integration: Optional[str] = None, + metadata: Optional[Metadata] = None, + ) -> SnowflakeStage: + """ + Creates a new named external stage to use for loading data from files into Snowflake + tables and unloading data from tables into files. + + At the moment, the following ways of authenticating to the backend are supported: + * Google Cloud Storage (GCS): using storage_integration, previously created + * Amazon (S3): one of the following: + (i) using storage_integration or + (ii) retrieving the AWS_KEY_ID and AWS_SECRET_KEY from the Airflow file connection + + :param file: File to be copied from/to using stage + :param storage_integration: Previously created Snowflake storage integration + :param metadata: Contains Snowflake database and schema information + :return: Stage created + + .. seealso:: + `Snowflake official documentation on stage creation + `_ + """ + auth = self._create_stage_auth_sub_statement( + file=file, storage_integration=storage_integration + ) + + metadata = metadata or self.default_metadata + stage = SnowflakeStage(metadata=metadata) + stage.set_url_from_file(file) + + fileformat = ASTRO_SDK_TO_SNOWFLAKE_FILE_FORMAT_MAP[file.type.name] + copy_options = COPY_OPTIONS[file.type.name] + + sql_statement = "".join( + [ + f"CREATE OR REPLACE STAGE {stage.qualified_name} URL='{stage.url}' ", + f"FILE_FORMAT=(TYPE={fileformat}, TRIM_SPACE=TRUE) ", + f"COPY_OPTIONS=({copy_options}) ", + auth, + ] + ) + + self.run_sql(sql_statement) + + return stage + + def stage_exists(self, stage: SnowflakeStage) -> bool: + """ + Checks if a Snowflake stage exists. + + :param: SnowflakeStage instance + :return: True/False + """ + sql_statement = f"DESCRIBE STAGE {stage.qualified_name}" + try: + self.hook.run(sql_statement) + except ProgrammingError: + logging.error( + "Stage '%s' does not exist or not authorized.", stage.qualified_name + ) + return False + return True + + def drop_stage(self, stage: SnowflakeStage) -> None: + """ + Runs the snowflake query to drop stage if it exists. + + :param stage: Stage to be dropped + """ + sql_statement = f"DROP STAGE IF EXISTS {stage.qualified_name};" + self.hook.run(sql_statement, autocommit=True) + + # --------------------------------------------------------- + # Table load methods + # --------------------------------------------------------- + def load_pandas_dataframe_to_table( self, source_dataframe: pd.DataFrame, @@ -125,12 +359,15 @@ def get_sqlalchemy_template_table_identifier_and_parameter( Since the table value is templated, there is a safety concern (e.g. SQL injection). We recommend looking into the documentation of the database and seeing what are the best practices. - This is the Snowflake documentation: - https://docs.snowflake.com/en/sql-reference/identifier-literal.html + :param table: The table object we want to generate a safe table identifier for :param jinja_table_identifier: The name used within the Jinja template to represent this table :return: value to replace the table identifier in the query and the value that should be used to replace it + + .. seealso:: + `Snowflake official documentation on literals + `_ """ return ( f"IDENTIFIER(:{jinja_table_identifier})", @@ -273,9 +510,15 @@ def wrap_identifier(inp: str) -> str: def is_valid_snow_identifier(name: str) -> bool: """ - Because Snowflake does not allow using `Identifier` for inserts or updates, we need to make reasonable attempts to - ensure that no one can perform a SQL injection using this method. The following method ensures that a string - follows the expected identifier syntax https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html + Because Snowflake does not allow using `Identifier` for inserts or updates, + we need to make reasonable attempts to ensure that no one can perform a SQL + injection using this method. + The following method ensures that a string follows the expected identifier syntax. + + .. seealso:: + `Snowflake official documentation on indentifiers syntax + `_ + """ if not 1 <= len(name) <= 255: return False diff --git a/src/astro/files/locations/amazon/s3.py b/src/astro/files/locations/amazon/s3.py index 87292b477..9b1c5b8c7 100644 --- a/src/astro/files/locations/amazon/s3.py +++ b/src/astro/files/locations/amazon/s3.py @@ -13,6 +13,10 @@ class S3Location(BaseFileLocation): location_type = FileLocation.S3 + @property + def hook(self) -> S3Hook: + return S3Hook(aws_conn_id=self.conn_id) if self.conn_id else S3Hook() + @staticmethod def _parse_s3_env_var() -> Tuple[str, str]: """Return S3 ID/KEY pair from environment vars""" @@ -23,8 +27,7 @@ def transport_params(self) -> Dict: """Structure s3fs credentials from Airflow connection. s3fs enables pandas to write to s3 """ - hook = S3Hook(aws_conn_id=self.conn_id) if self.conn_id else S3Hook() - session = hook.get_session() + session = self.hook.get_session() return {"client": session.client("s3")} @property @@ -33,8 +36,7 @@ def paths(self) -> List[str]: url = urlparse(self.path) bucket_name = url.netloc prefix = url.path[1:] - hook = S3Hook(aws_conn_id=self.conn_id) if self.conn_id else S3Hook() - prefixes = hook.list_keys(bucket_name=bucket_name, prefix=prefix) + prefixes = self.hook.list_keys(bucket_name=bucket_name, prefix=prefix) paths = [ urlunparse((url.scheme, url.netloc, keys, "", "", "")) for keys in prefixes ] diff --git a/src/astro/files/locations/base.py b/src/astro/files/locations/base.py index cd3202309..3099363be 100644 --- a/src/astro/files/locations/base.py +++ b/src/astro/files/locations/base.py @@ -25,6 +25,10 @@ def __init__(self, path: str, conn_id: Optional[str] = None): self.path = path self.conn_id = conn_id + @property + def hook(self): + raise NotImplementedError + @property @abstractmethod def location_type(self): diff --git a/src/astro/files/locations/google/gcs.py b/src/astro/files/locations/google/gcs.py index 8b2a029be..810d7c633 100644 --- a/src/astro/files/locations/google/gcs.py +++ b/src/astro/files/locations/google/gcs.py @@ -12,11 +12,14 @@ class GCSLocation(BaseFileLocation): location_type = FileLocation.GS + @property + def hook(self) -> GCSHook: + return GCSHook(gcp_conn_id=self.conn_id) if self.conn_id else GCSHook() + @property def transport_params(self) -> Dict: """get GCS credentials for storage""" - hook = GCSHook(gcp_conn_id=self.conn_id) if self.conn_id else GCSHook() - client = hook.get_conn() + client = self.hook.get_conn() return {"client": client} @property @@ -25,8 +28,7 @@ def paths(self) -> List[str]: url = urlparse(self.path) bucket_name = url.netloc prefix = url.path[1:] - hook = GCSHook(gcp_conn_id=self.conn_id) if self.conn_id else GCSHook() - prefixes = hook.list(bucket_name=bucket_name, prefix=prefix) + prefixes = self.hook.list(bucket_name=bucket_name, prefix=prefix) paths = [ urlunparse((url.scheme, url.netloc, keys, "", "", "")) for keys in prefixes ] diff --git a/src/astro/settings.py b/src/astro/settings.py index ba86ffe84..5e93f1e6f 100644 --- a/src/astro/settings.py +++ b/src/astro/settings.py @@ -2,3 +2,16 @@ DEFAULT_SCHEMA = "tmp_astro" SCHEMA = conf.get("astro_sdk", "sql_schema", fallback=DEFAULT_SCHEMA) + +# We are not defining a fallback key on purpose. S3 Snowflake stages can also +# be created without a storage integration, by using the Airflow AWS connection +# properties. +SNOWFLAKE_STORAGE_INTEGRATION_AMAZON = conf.get( + section="astro_sdk", key="snowflake_storage_integration_amazon", fallback=None +) + +SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE = conf.get( + section="astro_sdk", + key="snowflake_storage_integration_google", + fallback="gcs_int_python_sdk", +) diff --git a/tests/databases/test_snowflake.py b/tests/databases/test_snowflake.py index 242fe9616..1e7197da1 100644 --- a/tests/databases/test_snowflake.py +++ b/tests/databases/test_snowflake.py @@ -1,18 +1,23 @@ """Tests specific to the Sqlite Database implementation.""" import os import pathlib +from unittest.mock import patch import pandas as pd import pytest import sqlalchemy from sqlalchemy.exc import ProgrammingError -from astro.constants import Database +from astro.constants import Database, FileLocation, FileType from astro.databases import create_database -from astro.databases.snowflake import SnowflakeDatabase +from astro.databases.snowflake import SnowflakeDatabase, SnowflakeStage from astro.exceptions import NonExistentTableException from astro.files import File -from astro.settings import SCHEMA +from astro.settings import ( + SCHEMA, + SNOWFLAKE_STORAGE_INTEGRATION_AMAZON, + SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE, +) from astro.sql.table import Metadata, Table from astro.utils.load import copy_remote_file_to_local from tests.sql.operators import utils as test_utils @@ -338,3 +343,85 @@ def test_create_table_from_select_statement(database_table_fixture): expected = pd.DataFrame([{"id": 1, "name": "First"}]) test_utils.assert_dataframes_are_equal(df, expected) database.drop_table(target_table) + + +def test_stage_set_name_after(): + stage = SnowflakeStage() + stage.name = "abc" + assert stage.name == "abc" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "remote_files_fixture", + [ + {"provider": "google", "filetype": FileType.CSV}, + ], + indirect=True, + ids=["google_csv"], +) +def test_stage_exists_false(remote_files_fixture): + file_fixture = File(remote_files_fixture[0]) + database = SnowflakeDatabase(conn_id=CUSTOM_CONN_ID) + stage = SnowflakeStage( + name="inexistent-stage", + metadata=database.default_metadata, + ) + stage.set_url_from_file(file_fixture) + assert not database.stage_exists(stage) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "remote_files_fixture", + [ + {"provider": "google", "filetype": FileType.CSV}, + {"provider": "google", "filetype": FileType.NDJSON}, + {"provider": "google", "filetype": FileType.PARQUET}, + {"provider": "amazon", "filetype": FileType.CSV}, + ], + indirect=True, + ids=["google_csv", "google_ndjson", "google_parquet", "amazon_csv"], +) +def test_create_stage_succeeds(remote_files_fixture): + file_fixture = File(remote_files_fixture[0]) + if file_fixture.location.location_type == FileLocation.GS: + storage_integration = SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE + else: + storage_integration = SNOWFLAKE_STORAGE_INTEGRATION_AMAZON + + database = SnowflakeDatabase(conn_id=CUSTOM_CONN_ID) + stage = database.create_stage( + file=file_fixture, storage_integration=storage_integration + ) + assert database.stage_exists(stage) + database.drop_stage(stage) + + +def test_create_stage_google_fails_due_to_no_storage_integration(): + database = SnowflakeDatabase(conn_id="fake-conn") + with pytest.raises(ValueError) as exc_info: + database.create_stage(file=File("gs://some-bucket/some-file.csv")) + expected_msg = ( + "In order to create an stage for GCS, `storage_integration` is required." + ) + assert exc_info.match(expected_msg) + + +class MockCredentials: + access_key = None + secret_key = None + + +@patch( + "astro.files.locations.amazon.s3.S3Hook.get_credentials", + return_value=MockCredentials(), +) +def test_create_stage_amazon_fails_due_to_no_credentials(get_credentials): + database = SnowflakeDatabase(conn_id="fake-conn") + with pytest.raises(ValueError) as exc_info: + database.create_stage(file=File("s3://some-bucket/some-file.csv")) + expected_msg = ( + "In order to create an stage for S3, one of the following is required" + ) + assert exc_info.match(expected_msg)