Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ingest): correct external url for account identifier with account name #6715

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@


class SnowflakeQuery:
@staticmethod
def current_account() -> str:
return "select CURRENT_ACCOUNT()"

@staticmethod
def current_region() -> str:
return "select CURRENT_REGION()"

@staticmethod
def current_version() -> str:
return "select CURRENT_VERSION()"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Optional

from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport
from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport
from datahub.ingestion.source_report.usage.snowflake_usage import SnowflakeUsageReport


class SnowflakeV2Report(SnowflakeReport, SnowflakeUsageReport, ProfilingSqlReport):

account_locator: Optional[str] = None
region: Optional[str] = None

schemas_scanned: int = 0
databases_scanned: int = 0

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from enum import Enum
from functools import lru_cache
from typing import Any, Optional

from snowflake.connector import SnowflakeConnection
Expand All @@ -24,7 +23,14 @@ class SnowflakeCloudProvider(str, Enum):
AZURE = "azure"


SNOWFLAKE_DEFAULT_CLOUD_REGION_ID = "us-west-2"
# See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#region-ids
# Includes only exceptions to format <provider>_<cloud region with hyphen replaced by _>
SNOWFLAKE_REGION_CLOUD_REGION_MAPPING = {
"aws_us_east_1_gov": (SnowflakeCloudProvider.AWS, "us-east-1"),
"azure_uksouth": (SnowflakeCloudProvider.AZURE, "uk-south"),
"azure_centralindia": (SnowflakeCloudProvider.AZURE, "central-india.azure"),
}

SNOWFLAKE_DEFAULT_CLOUD = SnowflakeCloudProvider.AWS


Expand Down Expand Up @@ -64,49 +70,31 @@ class SnowflakeCommonMixin:
platform = "snowflake"

@staticmethod
@lru_cache(maxsize=128)
def create_snowsight_base_url(account_id: str) -> Optional[str]:
cloud: Optional[str] = None
account_locator: Optional[str] = None
cloud_region_id: Optional[str] = None
privatelink: bool = False

if "." not in account_id: # e.g. xy12345
account_locator = account_id.lower()
cloud_region_id = SNOWFLAKE_DEFAULT_CLOUD_REGION_ID
def create_snowsight_base_url(
account_locator: str,
cloud_region_id: str,
cloud: str,
privatelink: bool = False,
) -> Optional[str]:
if privatelink:
url = f"https://app.{account_locator}.{cloud_region_id}.privatelink.snowflakecomputing.com/"
elif cloud == SNOWFLAKE_DEFAULT_CLOUD:
url = f"https://app.snowflake.com/{cloud_region_id}/{account_locator}/"
else:
url = f"https://app.snowflake.com/{cloud_region_id}.{cloud}/{account_locator}/"
return url

@staticmethod
def get_cloud_region_from_snowflake_region_id(region):
if region in SNOWFLAKE_REGION_CLOUD_REGION_MAPPING.keys():
cloud, cloud_region_id = SNOWFLAKE_REGION_CLOUD_REGION_MAPPING[region]
elif region.startswith(("aws_", "gcp_", "azure_")):
# e.g. aws_us_west_2, gcp_us_central1, azure_northeurope
cloud, cloud_region_id = region.split("_", 1)
cloud_region_id = cloud_region_id.replace("_", "-")
else:
parts = account_id.split(".")
if len(parts) == 2: # e.g. xy12345.us-east-1
account_locator = parts[0].lower()
cloud_region_id = parts[1].lower()
elif len(parts) == 3 and parts[2] in (
SnowflakeCloudProvider.AWS,
SnowflakeCloudProvider.GCP,
SnowflakeCloudProvider.AZURE,
):
# e.g. xy12345.ap-south-1.aws or xy12345.us-central1.gcp or xy12345.west-us-2.azure
# NOT xy12345.us-west-2.privatelink or xy12345.eu-central-1.privatelink
account_locator = parts[0].lower()
cloud_region_id = parts[1].lower()
cloud = parts[2].lower()
elif len(parts) == 3 and parts[2] == "privatelink":
account_locator = parts[0].lower()
cloud_region_id = parts[1].lower()
privatelink = True
else:
logger.warning(
f"Could not create Snowsight base url for account {account_id}."
)
return None

if not privatelink and (cloud is None or cloud == SNOWFLAKE_DEFAULT_CLOUD):
return f"https://app.snowflake.com/{cloud_region_id}/{account_locator}/"
elif privatelink:
return f"https://app.{account_locator}.{cloud_region_id}.privatelink.snowflakecomputing.com/"
return f"https://app.snowflake.com/{cloud_region_id}.{cloud}/{account_locator}/"

def get_snowsight_base_url(self: SnowflakeCommonProtocol) -> Optional[str]:
return SnowflakeCommonMixin.create_snowsight_base_url(self.config.get_account())
raise Exception(f"Unknown snowflake region {region}")
return cloud, cloud_region_id

def _is_dataset_pattern_allowed(
self: SnowflakeCommonProtocol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config):
self.config: SnowflakeV2Config = config
self.report: SnowflakeV2Report = SnowflakeV2Report()
self.logger = logger
self.snowsight_base_url = None
# Create and register the stateful ingestion use-case handlers.
self.stale_entity_removal_handler = StaleEntityRemovalHandler(
source=self,
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config):

if self.is_classification_enabled():
self.classifiers = self.get_classifiers()

