Skip to content

Commit

Permalink
Add Snowflake stage methods (#523)
Browse files Browse the repository at this point in the history
- Create
- Check existance
- Drop

Relates to: #492
Co-authored-by: Ankit Chaurasia <[email protected]>
  • Loading branch information
tatiana authored and utkarsharma2 committed Jul 12, 2022
1 parent 857641c commit 695eff2
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 19 deletions.
259 changes: 251 additions & 8 deletions src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,135 @@
"""Snowflake database implementation."""
from typing import Dict, List, Tuple
import logging
import random
import string
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import pandas as pd
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from pandas.io.sql import SQLDatabase
from snowflake.connector import pandas_tools

from astro.constants import DEFAULT_CHUNK_SIZE, LoadExistStrategy, MergeConflictStrategy
from snowflake.connector.errors import ProgrammingError

from astro.constants import (
DEFAULT_CHUNK_SIZE,
FileLocation,
FileType,
LoadExistStrategy,
MergeConflictStrategy,
)
from astro.databases.base import BaseDatabase
from astro.files import File
from astro.sql.table import Metadata, Table

DEFAULT_CONN_ID = SnowflakeHook.default_conn_name

ASTRO_SDK_TO_SNOWFLAKE_FILE_FORMAT_MAP = {
FileType.CSV: "CSV",
FileType.NDJSON: "JSON",
FileType.PARQUET: "PARQUET",
}

COPY_OPTIONS = {
FileType.CSV: "ON_ERROR=CONTINUE",
FileType.NDJSON: "MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE",
FileType.PARQUET: "MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE",
}


@dataclass
class SnowflakeStage:
"""
Dataclass which abstracts properties of a Snowflake Stage.
Snowflake Stages are used to loading tables and unloading data from tables into files.
Example:
.. code-block:: python
snowflake_stage = SnowflakeStage(
name="stage_name",
url="gcs://bucket/prefix",
metadata=Metadata(database="SNOWFLAKE_DATABASE", schema="SNOWFLAKE_SCHEMA"),
)
.. seealso::
`Snowflake official documentation on stage creation
<https://docs.snowflake.com/en/sql-reference/sql/create-stage.html>`_
"""

name: str = ""
_name: str = field(init=False, repr=False, default="")
url: str = ""
metadata: Metadata = field(default_factory=Metadata)

@staticmethod
def _create_unique_name() -> str:
"""
Generate a valid Snowflake stage name.
:return: unique stage name
"""
return (
"stage_"
+ random.choice(string.ascii_lowercase)
+ "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(7)
)
)

def set_url_from_file(self, file: File) -> None:
"""
Given a file to be loaded/unloaded to from Snowflake, identifies its folder and
sets as self.url.
It is also responsbile for adjusting any path specific requirements for Snowflake.
:param file: File to be loaded/unloaded to from Snowflake
"""
# the stage URL needs to be the folder where the files are
# https://docs.snowflake.com/en/sql-reference/sql/create-stage.html#external-stage-parameters-externalstageparams
url = file.path[: file.path.rfind("/") + 1]
self.url = url.replace("gs://", "gcs://")

@property # type: ignore
def name(self) -> str:
"""
Return either the user-defined name or auto-generated one.
:return: stage name
:sphinx-autoapi-skip:
"""
if not self._name:
self._name = self._create_unique_name()
return self._name

@name.setter
def name(self, value: str) -> None:
"""
Set the stage name.
:param value: Stage name.
"""
if not isinstance(value, property) and value != self._name:
self._name = value

@property
def qualified_name(self) -> str:
"""
Return stage qualified name. In Snowflake, it is the database, schema and table
:return: Snowflake stage qualified name (e.g. database.schema.table)
"""
qualified_name_lists = [
self.metadata.database,
self.metadata.schema,
self.name,
]
qualified_name = ".".join(name for name in qualified_name_lists if name)
return qualified_name


class SnowflakeDatabase(BaseDatabase):
"""
Expand All @@ -20,6 +138,7 @@ class SnowflakeDatabase(BaseDatabase):
"""

def __init__(self, conn_id: str = DEFAULT_CONN_ID):
self.storage_integration: Optional[str] = None
super().__init__(conn_id)

@property
Expand Down Expand Up @@ -57,6 +176,121 @@ def get_table_qualified_name(table: Table) -> str: # skipcq: PYL-R0201
qualified_name = ".".join(name for name in qualified_name_lists if name)
return qualified_name

