From 25b2d801f62f70201ca4e86cd3e2629ba9b90244 Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 2 Aug 2023 20:23:51 -0400 Subject: [PATCH] [PECO-921] (Reimplemented) Fix: allow DESCRIBE TABLE EXTENDED to handle more than 2048 characters (#405) --------- Signed-off-by: Jesse Whitehouse --- CHANGELOG.md | 4 +++ dbt/adapters/databricks/impl.py | 20 ++++++++++- tests/unit/test_adapter.py | 64 +++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb90832b..a2554fe3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## dbt-databricks 1.6.x (Release TBD) +### Features + +- Follow up: re-implement fix for issue where the show tables extended command is limited to 2048 characters. ([#326](https://github.com/databricks/dbt-databricks/pull/326)). Set `DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS` to `true` to enable this behaviour. + ## dbt-databricks 1.6.1 (August 2, 2023) ### Fixes diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 95e4cf5d..38d8488f 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from itertools import chain from dataclasses import dataclass +import os import re from typing import ( Any, @@ -79,6 +80,22 @@ def check_not_found_error(errmsg: str) -> bool: return new_error or old_error is not None +def get_identifier_list_string(table_names: Set[str]) -> str: + """Returns `"|".join(table_names)` by default. + + Returns `"*"` if `DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS` == `"true"` + and the joined string exceeds 2048 characters + + This is for AWS Glue Catalog users. See issue #325. + """ + + _identifier = "|".join(table_names) + bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false") + if bypass_2048_char_limit == "true": + _identifier = _identifier if len(_identifier) < 2048 else "*" + return _identifier + + @undefined_proof class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation @@ -448,11 +465,12 @@ def _get_one_catalog( table_names.add(relation.identifier) columns: List[Dict[str, Any]] = [] + if len(table_names) > 0: schema_relation = self.Relation.create( database=database, schema=schema, - identifier="|".join(table_names), + identifier=get_identifier_list_string(table_names), quote_policy=self.config.quoting, ) for relation, information in self._list_relations_with_information(schema_relation): diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 3cb49a91..7772b659 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -8,6 +8,7 @@ from dbt.adapters.databricks import __version__ from dbt.adapters.databricks import DatabricksAdapter, DatabricksRelation from dbt.adapters.databricks.impl import check_not_found_error +from dbt.adapters.databricks.impl import get_identifier_list_string from dbt.adapters.databricks.connections import ( CATALOG_KEY_IN_SESSION_PROPERTIES, DBT_DATABRICKS_INVOCATION_ENV, @@ -947,6 +948,69 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel }, ) + def test_describe_table_extended_2048_char_limit(self): + """GIVEN a list of table_names whos total character length exceeds 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is "true" + THEN the identifier list is replaced with "*" + """ + + table_names: set(str) = set([f"customers_{i}" for i in range(200)]) + + # By default, don't limit the number of characters + self.assertEqual(get_identifier_list_string(table_names), "|".join(table_names)) + + # If environment variable is set, then limit the number of characters + with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + + # Long list of table names is capped + self.assertEqual(get_identifier_list_string(table_names), "*") + + # Short list of table names is not capped + self.assertEqual( + get_identifier_list_string(list(table_names)[:5]), "|".join(list(table_names)[:5]) + ) + + def test_describe_table_extended_should_not_limit(self): + """GIVEN a list of table_names whos total character length exceeds 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is not set + THEN the identifier list is not truncated + """ + + table_names: set(str) = set([f"customers_{i}" for i in range(200)]) + + # By default, don't limit the number of characters + self.assertEqual(get_identifier_list_string(table_names), "|".join(table_names)) + + def test_describe_table_extended_should_limit(self): + """GIVEN a list of table_names whos total character length exceeds 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is "true" + THEN the identifier list is replaced with "*" + """ + + table_names: set(str) = set([f"customers_{i}" for i in range(200)]) + + # If environment variable is set, then limit the number of characters + with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + + # Long list of table names is capped + self.assertEqual(get_identifier_list_string(table_names), "*") + + def test_describe_table_extended_may_limit(self): + """GIVEN a list of table_names whos total character length does not 2048 characters + WHEN the environment variable DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS is "true" + THEN the identifier list is not truncated + """ + + table_names: set(str) = set([f"customers_{i}" for i in range(200)]) + + # If environment variable is set, then we may limit the number of characters + with mock.patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + + # But a short list of table names is not capped + self.assertEqual( + get_identifier_list_string(list(table_names)[:5]), "|".join(list(table_names)[:5]) + ) + class TestCheckNotFound(unittest.TestCase): def test_prefix(self):