# Currently caching using instance variables
# TODO - rewrite cache for readability or use out of the box solution
self.db_tables: Dict[str, Optional[Dict[str, List[SnowflakeTable]]]] = {}
Expand Down Expand Up @@ -431,6 +433,8 @@ def get_workunits(self) -> Iterable[WorkUnit]:
conn: SnowflakeConnection = self.config.get_connection()
self.add_config_to_report()
self.inspect_session_metadata(conn)
if self.config.include_external_url:
self.snowsight_base_url = self.get_snowsight_base_url(conn)

self.report.include_technical_schema = self.config.include_technical_schema
databases: List[SnowflakeDatabase] = []
Expand Down Expand Up @@ -1153,21 +1157,55 @@ def get_sample_values_for_table(self, conn, table_name, schema_name, db_name):
def get_external_url_for_table(
self, table_name: str, schema_name: str, db_name: str, domain: str
) -> Optional[str]:
base_url = self.get_snowsight_base_url()
if base_url is not None:
return f"{base_url}#/data/databases/{db_name}/schemas/{schema_name}/{domain}/{table_name}/"
if self.snowsight_base_url is not None:
return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/{domain}/{table_name}/"
return None

def get_external_url_for_schema(
self, schema_name: str, db_name: str
) -> Optional[str]:
base_url = self.get_snowsight_base_url()
if base_url is not None:
return f"{base_url}#/data/databases/{db_name}/schemas/{schema_name}/"
if self.snowsight_base_url is not None:
return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/"
return None

def get_external_url_for_database(self, db_name: str) -> Optional[str]:
base_url = self.get_snowsight_base_url()
if base_url is not None:
return f"{base_url}#/data/databases/{db_name}/"
if self.snowsight_base_url is not None:
return f"{self.snowsight_base_url}#/data/databases/{db_name}/"
return None

def get_snowsight_base_url(self, conn):
try:
# See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#finding-the-region-and-locator-for-an-account
for db_row in self.query(conn, SnowflakeQuery.current_account()):
account_locator = db_row["CURRENT_ACCOUNT()"]

for db_row in self.query(conn, SnowflakeQuery.current_region()):
region = db_row["CURRENT_REGION()"]

self.report.account_locator = account_locator
self.report.region = region

# Returned region may be in the form <region_group>.<region>, see https://docs.snowflake.com/en/sql-reference/functions/current_region.html
region = region.split(".")[-1].lower()
account_locator = account_locator.lower()

cloud, cloud_region_id = self.get_cloud_region_from_snowflake_region_id(
region
)

# For privatelink, account identifier ends with .privatelink
# See https://docs.snowflake.com/en/user-guide/organizations-connect.html#private-connectivity-urls
return self.create_snowsight_base_url(
account_locator,
cloud_region_id,
cloud,
self.config.account_id.endswith(".privatelink"), # type:ignore
)

except Exception as e:
self.warn(
self.logger,
"snowsight url",
f"unable to get snowsight base url due to an error -> {e}",
)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@


def default_query_results(query):
if query == SnowflakeQuery.current_account():
return [{"CURRENT_ACCOUNT()": "ABC12345"}]
if query == SnowflakeQuery.current_region():
return [{"CURRENT_REGION()": "AWS_AP_SOUTH_1"}]
if query == SnowflakeQuery.current_role():
return [{"CURRENT_ROLE()": "TEST_ROLE"}]
elif query == SnowflakeQuery.current_version():
Expand Down
59 changes: 59 additions & 0 deletions metadata-ingestion/tests/unit/test_snowflake_beta_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datahub.configuration.common import ConfigurationError, OauthConfiguration
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCloudProvider
from datahub.ingestion.source.snowflake.snowflake_v2 import SnowflakeV2Source


Expand Down Expand Up @@ -447,3 +448,61 @@ def query_results(query):
assert report.capability_report[SourceCapability.DATA_PROFILING].capable
assert report.capability_report[SourceCapability.DESCRIPTIONS].capable
assert report.capability_report[SourceCapability.LINEAGE_COARSE].capable


def test_aws_cloud_region_from_snowflake_region_id():
(
cloud,
cloud_region_id,
) = SnowflakeV2Source.get_cloud_region_from_snowflake_region_id("aws_ca_central_1")

assert cloud == SnowflakeCloudProvider.AWS
assert cloud_region_id == "ca-central-1"

(
cloud,
cloud_region_id,
) = SnowflakeV2Source.get_cloud_region_from_snowflake_region_id("aws_us_east_1_gov")

assert cloud == SnowflakeCloudProvider.AWS
assert cloud_region_id == "us-east-1"


def test_google_cloud_region_from_snowflake_region_id():
(
cloud,
cloud_region_id,
) = SnowflakeV2Source.get_cloud_region_from_snowflake_region_id("gcp_europe_west2")

assert cloud == SnowflakeCloudProvider.GCP
assert cloud_region_id == "europe-west2"


def test_azure_cloud_region_from_snowflake_region_id():
(
cloud,
cloud_region_id,
) = SnowflakeV2Source.get_cloud_region_from_snowflake_region_id(
"azure_switzerlandnorth"
)

assert cloud == SnowflakeCloudProvider.AZURE
assert cloud_region_id == "switzerlandnorth"

(
cloud,
cloud_region_id,
) = SnowflakeV2Source.get_cloud_region_from_snowflake_region_id(
"azure_centralindia"
)

assert cloud == SnowflakeCloudProvider.AZURE
assert cloud_region_id == "central-india.azure"


def test_unknown_cloud_region_from_snowflake_region_id():
with pytest.raises(Exception) as e:
SnowflakeV2Source.get_cloud_region_from_snowflake_region_id(
"somecloud_someregion"
)
assert "Unknown snowflake region" in str(e)