# ---------------------------------------------------------
# Snowflake stage methods
# ---------------------------------------------------------

@staticmethod
def _create_stage_auth_sub_statement(
file: File, storage_integration: Optional[str] = None
) -> str:
"""
Create authentication-related line for the Snowflake CREATE STAGE.
Raise an exception if it is not defined.
:param file: File to be copied from/to using stage
:param storage_integration: Previously created Snowflake storage integration
:return: String containing line to be used for authentication on the remote storage
"""

if storage_integration is not None:
auth = f"storage_integration = {storage_integration};"
else:
if file.location.location_type == FileLocation.GS:
raise ValueError(
"In order to create an stage for GCS, `storage_integration` is required."
)
elif file.location.location_type == FileLocation.S3:
aws = file.location.hook.get_credentials()
if aws.access_key and aws.secret_key:
auth = f"credentials=(aws_key_id='{aws.access_key}' aws_secret_key='{aws.secret_key}');"
else:
raise ValueError(
"In order to create an stage for S3, one of the following is required: "
"* `storage_integration`"
"* AWS_KEY_ID and SECRET_KEY_ID"
)
return auth

def create_stage(
self,
file: File,
storage_integration: Optional[str] = None,
metadata: Optional[Metadata] = None,
) -> SnowflakeStage:
"""
Creates a new named external stage to use for loading data from files into Snowflake
tables and unloading data from tables into files.
At the moment, the following ways of authenticating to the backend are supported:
* Google Cloud Storage (GCS): using storage_integration, previously created
* Amazon (S3): one of the following:
(i) using storage_integration or
(ii) retrieving the AWS_KEY_ID and AWS_SECRET_KEY from the Airflow file connection
:param file: File to be copied from/to using stage
:param storage_integration: Previously created Snowflake storage integration
:param metadata: Contains Snowflake database and schema information
:return: Stage created
.. seealso::
`Snowflake official documentation on stage creation
<https://docs.snowflake.com/en/sql-reference/sql/create-stage.html>`_
"""
auth = self._create_stage_auth_sub_statement(
file=file, storage_integration=storage_integration
)

metadata = metadata or self.default_metadata
stage = SnowflakeStage(metadata=metadata)
stage.set_url_from_file(file)

fileformat = ASTRO_SDK_TO_SNOWFLAKE_FILE_FORMAT_MAP[file.type.name]
copy_options = COPY_OPTIONS[file.type.name]

sql_statement = "".join(
[
f"CREATE OR REPLACE STAGE {stage.qualified_name} URL='{stage.url}' ",
f"FILE_FORMAT=(TYPE={fileformat}, TRIM_SPACE=TRUE) ",
f"COPY_OPTIONS=({copy_options}) ",
auth,
]
)

self.run_sql(sql_statement)

return stage

def stage_exists(self, stage: SnowflakeStage) -> bool:
"""
Checks if a Snowflake stage exists.
:param: SnowflakeStage instance
:return: True/False
"""
sql_statement = f"DESCRIBE STAGE {stage.qualified_name}"
try:
self.hook.run(sql_statement)
except ProgrammingError:
logging.error(
"Stage '%s' does not exist or not authorized.", stage.qualified_name
)
return False
return True

def drop_stage(self, stage: SnowflakeStage) -> None:
"""
Runs the snowflake query to drop stage if it exists.
:param stage: Stage to be dropped
"""
sql_statement = f"DROP STAGE IF EXISTS {stage.qualified_name};"
self.hook.run(sql_statement, autocommit=True)

# ---------------------------------------------------------
# Table load methods
# ---------------------------------------------------------

