From 6795ddaae3c3cc0bacbbd71150bb8dd24bf5eb73 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Tue, 12 Mar 2024 20:01:59 +0530 Subject: [PATCH] feat(ingest): add classification to bigquery, redshift Contains refractoring changes for snowflake classification --- metadata-ingestion/setup.py | 5 +- .../src/datahub/ingestion/api/source.py | 1 + .../glossary/classification_mixin.py | 80 +++++++++- .../ingestion/source/bigquery_v2/bigquery.py | 49 +++++- .../source/bigquery_v2/bigquery_config.py | 10 +- .../bigquery_v2/bigquery_data_reader.py | 59 +++++++ .../source/bigquery_v2/bigquery_report.py | 8 +- .../ingestion/source/bigquery_v2/profiler.py | 2 +- .../ingestion/source/redshift/config.py | 4 + .../ingestion/source/redshift/redshift.py | 38 ++++- .../source/redshift/redshift_data_reader.py | 51 ++++++ .../ingestion/source/redshift/report.py | 8 +- .../source/snowflake/snowflake_data_reader.py | 59 +++++++ .../source/snowflake/snowflake_schema.py | 2 - .../source/snowflake/snowflake_v2.py | 149 ++++-------------- .../datahub/ingestion/source/sql/athena.py | 5 + .../ingestion/source/sql/clickhouse.py | 5 + .../ingestion/source/sql/data_reader.py | 125 ++++++--------- .../src/datahub/ingestion/source/sql/druid.py | 5 + .../src/datahub/ingestion/source/sql/hana.py | 5 + .../src/datahub/ingestion/source/sql/hive.py | 5 + .../src/datahub/ingestion/source/sql/mysql.py | 5 + .../datahub/ingestion/source/sql/oracle.py | 5 + .../datahub/ingestion/source/sql/postgres.py | 5 + .../ingestion/source/sql/presto_on_hive.py | 5 + .../ingestion/source/sql/sql_common.py | 8 +- .../datahub/ingestion/source/sql/teradata.py | 5 + .../src/datahub/ingestion/source/sql/trino.py | 5 + .../datahub/ingestion/source/sql/vertica.py | 5 + .../bigquery_v2/bigquery_mcp_golden.json | 57 ++++++- .../integration/bigquery_v2/test_bigquery.py | 65 ++++++++ .../mysql/mysql_to_file_dbalias.yml | 46 ++++++ .../integration/snowflake/test_snowflake.py | 15 +- .../test_snowflake_classification.py | 3 +- 34 files changed, 674 insertions(+), 230 deletions(-) create mode 100644 metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py create mode 100644 metadata-ingestion/tests/integration/mysql/mysql_to_file_dbalias.yml diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index cb6e884d57380e..da0b850be9f083 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -302,7 +302,8 @@ | { *sqlglot_lib, "google-cloud-datacatalog-lineage==0.2.2", - }, + } + | classification_lib, "clickhouse": sql_common | clickhouse_common, "clickhouse-usage": sql_common | usage_common | clickhouse_common, "datahub-lineage-file": set(), @@ -370,6 +371,8 @@ | redshift_common | usage_common | sqlglot_lib + | classification_lib + | {"db-dtypes"} # Pandas extension data types | {"cachetools"}, "s3": {*s3_base, *data_lake_profiling}, "gcs": {*s3_base, *data_lake_profiling}, diff --git a/metadata-ingestion/src/datahub/ingestion/api/source.py b/metadata-ingestion/src/datahub/ingestion/api/source.py index 906a431666e17f..d299f1009d51a3 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source.py @@ -57,6 +57,7 @@ class SourceCapability(Enum): TAGS = "Extract Tags" SCHEMA_METADATA = "Schema Metadata" CONTAINERS = "Asset Containers" + CLASSIFICATION = "Classification" @dataclass diff --git a/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py b/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py index c6c95e76d196fc..c0de827b21131f 100644 --- a/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py +++ b/metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py @@ -1,16 +1,20 @@ import concurrent.futures import logging from dataclasses import dataclass, field +from functools import partial from math import ceil -from typing import Dict, Iterable, List, Optional +from typing import Callable, Dict, Iterable, List, Optional, Union from datahub_classify.helper_classes import ColumnInfo, Metadata from pydantic import Field from datahub.configuration.common import ConfigModel, ConfigurationError from datahub.emitter.mce_builder import get_sys_time, make_term_urn, make_user_urn +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.glossary.classifier import ClassificationConfig, Classifier from datahub.ingestion.glossary.classifier_registry import classifier_registry +from datahub.ingestion.source.sql.data_reader import DataReader from datahub.metadata.com.linkedin.pegasus2avro.common import ( AuditStamp, GlossaryTermAssociation, @@ -25,9 +29,12 @@ @dataclass class ClassificationReportMixin: + + num_tables_fetch_sample_values_failed: int = 0 + num_tables_classification_attempted: int = 0 num_tables_classification_failed: int = 0 - num_tables_classified: int = 0 + num_tables_classification_found: int = 0 info_types_detected: LossyDict[str, LossyList[str]] = field( default_factory=LossyDict @@ -99,8 +106,22 @@ def classify_schema_fields( self, dataset_name: str, schema_metadata: SchemaMetadata, - sample_data: Dict[str, list], + sample_data: Union[Dict[str, list], Callable[[], Dict[str, list]]], ) -> None: + + if not isinstance(sample_data, Dict): + try: + # TODO: In future, sample_data fetcher can be lazily called if classification + # requires values as prediction factor + sample_data = sample_data() + except Exception as e: + self.report.num_tables_fetch_sample_values_failed += 1 + logger.warning( + f"Failed to get sample values for dataset. Make sure you have granted SELECT permissions on dataset. {dataset_name}", + ) + sample_data = dict() + logger.debug("Error", exc_info=e) + column_infos = self.get_columns_to_classify( dataset_name, schema_metadata, sample_data ) @@ -137,7 +158,7 @@ def classify_schema_fields( ) if field_terms: - self.report.num_tables_classified += 1 + self.report.num_tables_classification_found += 1 self.populate_terms_in_schema_metadata(schema_metadata, field_terms) def update_field_terms( @@ -234,8 +255,11 @@ def get_columns_to_classify( ) continue - # TODO: Let's auto-skip passing sample_data for complex(array/struct) columns - # for initial rollout + # As a result of custom field path specification e.g. [version=2.0].[type=struct].[type=struct].service' + # Sample values for a nested field (an array , union or struct) are not read / passed in classifier correctly. + # TODO: Fix this behavior for nested fields. This would probably involve: + # 1. Preprocessing field path spec v2 back to native field representation. (without [*] constructs) + # 2. Preprocessing retrieved structured sample data to pass in sample values correctly for nested fields. column_infos.append( ColumnInfo( @@ -256,3 +280,47 @@ def get_columns_to_classify( ) return column_infos + + +def classification_workunit_processor( + table_wu_generator: Iterable[MetadataWorkUnit], + classification_handler: ClassificationHandler, + data_reader: Optional[DataReader], + table_id: List[str], + data_reader_kwargs: dict = {}, +) -> Iterable[MetadataWorkUnit]: + table_name = ".".join(table_id) + if not classification_handler.is_classification_enabled_for_table(table_name): + yield from table_wu_generator + for wu in table_wu_generator: + maybe_schema_metadata = wu.get_aspect_of_type(SchemaMetadata) + if maybe_schema_metadata: + try: + classification_handler.classify_schema_fields( + table_name, + maybe_schema_metadata, + ( + partial( + data_reader.get_sample_data_for_table, + table_id, + classification_handler.config.classification.sample_size + * 1.2, + **data_reader_kwargs, + ) + if data_reader + else dict() + ), + ) + yield MetadataChangeProposalWrapper( + aspect=maybe_schema_metadata, entityUrn=wu.get_urn() + ).as_workunit( + is_primary_source=wu.is_primary_source, + ) + except Exception as e: + logger.debug( + f"Failed to classify table columns for {table_name} due to error -> {e}", + exc_info=e, + ) + yield wu + else: + yield wu diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py index bcc0aa50ed22e6..8452399bddf5da 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py @@ -35,11 +35,16 @@ TestConnectionReport, ) from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.glossary.classification_mixin import ( + ClassificationHandler, + classification_workunit_processor, +) from datahub.ingestion.source.bigquery_v2.bigquery_audit import ( BigqueryTableIdentifier, BigQueryTableRef, ) from datahub.ingestion.source.bigquery_v2.bigquery_config import BigQueryV2Config +from datahub.ingestion.source.bigquery_v2.bigquery_data_reader import BigQueryDataReader from datahub.ingestion.source.bigquery_v2.bigquery_helper import ( unquote_and_decode_unicode_escape_seq, ) @@ -167,6 +172,11 @@ def cleanup(config: BigQueryV2Config) -> None: "Optionally enabled via `stateful_ingestion.remove_stale_metadata`", supported=True, ) +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource): # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types BIGQUERY_FIELD_TYPE_MAPPINGS: Dict[ @@ -214,6 +224,7 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): super(BigqueryV2Source, self).__init__(config, ctx) self.config: BigQueryV2Config = config self.report: BigQueryV2Report = BigQueryV2Report() + self.classification_handler = ClassificationHandler(self.config, self.report) self.platform: str = "bigquery" BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX = ( @@ -227,6 +238,12 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): ) self.sql_parser_schema_resolver = self._init_schema_resolver() + self.data_reader: Optional[BigQueryDataReader] = None + if self.classification_handler.is_classification_enabled(): + self.data_reader = BigQueryDataReader.create( + self.config.get_bigquery_client() + ) + redundant_lineage_run_skip_handler: Optional[ RedundantLineageRunSkipHandler ] = None @@ -713,6 +730,7 @@ def _process_schema( ) columns = None + if ( self.config.include_tables or self.config.include_views @@ -732,12 +750,27 @@ def _process_schema( for table in db_tables[dataset_name]: table_columns = columns.get(table.name, []) if columns else [] - yield from self._process_table( + table_wu_generator = self._process_table( table=table, columns=table_columns, project_id=project_id, dataset_name=dataset_name, ) + yield from classification_workunit_processor( + table_wu_generator, + self.classification_handler, + self.data_reader, + [project_id, dataset_name, table.name], + data_reader_kwargs=dict( + sample_size_percent=( + self.config.classification.sample_size + * 1.2 + / table.rows_count + if table.rows_count + else None + ) + ), + ) elif self.store_table_refs: # Need table_refs to calculate lineage and usage for table_item in self.bigquery_data_dictionary.list_tables( @@ -1071,14 +1104,16 @@ def gen_dataset_workunits( ) yield self.gen_schema_metadata( - dataset_urn, table, columns, str(datahub_dataset_name) + dataset_urn, table, columns, datahub_dataset_name ) dataset_properties = DatasetProperties( name=datahub_dataset_name.get_table_display_name(), - description=unquote_and_decode_unicode_escape_seq(table.comment) - if table.comment - else "", + description=( + unquote_and_decode_unicode_escape_seq(table.comment) + if table.comment + else "" + ), qualifiedName=str(datahub_dataset_name), created=( TimeStamp(time=int(table.created.timestamp() * 1000)) @@ -1238,10 +1273,10 @@ def gen_schema_metadata( dataset_urn: str, table: Union[BigqueryTable, BigqueryView, BigqueryTableSnapshot], columns: List[BigqueryColumn], - dataset_name: str, + dataset_name: BigqueryTableIdentifier, ) -> MetadataWorkUnit: schema_metadata = SchemaMetadata( - schemaName=dataset_name, + schemaName=str(dataset_name), platform=make_data_platform_urn(self.platform), version=0, hash="", diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py index 2f4978d49e6870..28f0be2c38033b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py @@ -10,6 +10,9 @@ from datahub.configuration.common import AllowDenyPattern, ConfigModel from datahub.configuration.validate_field_removal import pydantic_removed_field +from datahub.ingestion.glossary.classification_mixin import ( + ClassificationSourceConfigMixin, +) from datahub.ingestion.source.sql.sql_config import SQLCommonConfig from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulLineageConfigMixin, @@ -64,9 +67,9 @@ def __init__(self, **data: Any): ) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path - def get_bigquery_client(config) -> bigquery.Client: - client_options = config.extra_client_options - return bigquery.Client(config.project_on_behalf, **client_options) + def get_bigquery_client(self) -> bigquery.Client: + client_options = self.extra_client_options + return bigquery.Client(self.project_on_behalf, **client_options) def make_gcp_logging_client( self, project_id: Optional[str] = None @@ -96,6 +99,7 @@ class BigQueryV2Config( StatefulUsageConfigMixin, StatefulLineageConfigMixin, StatefulProfilingConfigMixin, + ClassificationSourceConfigMixin, ): project_id_pattern: AllowDenyPattern = Field( default=AllowDenyPattern.allow_all(), diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py new file mode 100644 index 00000000000000..387dd5f687a752 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_data_reader.py @@ -0,0 +1,59 @@ +from collections import defaultdict +from typing import Dict, List, Optional + +from google.cloud import bigquery + +from datahub.ingestion.source.sql.data_reader import DataReader + + +class BigQueryDataReader(DataReader): + @staticmethod + def create( + client: bigquery.Client, + ) -> "BigQueryDataReader": + return BigQueryDataReader(client) + + def __init__( + self, + client: bigquery.Client, + ) -> None: + self.client = client + + def get_sample_data_for_table( + self, + table_id: List[str], + sample_size: int, + *, + sample_size_percent: Optional[float] = None, + filter: Optional[str] = None, + ) -> Dict[str, list]: + """ + table_id should be in the form [project, dataset, schema] + """ + column_values: Dict[str, list] = defaultdict(list) + + project = table_id[0] + dataset = table_id[1] + table_name = table_id[2] + + if sample_size_percent is None: + return column_values + # Ideally we always know the actual row count. + # The alternative to perform limit query scans entire BQ table + # and is never a recommended option due to cost factor, unless + # additional filter clause (e.g. where condition on partition) is available. + + sample_pc = sample_size_percent * 100 + # TODO: handle for sharded+compulsory partitioned tables + + sql = ( + f"SELECT * FROM `{project}.{dataset}.{table_name}` " + + f"TABLESAMPLE SYSTEM ({sample_pc:.8f} percent)" + ) + # Ref: https://cloud.google.com/bigquery/docs/samples/bigquery-query-results-dataframe + df = self.client.query_and_wait(sql).to_dataframe() + + return df.to_dict(orient="list") + + def close(self) -> None: + self.client.close() diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py index ad7b86219e7c13..54eca61dfe1c9a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py @@ -7,6 +7,7 @@ import pydantic from datahub.ingestion.api.report import Report +from datahub.ingestion.glossary.classification_mixin import ClassificationReportMixin from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport from datahub.ingestion.source_report.time_window import BaseTimeWindowReport @@ -42,7 +43,12 @@ class BigQueryProcessingPerfReport(Report): @dataclass -class BigQueryV2Report(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowReport): +class BigQueryV2Report( + ProfilingSqlReport, + IngestionStageReport, + BaseTimeWindowReport, + ClassificationReportMixin, +): num_total_lineage_entries: TopKDict[str, int] = field(default_factory=TopKDict) num_skipped_lineage_entries_missing_data: TopKDict[str, int] = field( default_factory=int_top_k_dict diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py index 4083eb6db77c15..dbaf28fabc9d45 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/profiler.py @@ -91,7 +91,7 @@ def generate_partition_profiler_query( ) else: logger.warning( - f"Partitioned table {table.name} without partiton column" + f"Partitioned table {table.name} without partition column" ) self.report.profiling_skipped_invalid_partition_ids[ f"{project}.{schema}.{table.name}" diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py index 89fa5dde0e11cc..214f6279283458 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py @@ -9,6 +9,9 @@ from datahub.configuration.common import AllowDenyPattern from datahub.configuration.source_common import DatasetLineageProviderConfigBase from datahub.configuration.validate_field_removal import pydantic_removed_field +from datahub.ingestion.glossary.classification_mixin import ( + ClassificationSourceConfigMixin, +) from datahub.ingestion.source.data_lake_common.path_spec import PathSpec from datahub.ingestion.source.sql.sql_config import BasicSQLAlchemyConfig from datahub.ingestion.source.state.stateful_ingestion_base import ( @@ -70,6 +73,7 @@ class RedshiftConfig( RedshiftUsageConfig, StatefulLineageConfigMixin, StatefulProfilingConfigMixin, + ClassificationSourceConfigMixin, ): database: str = Field(default="dev", description="database") diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py index b890df3b1f7761..26a7ecb2de034a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py @@ -35,6 +35,10 @@ ) from datahub.ingestion.api.source_helpers import create_dataset_props_patch_builder from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.glossary.classification_mixin import ( + ClassificationHandler, + classification_workunit_processor, +) from datahub.ingestion.source.common.subtypes import ( DatasetContainerSubTypes, DatasetSubTypes, @@ -43,6 +47,7 @@ from datahub.ingestion.source.redshift.lineage import RedshiftLineageExtractor from datahub.ingestion.source.redshift.lineage_v2 import RedshiftSqlLineageV2 from datahub.ingestion.source.redshift.profile import RedshiftProfiler +from datahub.ingestion.source.redshift.redshift_data_reader import RedshiftDataReader from datahub.ingestion.source.redshift.redshift_schema import ( RedshiftColumn, RedshiftDataDictionary, @@ -52,6 +57,7 @@ ) from datahub.ingestion.source.redshift.report import RedshiftReport from datahub.ingestion.source.redshift.usage import RedshiftUsageExtractor +from datahub.ingestion.source.sql.data_reader import DataReader from datahub.ingestion.source.sql.sql_common import SqlWorkUnit from datahub.ingestion.source.sql.sql_types import resolve_postgres_modified_type from datahub.ingestion.source.sql.sql_utils import ( @@ -127,6 +133,11 @@ "Enabled by default, can be disabled via configuration `include_usage_statistics`", ) @capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class RedshiftSource(StatefulIngestionSourceBase, TestableSource): """ This plugin extracts the following: @@ -313,6 +324,8 @@ def __init__(self, config: RedshiftConfig, ctx: PipelineContext): self.catalog_metadata: Dict = {} self.config: RedshiftConfig = config self.report: RedshiftReport = RedshiftReport() + self.classification_handler = ClassificationHandler(self.config, self.report) + # TODO: support classification for Redshift self.platform = "redshift" self.domain_registry = None if self.config.domain: @@ -486,6 +499,20 @@ def process_schemas(self, connection, database): self.db_schemas[database][schema.name] = schema yield from self.process_schema(connection, database, schema) + def make_data_reader( + self, + connection: redshift_connector.Connection, + ) -> Optional[DataReader]: + """ + Subclasses can override this with source-specific data reader + if source provides clause to pick random sample instead of current + limit-based sample + """ + if self.classification_handler.is_classification_enabled(): + return RedshiftDataReader.create(connection) + + return None + def process_schema( self, connection: redshift_connector.Connection, @@ -525,6 +552,7 @@ def process_schema( ) if self.config.include_tables: + data_reader = self.make_data_reader(connection) logger.info(f"Process tables in schema {database}.{schema.name}") if ( self.db_tables[schema.database] @@ -532,7 +560,15 @@ def process_schema( ): for table in self.db_tables[schema.database][schema.name]: table.columns = schema_columns[schema.name].get(table.name, []) - yield from self._process_table(table, database=database) + table_wu_generator = self._process_table( + table, database=database + ) + yield from classification_workunit_processor( + table_wu_generator, + self.classification_handler, + data_reader, + [schema.database, schema.name, table.name], + ) self.report.table_processed[report_key] = ( self.report.table_processed.get( f"{database}.{schema.name}", 0 diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py new file mode 100644 index 00000000000000..725b1b1adf7429 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift_data_reader.py @@ -0,0 +1,51 @@ +import logging +from typing import Any, Dict, List + +import redshift_connector + +from datahub.ingestion.source.sql.data_reader import DataReader +from datahub.utilities.perf_timer import PerfTimer + +logger = logging.Logger(__name__) + + +class RedshiftDataReader(DataReader): + @staticmethod + def create(conn: redshift_connector.Connection) -> "RedshiftDataReader": + return RedshiftDataReader(conn) + + def __init__(self, conn: redshift_connector.Connection) -> None: + # The lifecycle of this connection is managed externally + self.conn = conn + + def get_sample_data_for_table( + self, table_id: List[str], sample_size: int, **kwargs: Any + ) -> Dict[str, list]: + """ + For redshift, table_id should be in form (db_name, schema_name, table_name) + """ + + assert len(table_id) == 3 + + db_name = table_id[0] + schema_name = table_id[1] + table_name = table_id[2] + logger.debug( + f"Collecting sample values for table {db_name}.{schema_name}.{table_name}" + ) + + with PerfTimer() as timer, self.conn.cursor() as cursor: + + sql = f"select * from {db_name}.{schema_name}.{table_name} limit {sample_size};" + cursor.execute(sql) + df = cursor.fetch_dataframe() + # Fetch the result set from the cursor and deliver it as the Pandas DataFrame. + time_taken = timer.elapsed_seconds() + logger.debug( + f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};" + f"{df.shape[0]} rows; took {time_taken:.3f} seconds" + ) + return df.to_dict(orient="list") + + def close(self) -> None: + pass diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py index 6c2a12498f2c0d..e2a035091d0ad9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Dict, Optional +from datahub.ingestion.glossary.classification_mixin import ClassificationReportMixin from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport from datahub.ingestion.source_report.time_window import BaseTimeWindowReport @@ -11,7 +12,12 @@ @dataclass -class RedshiftReport(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowReport): +class RedshiftReport( + ProfilingSqlReport, + IngestionStageReport, + BaseTimeWindowReport, + ClassificationReportMixin, +): num_usage_workunits_emitted: Optional[int] = None num_operational_stats_workunits_emitted: Optional[int] = None upstream_lineage: LossyDict = field(default_factory=LossyDict) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py new file mode 100644 index 00000000000000..18fd02e253a69f --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_data_reader.py @@ -0,0 +1,59 @@ +import logging +from typing import Any, Callable, Dict, List + +import pandas as pd +from snowflake.connector import SnowflakeConnection + +from datahub.ingestion.source.sql.data_reader import DataReader +from datahub.utilities.perf_timer import PerfTimer + +logger = logging.Logger(__name__) + + +class SnowflakeDataReader(DataReader): + @staticmethod + def create( + conn: SnowflakeConnection, col_name_preprocessor: Callable[[str], str] + ) -> "SnowflakeDataReader": + return SnowflakeDataReader(conn, col_name_preprocessor) + + def __init__( + self, conn: SnowflakeConnection, col_name_preprocessor: Callable[[str], str] + ) -> None: + # The lifecycle of this connection is managed externally + self.conn = conn + self.col_name_preprocessor = col_name_preprocessor + + def get_sample_data_for_table( + self, table_id: List[str], sample_size: int, **kwargs: Any + ) -> Dict[str, list]: + """ + For snowflake, table_id should be in form (db_name, schema_name, table_name) + """ + + assert len(table_id) == 3 + + db_name = table_id[0] + schema_name = table_id[1] + table_name = table_id[2] + logger.debug( + f"Collecting sample values for table {db_name}.{schema_name}.{table_name}" + ) + + with PerfTimer() as timer, self.conn.cursor() as cursor: + + sql = f'select * from "{db_name}"."{schema_name}"."{table_name}" sample ({sample_size} rows);' + cursor.execute(sql) + dat = cursor.fetchall() + # Fetch the result set from the cursor and deliver it as the Pandas DataFrame. + df = pd.DataFrame(dat, columns=[col.name for col in cursor.description]) + time_taken = timer.elapsed_seconds() + logger.debug( + f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};" + f"{df.shape[0]} rows; took {time_taken:.3f} seconds" + ) + df.columns = [self.col_name_preprocessor(col) for col in df.columns] + return df.to_dict(orient="list") + + def close(self) -> None: + pass diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py index 9526bdec4b05dc..292c57494632c5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -5,7 +5,6 @@ from functools import lru_cache from typing import Dict, List, Optional -import pandas as pd from snowflake.connector import SnowflakeConnection from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain @@ -84,7 +83,6 @@ class SnowflakeTable(BaseTable): foreign_keys: List[SnowflakeFK] = field(default_factory=list) tags: Optional[List[SnowflakeTag]] = None column_tags: Dict[str, List[SnowflakeTag]] = field(default_factory=dict) - sample_data: Optional[pd.DataFrame] = None @dataclass diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 591bdffed58190..9443b2009ce740 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -7,7 +7,6 @@ from functools import partial from typing import Callable, Dict, Iterable, List, Optional, Union -import pandas as pd from snowflake.connector import SnowflakeConnection from datahub.configuration.pattern_utils import is_schema_allowed @@ -37,7 +36,10 @@ TestConnectionReport, ) from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.glossary.classification_mixin import ClassificationHandler +from datahub.ingestion.glossary.classification_mixin import ( + ClassificationHandler, + classification_workunit_processor, +) from datahub.ingestion.source.common.subtypes import ( DatasetContainerSubTypes, DatasetSubTypes, @@ -52,6 +54,7 @@ SnowflakeV2Config, TagOption, ) +from datahub.ingestion.source.snowflake.snowflake_data_reader import SnowflakeDataReader from datahub.ingestion.source.snowflake.snowflake_lineage_v2 import ( SnowflakeLineageExtractor, ) @@ -134,7 +137,6 @@ ) from datahub.metadata.com.linkedin.pegasus2avro.tag import TagProperties from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator -from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.registries.domain_registry import DomainRegistry logger: logging.Logger = logging.getLogger(__name__) @@ -212,6 +214,11 @@ "Optionally enabled via `extract_tags`", supported=True, ) +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class SnowflakeV2Source( SnowflakeQueryMixin, SnowflakeConnectionMixin, @@ -305,10 +312,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): config, self.report, self.profiling_state_handler ) - if self.config.classification.enabled: - self.classification_handler = ClassificationHandler( - self.config, self.report - ) + self.classification_handler = ClassificationHandler(self.config, self.report) # Caches tables for a single database. Consider moving to disk or S3 when possible. self.db_tables: Dict[str, List[SnowflakeTable]] = {} @@ -775,8 +779,17 @@ def _process_schema( self.db_tables[schema_name] = tables if self.config.include_technical_schema: + data_reader = self.make_data_reader() for table in tables: - yield from self._process_table(table, schema_name, db_name) + table_wu_generator = self._process_table( + table, schema_name, db_name + ) + yield from classification_workunit_processor( + table_wu_generator, + self.classification_handler, + data_reader, + [db_name, schema_name, table.name], + ) if self.config.include_views: views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name) @@ -876,6 +889,14 @@ def fetch_tables_for_schema( ) return [] + def make_data_reader(self) -> Optional[SnowflakeDataReader]: + if self.classification_handler.is_classification_enabled() and self.connection: + return SnowflakeDataReader.create( + self.connection, self.snowflake_identifier + ) + + return None + def _process_table( self, table: SnowflakeTable, @@ -890,12 +911,6 @@ def _process_table( self.fetch_foreign_keys_for_table(table, schema_name, db_name, table_identifier) - dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) - - self.fetch_sample_data_for_classification( - table, schema_name, db_name, dataset_name - ) - if self.config.extract_tags != TagOption.skip: table.tags = self.tag_extractor.get_tags_on_object( table_name=table.name, @@ -914,36 +929,6 @@ def _process_table( yield from self.gen_dataset_workunits(table, schema_name, db_name) - def fetch_sample_data_for_classification( - self, table: SnowflakeTable, schema_name: str, db_name: str, dataset_name: str - ) -> None: - if ( - table.columns - and self.config.classification.enabled - and self.classification_handler.is_classification_enabled_for_table( - dataset_name - ) - ): - try: - table.sample_data = self.get_sample_values_for_table( - table.name, schema_name, db_name - ) - except Exception as e: - logger.debug( - f"Failed to get sample values for dataset {dataset_name} due to error {e}", - exc_info=e, - ) - if isinstance(e, SnowflakePermissionError): - self.report_warning( - "Failed to get sample values for dataset. Please grant SELECT permissions on dataset.", - dataset_name, - ) - else: - self.report_warning( - "Failed to get sample values for dataset", - dataset_name, - ) - def fetch_foreign_keys_for_table( self, table: SnowflakeTable, @@ -1073,9 +1058,7 @@ def gen_dataset_workunits( ).as_workunit() schema_metadata = self.gen_schema_metadata(table, schema_name, db_name) - # TODO: classification is only run for snowflake tables. - # Should we run classification for snowflake views as well? - self.classify_snowflake_table(table, dataset_name, schema_metadata) + yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, aspect=schema_metadata ).as_workunit() @@ -1296,47 +1279,6 @@ def build_foreign_keys( ) return foreign_keys - def classify_snowflake_table( - self, - table: Union[SnowflakeTable, SnowflakeView], - dataset_name: str, - schema_metadata: SchemaMetadata, - ) -> None: - if ( - isinstance(table, SnowflakeTable) - and self.config.classification.enabled - and self.classification_handler.is_classification_enabled_for_table( - dataset_name - ) - ): - if table.sample_data is not None: - table.sample_data.columns = [ - self.snowflake_identifier(col) for col in table.sample_data.columns - ] - - try: - self.classification_handler.classify_schema_fields( - dataset_name, - schema_metadata, - ( - table.sample_data.to_dict(orient="list") - if table.sample_data is not None - else {} - ), - ) - except Exception as e: - logger.debug( - f"Failed to classify table columns for {dataset_name} due to error -> {e}", - exc_info=e, - ) - self.report_warning( - "Failed to classify table columns", - dataset_name, - ) - finally: - # Cleaning up sample_data fetched for classification - table.sample_data = None - def get_report(self) -> SourceReport: return self.report @@ -1551,37 +1493,6 @@ def inspect_session_metadata(self) -> None: except Exception: self.report.edition = None - # Ideally we do not want null values in sample data for a column. - # However that would require separate query per column and - # that would be expensive, hence not done. To compensale for possibility - # of some null values in collected sample, we fetch extra (20% more) - # rows than configured sample_size. - def get_sample_values_for_table( - self, table_name: str, schema_name: str, db_name: str - ) -> pd.DataFrame: - # Create a cursor object. - logger.debug( - f"Collecting sample values for table {db_name}.{schema_name}.{table_name}" - ) - - actual_sample_size = self.config.classification.sample_size * 1.2 - with PerfTimer() as timer: - cur = self.get_connection().cursor() - # Execute a statement that will generate a result set. - sql = f'select * from "{db_name}"."{schema_name}"."{table_name}" sample ({actual_sample_size} rows);' - - cur.execute(sql) - # Fetch the result set from the cursor and deliver it as the Pandas DataFrame. - - dat = cur.fetchall() - df = pd.DataFrame(dat, columns=[col.name for col in cur.description]) - time_taken = timer.elapsed_seconds() - logger.debug( - f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};{df.shape[0]} rows; took {time_taken:.3f} seconds" - ) - - return df - # domain is either "view" or "table" def get_external_url_for_table( self, table_name: str, schema_name: str, db_name: str, domain: str diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index c3759875b2769a..eed5b1cb6c9eb8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -291,6 +291,11 @@ def get_sql_alchemy_url(self): ) @capability(SourceCapability.LINEAGE_COARSE, "Supported for S3 tables") @capability(SourceCapability.DESCRIPTIONS, "Enabled by default") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class AthenaSource(SQLAlchemySource): """ This plugin supports extracting the following metadata from Athena diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py b/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py index 84c1d3844a7b48..7d32b5a20df11e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py @@ -380,6 +380,11 @@ def get_columns(self, connection, table_name, schema=None, **kw): @support_status(SupportStatus.CERTIFIED) @capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion") @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class ClickHouseSource(TwoTierSQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/data_reader.py b/metadata-ingestion/src/datahub/ingestion/source/sql/data_reader.py index 73730a9ea0ef73..7c5c1df375ac51 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/data_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/data_reader.py @@ -1,12 +1,11 @@ import logging from abc import abstractmethod from collections import defaultdict -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import sqlalchemy as sa from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.engine.row import LegacyRow from datahub.ingestion.api.closeable import Closeable @@ -14,16 +13,49 @@ class DataReader(Closeable): - @abstractmethod def get_sample_data_for_column( - self, table_id: List[str], column_name: str, sample_size: int = 100 + self, table_id: List[str], column_name: str, sample_size: int ) -> list: - pass + raise NotImplementedError() @abstractmethod def get_sample_data_for_table( - self, table_id: List[str], sample_size: int = 100 + self, + table_id: List[str], + sample_size: int, + *, + sample_size_percent: Optional[float] = None, + filter: Optional[str] = None, ) -> Dict[str, list]: + """ + Fetches table values , approx sample_size rows + + Args: + table_id (List[str]): Table name identifier. One of + - [, , ] or + - [, ] or + - [] + sample_size (int): sample size + + Keyword Args: + sample_size_percent(float, between 0 and 1): For bigquery-like data platforms that provide only + percentage based sampling methods. If present, actual sample_size + may be ignored. + + filter (string): For bigquery-like data platforms that need mandatory filter on partition + column for some cases + + + Returns: + Dict[str, list]: dictionary of (column name -> list of column values) + """ + + # Ideally we do not want null values in sample data for a column. + # However that would require separate query per column and + # that would be expensive, hence not done. To compensate for possibility + # of some null values in collected sample, its usually recommended to + # fetch extra (20% more) rows than configured sample_size. + pass @@ -36,8 +68,7 @@ def __init__( self, conn: Union[Engine, Connection], ) -> None: - # TODO: How can this use a connection pool instead ? - self.engine = conn.engine.connect() + self.connection = conn.engine.connect() def _table(self, table_id: List[str]) -> sa.Table: return sa.Table( @@ -46,91 +77,37 @@ def _table(self, table_id: List[str]) -> sa.Table: schema=table_id[-2] if len(table_id) > 1 else None, ) - def get_sample_data_for_column( - self, table_id: List[str], column_name: str, sample_size: int = 100 - ) -> list: - """ - Fetches non-null column values, upto count - Args: - table_id: Table name identifier. One of - - [, , ] or - - [, ] or - - [] - column: Column name - Returns: - list of column values - """ - - table = self._table(table_id) - query: Any - ignore_null_condition = sa.column(column_name).is_(None) - # limit doesn't compile properly for oracle so we will append rownum to query string later - if self.engine.dialect.name.lower() == "oracle": - raw_query = ( - sa.select([sa.column(column_name)]) - .select_from(table) - .where(sa.not_(ignore_null_condition)) - ) - - query = str( - raw_query.compile(self.engine, compile_kwargs={"literal_binds": True}) - ) - query += "\nAND ROWNUM <= %d" % sample_size - else: - query = ( - sa.select([sa.column(column_name)]) - .select_from(table) - .where(sa.not_(ignore_null_condition)) - .limit(sample_size) - ) - query_results = self.engine.execute(query) - - return [x[column_name] for x in query_results.fetchall()] - def get_sample_data_for_table( - self, table_id: List[str], sample_size: int = 100 + self, table_id: List[str], sample_size: int, **kwargs: Any ) -> Dict[str, list]: - """ - Fetches table values, upto *1.2 count - Args: - table_id: Table name identifier. One of - - [, , ] or - - [, ] or - - [] - Returns: - dictionary of (column name -> list of column values) - """ + + logger.debug(f"Collecting sample values for table {'.'.join(table_id)}") + column_values: Dict[str, list] = defaultdict(list) table = self._table(table_id) - # Ideally we do not want null values in sample data for a column. - # However that would require separate query per column and - # that would be expensiv. To compensate for possibility - # of some null values in collected sample, we fetch extra (20% more) - # rows than configured sample_size. - sample_size = int(sample_size * 1.2) - query: Any # limit doesn't compile properly for oracle so we will append rownum to query string later - if self.engine.dialect.name.lower() == "oracle": + if self.connection.dialect.name.lower() == "oracle": raw_query = sa.select([sa.text("*")]).select_from(table) query = str( - raw_query.compile(self.engine, compile_kwargs={"literal_binds": True}) + raw_query.compile( + self.connection, compile_kwargs={"literal_binds": True} + ) ) query += "\nAND ROWNUM <= %d" % sample_size else: query = sa.select([sa.text("*")]).select_from(table).limit(sample_size) - query_results = self.engine.execute(query) + query_results = self.connection.execute(query) # Not ideal - creates a parallel structure in column_values. Can we use pandas here ? for row in query_results.fetchall(): - if isinstance(row, LegacyRow): - for col, col_value in row.items(): - column_values[col].append(col_value) + for col, col_value in row._mapping.items(): + column_values[col].append(col_value) return column_values def close(self) -> None: - self.engine.close() + self.connection.close() diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py b/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py index 3f20e0a0f18b65..fdec869baa5830 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py @@ -61,6 +61,11 @@ def get_identifier(self, schema: str, table: str) -> str: @config_class(DruidConfig) @support_status(SupportStatus.INCUBATING) @capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class DruidSource(SQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py b/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py index 5c9c8f063a1a9e..40875809120de3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py @@ -28,6 +28,11 @@ class HanaConfig(BasicSQLAlchemyConfig): @capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") @capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class HanaSource(SQLAlchemySource): def __init__(self, config: HanaConfig, ctx: PipelineContext): super().__init__(config, ctx, "hana") diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py b/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py index 003732236ba80c..2975bfe820d1b6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py @@ -134,6 +134,11 @@ def clean_host_port(cls, v): @support_status(SupportStatus.CERTIFIED) @capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default") @capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class HiveSource(TwoTierSQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py index 9b482beba924f9..f3e2cccb9e8d04 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py @@ -66,6 +66,11 @@ def get_identifier(self, *, schema: str, table: str) -> str: @capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") @capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class MySQLSource(TwoTierSQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py index bcf0f26008ae30..cf7bdc982ee808 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py @@ -560,6 +560,11 @@ def __getattr__(self, item: str) -> Any: @config_class(OracleConfig) @support_status(SupportStatus.INCUBATING) @capability(SourceCapability.DOMAINS, "Enabled by default") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class OracleSource(SQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py index 5d1e37fbb68a37..20976c91f78789 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py @@ -132,6 +132,11 @@ class PostgresConfig(BasePostgresConfig): @capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default") @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") @capability(SourceCapability.LINEAGE_COARSE, "Optionally enabled via configuration") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class PostgresSource(SQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/presto_on_hive.py b/metadata-ingestion/src/datahub/ingestion/source/sql/presto_on_hive.py index 9657fdab9e2e31..98e2f2ecfbd5a7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/presto_on_hive.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/presto_on_hive.py @@ -160,6 +160,11 @@ def get_sql_alchemy_url( @support_status(SupportStatus.CERTIFIED) @capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion") @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class PrestoOnHiveSource(SQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index 9ec30d57b8f762..91736b24727c8d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -820,13 +820,15 @@ def _classify( dataset_name ) and data_reader + and schema_metadata.fields ): self.classification_handler.classify_schema_fields( dataset_name, schema_metadata, - data_reader.get_sample_data_for_table( - table_id=[schema, table], - sample_size=self.config.classification.sample_size, + partial( + data_reader.get_sample_data_for_table, + [schema, table], + int(self.config.classification.sample_size * 1.2), ), ) except Exception as e: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py b/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py index 4f2fc799ecc304..af61a28c9a6188 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py @@ -444,6 +444,11 @@ class TeradataConfig(BaseTeradataConfig, BaseTimeWindowConfig): @capability(SourceCapability.LINEAGE_COARSE, "Optionally enabled via configuration") @capability(SourceCapability.LINEAGE_FINE, "Optionally enabled via configuration") @capability(SourceCapability.USAGE_STATS, "Optionally enabled via configuration") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class TeradataSource(TwoTierSQLAlchemySource): """ This plugin extracts the following: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py index 7668cb01f84bc8..1828c5101d4f3c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py @@ -226,6 +226,11 @@ def get_identifier(self: BasicSQLAlchemyConfig, schema: str, table: str) -> str: @support_status(SupportStatus.CERTIFIED) @capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") @capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration") +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class TrinoSource(SQLAlchemySource): """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py index 32f1ba5b8d5635..9800660a9ad545 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py @@ -120,6 +120,11 @@ def clean_host_port(cls, v): "Optionally enabled via `stateful_ingestion.remove_stale_metadata`", supported=True, ) +@capability( + SourceCapability.CLASSIFICATION, + "Optionally enabled via `classification.enabled`", + supported=True, +) class VerticaSource(SQLAlchemySource): def __init__(self, config: VerticaConfig, ctx: PipelineContext): # self.platform = platform diff --git a/metadata-ingestion/tests/integration/bigquery_v2/bigquery_mcp_golden.json b/metadata-ingestion/tests/integration/bigquery_v2/bigquery_mcp_golden.json index da9589d2195ac6..f8763d48d35ef9 100644 --- a/metadata-ingestion/tests/integration/bigquery_v2/bigquery_mcp_golden.json +++ b/metadata-ingestion/tests/integration/bigquery_v2/bigquery_mcp_golden.json @@ -236,7 +236,62 @@ "tableSchema": "" } }, - "fields": [] + "fields": [ + { + "fieldPath": "age", + "nullable": false, + "description": "comment", + "type": { + "type": { + "com.linkedin.schema.NumberType": {} + } + }, + "nativeDataType": "INT", + "recursive": false, + "globalTags": { + "tags": [] + }, + "glossaryTerms": { + "terms": [ + { + "urn": "urn:li:glossaryTerm:Age" + } + ], + "auditStamp": { + "time": 1643871600000, + "actor": "urn:li:corpuser:datahub" + } + }, + "isPartOfKey": false + }, + { + "fieldPath": "email", + "nullable": false, + "description": "comment", + "type": { + "type": { + "com.linkedin.schema.StringType": {} + } + }, + "nativeDataType": "STRING", + "recursive": false, + "globalTags": { + "tags": [] + }, + "glossaryTerms": { + "terms": [ + { + "urn": "urn:li:glossaryTerm:Email_Address" + } + ], + "auditStamp": { + "time": 1643871600000, + "actor": "urn:li:corpuser:datahub" + } + }, + "isPartOfKey": false + } + ] } }, "systemMetadata": { diff --git a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py index 602401134dcd30..e79bbbe995aae0 100644 --- a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py +++ b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py @@ -1,11 +1,20 @@ +import random +import string from typing import Any, Dict from unittest.mock import patch from freezegun import freeze_time from google.cloud.bigquery.table import TableListItem +from datahub.ingestion.glossary.classifier import ( + ClassificationConfig, + DynamicTypedClassifierConfig, +) +from datahub.ingestion.glossary.datahub_classifier import DataHubClassifierConfig from datahub.ingestion.source.bigquery_v2.bigquery import BigqueryV2Source +from datahub.ingestion.source.bigquery_v2.bigquery_data_reader import BigQueryDataReader from datahub.ingestion.source.bigquery_v2.bigquery_schema import ( + BigqueryColumn, BigqueryDataset, BigQuerySchemaApi, BigqueryTable, @@ -16,13 +25,29 @@ FROZEN_TIME = "2022-02-03 07:00:00" +def random_email(): + return ( + "".join( + [ + random.choice(string.ascii_lowercase) + for i in range(random.randint(10, 15)) + ] + ) + + "@xyz.com" + ) + + @freeze_time(FROZEN_TIME) @patch.object(BigQuerySchemaApi, "get_tables_for_dataset") @patch.object(BigqueryV2Source, "get_core_table_details") @patch.object(BigQuerySchemaApi, "get_datasets_for_project_id") +@patch.object(BigQuerySchemaApi, "get_columns_for_dataset") +@patch.object(BigQueryDataReader, "get_sample_data_for_table") @patch("google.cloud.bigquery.Client") def test_bigquery_v2_ingest( client, + get_sample_data_for_table, + get_columns_for_dataset, get_datasets_for_project_id, get_core_table_details, get_tables_for_dataset, @@ -42,6 +67,34 @@ def test_bigquery_v2_ingest( ) table_name = "table-1" get_core_table_details.return_value = {table_name: table_list_item} + get_columns_for_dataset.return_value = { + table_name: [ + BigqueryColumn( + name="age", + ordinal_position=1, + is_nullable=False, + field_path="col_1", + data_type="INT", + comment="comment", + is_partition_column=False, + cluster_column_position=None, + ), + BigqueryColumn( + name="email", + ordinal_position=1, + is_nullable=False, + field_path="col_2", + data_type="STRING", + comment="comment", + is_partition_column=False, + cluster_column_position=None, + ), + ] + } + get_sample_data_for_table.return_value = { + "age": [random.randint(1, 80) for i in range(20)], + "email": [random_email() for i in range(20)], + } bigquery_table = BigqueryTable( name=table_name, @@ -58,6 +111,18 @@ def test_bigquery_v2_ingest( "include_usage_statistics": False, "include_table_lineage": False, "include_data_platform_instance": True, + "classification": ClassificationConfig( + enabled=True, + classifiers=[ + DynamicTypedClassifierConfig( + type="datahub", + config=DataHubClassifierConfig( + minimum_values_threshold=1, + ), + ) + ], + max_workers=1, + ).dict(), } pipeline_config_dict: Dict[str, Any] = { diff --git a/metadata-ingestion/tests/integration/mysql/mysql_to_file_dbalias.yml b/metadata-ingestion/tests/integration/mysql/mysql_to_file_dbalias.yml new file mode 100644 index 00000000000000..3c8e41650bd0f5 --- /dev/null +++ b/metadata-ingestion/tests/integration/mysql/mysql_to_file_dbalias.yml @@ -0,0 +1,46 @@ +run_id: mysql-test + +source: + type: mysql + config: + username: root + password: example + database: metagalaxy + host_port: localhost:53307 + schema_pattern: + allow: + - "^metagalaxy" + - "^northwind" + - "^datacharmer" + - "^test_cases" + profile_pattern: + allow: + - "^northwind" + - "^datacharmer" + - "^test_cases" + profiling: + enabled: True + include_field_null_count: true + include_field_min_value: true + include_field_max_value: true + include_field_mean_value: true + include_field_median_value: true + include_field_stddev_value: true + include_field_quantiles: true + include_field_distinct_value_frequencies: true + include_field_histogram: true + include_field_sample_values: true + domain: + "urn:li:domain:sales": + allow: + - "^metagalaxy" + classification: + enabled: True + classifiers: + - type: datahub + config: + minimum_values_threshold: 1 +sink: + type: file + config: + filename: "./mysql_mces_dbalias.json" diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py index 88354ba74c417d..81487d38eda7d0 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py @@ -4,7 +4,6 @@ from typing import cast from unittest import mock -import pandas as pd import pytest from freezegun import freeze_time @@ -65,7 +64,7 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): golden_file = test_resources_dir / "snowflake_golden.json" with mock.patch("snowflake.connector.connect") as mock_connect, mock.patch( - "datahub.ingestion.source.snowflake.snowflake_v2.SnowflakeV2Source.get_sample_values_for_table" + "datahub.ingestion.source.snowflake.snowflake_data_reader.SnowflakeDataReader.get_sample_data_for_table" ) as mock_sample_values: sf_connection = mock.MagicMock() sf_cursor = mock.MagicMock() @@ -74,13 +73,11 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): sf_cursor.execute.side_effect = default_query_results - mock_sample_values.return_value = pd.DataFrame( - data={ - "col_1": [random.randint(1, 80) for i in range(20)], - "col_2": [random_email() for i in range(20)], - "col_3": [random_cloud_region() for i in range(20)], - } - ) + mock_sample_values.return_value = { + "col_1": [random.randint(1, 80) for i in range(20)], + "col_2": [random_email() for i in range(20)], + "col_3": [random_cloud_region() for i in range(20)], + } datahub_classifier_config = DataHubClassifierConfig( minimum_values_threshold=10, diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py index 427b6e562ebd16..75a9df4f280512 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_classification.py @@ -91,7 +91,8 @@ def test_snowflake_classification_perf(num_workers, num_cols_per_table, num_tabl source_report = pipeline.source.get_report() assert isinstance(source_report, SnowflakeV2Report) assert ( - cast(SnowflakeV2Report, source_report).num_tables_classified == num_tables + cast(SnowflakeV2Report, source_report).num_tables_classification_found + == num_tables ) assert ( len(