Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): handle mssql casing issues in lineage #11920

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@
# See more below:
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
"teradata",
# For SQL server, the default collation rules mean that all identifiers (schema, table, column names)
# are case preserving but case insensitive.
"mssql",
}
DIALECTS_WITH_DEFAULT_UPPERCASE_COLS = {
# In some dialects, column identifiers are effectively case insensitive
# because they are automatically converted to uppercase. Most other systems
# automatically lowercase unquoted identifiers.
"snowflake",
}
assert DIALECTS_WITH_DEFAULT_UPPERCASE_COLS.issubset(
DIALECTS_WITH_CASE_INSENSITIVE_COLS
)


class QueryType(enum.Enum):
Expand Down
32 changes: 25 additions & 7 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import traceback
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union

import pydantic.dataclasses
import sqlglot
Expand Down Expand Up @@ -873,6 +873,27 @@ def _translate_internal_column_lineage(
)


_StrOrNone = TypeVar("_StrOrNone", str, Optional[str])


def _normalize_db_or_schema(
db_or_schema: _StrOrNone,
dialect: sqlglot.Dialect,
) -> _StrOrNone:
if db_or_schema is None:
return None

# In snowflake, table identifiers must be uppercased to match sqlglot's behavior.
if is_dialect_instance(dialect, "snowflake"):
return db_or_schema.upper()

# In mssql, table identifiers must be lowercased.
elif is_dialect_instance(dialect, "mssql"):
return db_or_schema.lower()

return db_or_schema


def _sqlglot_lineage_inner(
sql: sqlglot.exp.ExpOrStr,
schema_resolver: SchemaResolverInterface,
Expand All @@ -885,12 +906,9 @@ def _sqlglot_lineage_inner(
else:
dialect = get_dialect(default_dialect)

if is_dialect_instance(dialect, "snowflake"):
# in snowflake, table identifiers must be uppercased to match sqlglot's behavior.
if default_db:
default_db = default_db.upper()
if default_schema:
default_schema = default_schema.upper()
default_db = _normalize_db_or_schema(default_db, dialect)
default_schema = _normalize_db_or_schema(default_schema, dialect)

if is_dialect_instance(dialect, "redshift") and not default_schema:
# On Redshift, there's no "USE SCHEMA <schema>" command. The default schema
# is public, and "current schema" is the one at the front of the search path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def is_dialect_instance(
else:
platforms = list(platforms)

dialects = [sqlglot.Dialect.get_or_raise(platform) for platform in platforms]
dialects = [get_dialect(platform) for platform in platforms]

if any(isinstance(dialect, dialect_class.__class__) for dialect_class in dialects):
return True
Expand Down
Loading
Loading