def load_pandas_dataframe_to_table(
self,
source_dataframe: pd.DataFrame,
Expand Down Expand Up @@ -125,12 +359,15 @@ def get_sqlalchemy_template_table_identifier_and_parameter(
Since the table value is templated, there is a safety concern (e.g. SQL injection).
We recommend looking into the documentation of the database and seeing what are the best practices.
This is the Snowflake documentation:
https://docs.snowflake.com/en/sql-reference/identifier-literal.html
:param table: The table object we want to generate a safe table identifier for
:param jinja_table_identifier: The name used within the Jinja template to represent this table
:return: value to replace the table identifier in the query and the value that should be used to replace it
.. seealso::
`Snowflake official documentation on literals
<https://docs.snowflake.com/en/sql-reference/identifier-literal.html>`_
"""
return (
f"IDENTIFIER(:{jinja_table_identifier})",
Expand Down Expand Up @@ -273,9 +510,15 @@ def wrap_identifier(inp: str) -> str:

def is_valid_snow_identifier(name: str) -> bool:
"""
Because Snowflake does not allow using `Identifier` for inserts or updates, we need to make reasonable attempts to
ensure that no one can perform a SQL injection using this method. The following method ensures that a string
follows the expected identifier syntax https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html
Because Snowflake does not allow using `Identifier` for inserts or updates,
we need to make reasonable attempts to ensure that no one can perform a SQL
injection using this method.
The following method ensures that a string follows the expected identifier syntax.
.. seealso::
`Snowflake official documentation on indentifiers syntax
<https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html>`_
"""
if not 1 <= len(name) <= 255:
return False
Expand Down
10 changes: 6 additions & 4 deletions src/astro/files/locations/amazon/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class S3Location(BaseFileLocation):

location_type = FileLocation.S3

@property
def hook(self) -> S3Hook:
return S3Hook(aws_conn_id=self.conn_id) if self.conn_id else S3Hook()

@staticmethod
def _parse_s3_env_var() -> Tuple[str, str]:
"""Return S3 ID/KEY pair from environment vars"""
Expand All @@ -23,8 +27,7 @@ def transport_params(self) -> Dict:
"""Structure s3fs credentials from Airflow connection.
s3fs enables pandas to write to s3
"""
hook = S3Hook(aws_conn_id=self.conn_id) if self.conn_id else S3Hook()
session = hook.get_session()
session = self.hook.get_session()
return {"client": session.client("s3")}

@property
Expand All @@ -33,8 +36,7 @@ def paths(self) -> List[str]:
url = urlparse(self.path)
bucket_name = url.netloc
prefix = url.path[1:]
hook = S3Hook(aws_conn_id=self.conn_id) if self.conn_id else S3Hook()
prefixes = hook.list_keys(bucket_name=bucket_name, prefix=prefix)
prefixes = self.hook.list_keys(bucket_name=bucket_name, prefix=prefix)
paths = [
urlunparse((url.scheme, url.netloc, keys, "", "", "")) for keys in prefixes
]
Expand Down
4 changes: 4 additions & 0 deletions src/astro/files/locations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def __init__(self, path: str, conn_id: Optional[str] = None):
self.path = path
self.conn_id = conn_id

@property
def hook(self):
raise NotImplementedError

@property
@abstractmethod
def location_type(self):
Expand Down
10 changes: 6 additions & 4 deletions src/astro/files/locations/google/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ class GCSLocation(BaseFileLocation):

location_type = FileLocation.GS

@property
def hook(self) -> GCSHook:
return GCSHook(gcp_conn_id=self.conn_id) if self.conn_id else GCSHook()

@property
def transport_params(self) -> Dict:
"""get GCS credentials for storage"""
hook = GCSHook(gcp_conn_id=self.conn_id) if self.conn_id else GCSHook()
client = hook.get_conn()
client = self.hook.get_conn()
return {"client": client}

@property
Expand All @@ -25,8 +28,7 @@ def paths(self) -> List[str]:
url = urlparse(self.path)
bucket_name = url.netloc
prefix = url.path[1:]
hook = GCSHook(gcp_conn_id=self.conn_id) if self.conn_id else GCSHook()
prefixes = hook.list(bucket_name=bucket_name, prefix=prefix)
prefixes = self.hook.list(bucket_name=bucket_name, prefix=prefix)
paths = [
urlunparse((url.scheme, url.netloc, keys, "", "", "")) for keys in prefixes
]
Expand Down
13 changes: 13 additions & 0 deletions src/astro/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,16 @@

DEFAULT_SCHEMA = "tmp_astro"
SCHEMA = conf.get("astro_sdk", "sql_schema", fallback=DEFAULT_SCHEMA)

# We are not defining a fallback key on purpose. S3 Snowflake stages can also
# be created without a storage integration, by using the Airflow AWS connection
# properties.
SNOWFLAKE_STORAGE_INTEGRATION_AMAZON = conf.get(
section="astro_sdk", key="snowflake_storage_integration_amazon", fallback=None
)

SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE = conf.get(
section="astro_sdk",
key="snowflake_storage_integration_google",
fallback="gcs_int_python_sdk",
)
Loading

0 comments on commit 695eff2

Please sign in to comment.