diff --git a/pipelines/setup.py b/pipelines/setup.py index ae9866d9..1b4e09e9 100644 --- a/pipelines/setup.py +++ b/pipelines/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="whale-pipelines", - version="1.1.3", + version="1.1.4", author="Robert Yi", author_email="robert@ryi.me", description="A pared-down metadata scraper + SQL runner.", diff --git a/pipelines/tests/unit/extractor/test_glue_extractor.py b/pipelines/tests/unit/extractor/test_glue_extractor.py new file mode 100644 index 00000000..eb24335f --- /dev/null +++ b/pipelines/tests/unit/extractor/test_glue_extractor.py @@ -0,0 +1,248 @@ +import logging +import unittest + +from mock import patch +from pyhocon import ConfigFactory + +from whale.extractor.glue_extractor import GlueExtractor +from whale.models.table_metadata import TableMetadata, ColumnMetadata + + +@patch("whale.extractor.glue_extractor.boto3.client", lambda x: None) +class TestGlueExtractor(unittest.TestCase): + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + + self.conf = ConfigFactory.from_dict({}) + + def test_extraction_with_empty_query_result(self) -> None: + """ + Test Extraction with empty result from query + """ + with patch.object(GlueExtractor, "_search_tables"): + extractor = GlueExtractor() + extractor.init(self.conf) + + results = extractor.extract() + self.assertEqual(results, None) + + def test_extraction_with_single_result(self) -> None: + with patch.object(GlueExtractor, "_search_tables") as mock_search: + mock_search.return_value = [ + { + "Name": "test_catalog_test_schema_test_table", + "DatabaseName": "test_database", + "Description": "a table for testing", + "StorageDescriptor": { + "Columns": [ + { + "Name": "col_id1", + "Type": "bigint", + "Comment": "description of id1", + }, + { + "Name": "col_id2", + "Type": "bigint", + "Comment": "description of id2", + }, + {"Name": "is_active", "Type": "boolean"}, + { + "Name": "source", + "Type": "varchar", + "Comment": "description of source", + }, + { + "Name": "etl_created_at", + "Type": "timestamp", + "Comment": "description of etl_created_at", + }, + {"Name": "ds", "Type": "varchar"}, + ], + "Location": "test_catalog.test_schema.test_table", + }, + "PartitionKeys": [ + { + "Name": "partition_key1", + "Type": "string", + "Comment": "description of partition_key1", + }, + ], + "TableType": "EXTERNAL_TABLE", + } + ] + + extractor = GlueExtractor() + extractor.init(self.conf) + actual = extractor.extract() + expected = TableMetadata( + "test_database", + "test_catalog", + "test_schema", + "test_table", + "a table for testing", + [ + ColumnMetadata("col_id1", "description of id1", "bigint", 0), + ColumnMetadata("col_id2", "description of id2", "bigint", 1), + ColumnMetadata("is_active", None, "boolean", 2), + ColumnMetadata("source", "description of source", "varchar", 3), + ColumnMetadata( + "etl_created_at", + "description of etl_created_at", + "timestamp", + 4, + ), + ColumnMetadata("ds", None, "varchar", 5), + ColumnMetadata( + "partition_key1", "description of partition_key1", "string", 6 + ), + ], + False, + ) + self.assertEqual(expected.__repr__(), actual.__repr__()) + self.assertIsNone(extractor.extract()) + + def test_extraction_with_multiple_result(self) -> None: + with patch.object(GlueExtractor, "_search_tables") as mock_search: + mock_search.return_value = [ + { + "Name": "test_catalog_test_schema_test_table", + "DatabaseName": "test_database", + "Description": "test table", + "StorageDescriptor": { + "Columns": [ + { + "Name": "col_id1", + "Type": "bigint", + "Comment": "description of col_id1", + }, + { + "Name": "col_id2", + "Type": "bigint", + "Comment": "description of col_id2", + }, + {"Name": "is_active", "Type": "boolean"}, + { + "Name": "source", + "Type": "varchar", + "Comment": "description of source", + }, + { + "Name": "etl_created_at", + "Type": "timestamp", + "Comment": "description of etl_created_at", + }, + {"Name": "ds", "Type": "varchar"}, + ], + "Location": "test_catalog.test_schema.test_table", + }, + "PartitionKeys": [ + { + "Name": "partition_key1", + "Type": "string", + "Comment": "description of partition_key1", + }, + ], + "TableType": "EXTERNAL_TABLE", + }, + { + "Name": "test_catalog1_test_schema1_test_table1", + "DatabaseName": "test_database", + "Description": "test table 1", + "StorageDescriptor": { + "Columns": [ + { + "Name": "col_name", + "Type": "varchar", + "Comment": "description of col_name", + }, + ], + "Location": "test_catalog1.test_schema1.test_table1", + }, + "Parameters": { + "comment": "description of test table 3 from comment" + }, + "TableType": "EXTERNAL_TABLE", + }, + { + "Name": "test_catalog_test_schema_test_view", + "DatabaseName": "test_database", + "Description": "test view 1", + "StorageDescriptor": { + "Columns": [ + { + "Name": "col_id3", + "Type": "varchar", + "Comment": "description of col_id3", + }, + { + "Name": "col_name3", + "Type": "varchar", + "Comment": "description of col_name3", + }, + ], + "Location": "test_catalog.test_schema.test_view", + }, + "TableType": "VIRTUAL_VIEW", + }, + ] + + extractor = GlueExtractor() + extractor.init(self.conf) + + expected = TableMetadata( + "test_database", + "test_catalog", + "test_schema", + "test_table", + "test table", + [ + ColumnMetadata("col_id1", "description of col_id1", "bigint", 0), + ColumnMetadata("col_id2", "description of col_id2", "bigint", 1), + ColumnMetadata("is_active", None, "boolean", 2), + ColumnMetadata("source", "description of source", "varchar", 3), + ColumnMetadata( + "etl_created_at", + "description of etl_created_at", + "timestamp", + 4, + ), + ColumnMetadata("ds", None, "varchar", 5), + ColumnMetadata( + "partition_key1", "description of partition_key1", "string", 6 + ), + ], + False, + ) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata( + "test_database", + "test_catalog1", + "test_schema1", + "test_table1", + "test table 1", + [ + ColumnMetadata("col_name", "description of col_name", "varchar", 0), + ], + False, + ) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + expected = TableMetadata( + "test_database", + "test_catalog", + "test_schema", + "test_view", + "test view 1", + [ + ColumnMetadata("col_id3", "description of col_id3", "varchar", 0), + ColumnMetadata( + "col_name3", "description of col_name3", "varchar", 1 + ), + ], + True, + ) + self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) + + self.assertIsNone(extractor.extract()) + self.assertIsNone(extractor.extract()) diff --git a/pipelines/whale/extractor/glue_extractor.py b/pipelines/whale/extractor/glue_extractor.py new file mode 100644 index 00000000..7fdd0cef --- /dev/null +++ b/pipelines/whale/extractor/glue_extractor.py @@ -0,0 +1,116 @@ +import boto3 +from databuilder.extractor.base_extractor import Extractor +from pyhocon import ConfigFactory, ConfigTree +from typing import Any, Dict, Iterator, List, Union +from whale.models.table_metadata import TableMetadata, ColumnMetadata + + +class GlueExtractor(Extractor): + """ + Extracts metadata from AWS glue. Adapted from Amundsen's glue extractor. + """ + + CONNECTION_NAME_KEY = "connection_name" + FILTER_KEY = "filters" + DEFAULT_CONFIG = ConfigFactory.from_dict({FILTER_KEY: None}) + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(GlueExtractor.DEFAULT_CONFIG) + self._filters = conf.get(GlueExtractor.FILTER_KEY) + self._glue = boto3.client("glue") + self._extract_iter: Union[None, Iterator] = None + self._connection_name = conf.get(GlueExtractor.CONNECTION_NAME_KEY, None) or "" + + def extract(self) -> Union[TableMetadata, None]: + if not self._extract_iter: + self._extract_iter = self._get_extract_iter() + try: + return next(self._extract_iter) + except StopIteration: + return None + + def get_scope(self) -> str: + return "extractor.glue" + + def _get_extract_iter(self) -> Iterator[TableMetadata]: + for row in self._get_raw_extract_iter(): + columns, i = [], 0 + + for column in row["StorageDescriptor"]["Columns"] + row.get( + "PartitionKeys", [] + ): + columns.append( + ColumnMetadata( + column["Name"], + column["Comment"] if "Comment" in column else None, + column["Type"], + i, + ) + ) + i += 1 + + catalog, schema, table = self._parse_location( + location=row["StorageDescriptor"]["Location"], name=row["Name"] + ) + + if self._connection_name: + database = self._connection_name + "/" + row["DatabaseName"] + else: + database = row["DatabaseName"] + + yield TableMetadata( + database, + catalog, + schema, + table, + row.get("Description") or row.get("Parameters", {}).get("comment"), + columns, + row.get("TableType") == "VIRTUAL_VIEW", + ) + + def _parse_location(self, location, name): + + """ + Location is formatted in glue as `catalog.schema.table`, while name + is formatted as `catalog_schema_table`. To determine what the catalog, + schema, and table are, then, (particularly in the case where catalogs, + schemas, and tables can have underscores and/or periods), we need to + find points where location has a `.`, while name has a `_`.""" + + start_index = 0 + splits = [] + for end_index, (location_character, name_character) in enumerate( + zip(location, name) + ): + if location_character == "." and name_character == "_": + splits.append(location[start_index:end_index]) + start_index = end_index + 1 + elif end_index == len(location) - 1: + splits.append(location[start_index:]) + + table = splits[-1] + schema = splits[-2] + if len(splits) == 3: + catalog = splits[-3] + else: + catalog = None + + return catalog, schema, table + + def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]: + tables = self._search_tables() + return iter(tables) + + def _search_tables(self) -> List[Dict[str, Any]]: + tables = [] + kwargs = {} + if self._filters is not None: + kwargs["Filters"] = self._filters + data = self._glue.search_tables(**kwargs) + tables += data["TableList"] + while "NextToken" in data: + token = data["NextToken"] + kwargs["NextToken"] = token + data = self._glue.search_tables(**kwargs) + tables += data["TableList"] + return tables diff --git a/pipelines/whale/loader/whale_loader.py b/pipelines/whale/loader/whale_loader.py index a7334c93..07d8d957 100644 --- a/pipelines/whale/loader/whale_loader.py +++ b/pipelines/whale/loader/whale_loader.py @@ -76,7 +76,12 @@ def load(self, record) -> None: schema = record.schema cluster = record.cluster - database = self.database_name or record.database + if ( + "/" in record.database + ): # TODO: In general, we should always use self.database_name, unless we override the amundsen extractor and add subdirectories + database = record.database + else: # ... so we have to do this. + database = self.database_name or record.database if cluster == "None": # edge case for Hive Metastore cluster = None diff --git a/pipelines/whale/utils/extractor_wrappers.py b/pipelines/whale/utils/extractor_wrappers.py index 39eb7e0e..e50ad04c 100644 --- a/pipelines/whale/utils/extractor_wrappers.py +++ b/pipelines/whale/utils/extractor_wrappers.py @@ -10,11 +10,11 @@ from whale.extractor.bigquery_metadata_extractor import BigQueryMetadataExtractor from whale.extractor.spanner_metadata_extractor import SpannerMetadataExtractor from whale.extractor.bigquery_watermark_extractor import BigQueryWatermarkExtractor +from whale.extractor.glue_extractor import GlueExtractor from whale.extractor.snowflake_metadata_extractor import SnowflakeMetadataExtractor from whale.extractor.metric_runner import MetricRunner from whale.engine.sql_alchemy_engine import SQLAlchemyEngine from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor -from databuilder.extractor.glue_extractor import GlueExtractor from databuilder.extractor.hive_table_metadata_extractor import ( HiveTableMetadataExtractor, ) @@ -105,7 +105,7 @@ def configure_glue_extractors(connection: ConnectionConfigSchema): conf = ConfigFactory.from_dict( { - f"{scope}.{Extractor.CLUSTER_KEY}": connection.cluster, + f"{scope}.{Extractor.CONNECTION_NAME_KEY}": connection.name, f"{scope}.{Extractor.FILTER_KEY}": connection.filter_key, } )