Skip to content

Commit

Permalink
fix(spark): CREATE TABLE ... PARTITIONED BY fixes (#2937)
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon authored Feb 9, 2024
1 parent 159da45 commit 76d6634
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 72 deletions.
22 changes: 0 additions & 22 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,28 +654,6 @@ def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) ->
return _time_format


def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
columns are removed from the create statement.
"""
has_schema = isinstance(expression.this, exp.Schema)
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")

if has_schema and is_partitionable:
prop = expression.find(exp.PartitionedByProperty)
if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in prop.this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)

return self.create_sql(expression)


def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.List], E]:
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
no_trycast_sql,
rename_func,
str_position_sql,
timestrtotime_sql,
)
from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by


def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
Expand Down Expand Up @@ -125,7 +125,7 @@ class Generator(generator.Generator):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
Expand Down
43 changes: 13 additions & 30 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
if_sql,
Expand All @@ -32,6 +31,12 @@
timestrtotime_sql,
var_map_sql,
)
from sqlglot.transforms import (
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
preprocess,
move_schema_columns_to_partitioned_by,
)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
Expand All @@ -55,30 +60,6 @@
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")


def _create_sql(self, expression: exp.Create) -> str:
# remove UNIQUE column constraints
for constraint in expression.find_all(exp.UniqueColumnConstraint):
if constraint.parent:
constraint.parent.pop()

properties = expression.args.get("properties")
temporary = any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
)

# CTAS with temp tables map to CREATE TEMPORARY VIEW
kind = expression.args["kind"]
if kind.upper() == "TABLE" and temporary:
if expression.expression:
return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}"
else:
# CREATE TEMPORARY TABLE may require storage provider
expression = self.temporary_storage_provider(expression)

return create_with_partitions_sql(self, expression)


def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
if isinstance(expression, exp.TsOrDsAdd) and not expression.unit:
return self.func("DATE_ADD", expression.this, expression.expression)
Expand Down Expand Up @@ -518,7 +499,13 @@ class Generator(generator.Generator):
"" if e.args.get("allow_null") else "NOT NULL"
),
exp.VarMap: var_map_sql,
exp.Create: _create_sql,
exp.Create: preprocess(
[
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
move_schema_columns_to_partitioned_by,
]
),
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
Expand Down Expand Up @@ -581,10 +568,6 @@ def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:

return super()._jsonpathkey_sql(expression)

def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
# Hive has no temporary storage provider (there are hive settings though)
return expression

def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
Expand Down
28 changes: 27 additions & 1 deletion sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.hive import _parse_ignore_nulls
from sqlglot.dialects.spark2 import Spark2
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
from sqlglot.transforms import (
ctas_with_tmp_tables_to_create_tmp_view,
remove_unique_constraints,
preprocess,
move_partitioned_by_to_schema_columns,
)


def _parse_datediff(args: t.List) -> exp.Expression:
Expand Down Expand Up @@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression:
)


def _normalize_partition(e: exp.Expression) -> exp.Expression:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
return exp.to_identifier(e)
if isinstance(e, exp.Literal):
return exp.to_identifier(e.name)
return e


class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
Expand Down Expand Up @@ -72,6 +87,17 @@ class Generator(Spark2.Generator):

TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.Create: preprocess(
[
remove_unique_constraints,
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
e, temporary_storage_provider
),
move_partitioned_by_to_schema_columns,
]
),
exp.PartitionedByProperty: lambda self,
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
Expand Down
28 changes: 22 additions & 6 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
from sqlglot.transforms import (
preprocess,
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
move_schema_columns_to_partitioned_by,
)


def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
Expand Down Expand Up @@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
return expression


def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
# spark2, spark, Databricks require a storage provider for temporary tables
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
expression.args["properties"].append("expressions", provider)
return expression


class Spark2(Hive):
class Parser(Hive.Parser):
TRIM_PATTERN_FIRST = True
Expand Down Expand Up @@ -193,6 +206,15 @@ class Generator(Hive.Generator):
e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.Create: preprocess(
[
remove_unique_constraints,
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
e, temporary_storage_provider
),
move_schema_columns_to_partitioned_by,
]
),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
Expand Down Expand Up @@ -251,12 +273,6 @@ def struct_sql(self, expression: exp.Struct) -> str:

return self.func("STRUCT", *args)

def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
# spark2, spark, Databricks require a storage provider for temporary tables
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
expression.args["properties"].append("expressions", provider)
return expression

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
schema = f"'{self.sql(expression, 'to')}'"
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,11 @@ class Create(DDL):
"clone": False,
}

@property
def kind(self) -> t.Optional[str]:
kind = self.args.get("kind")
return kind and kind.upper()


# https://docs.snowflake.com/en/sql-reference/sql/create-clone
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
Expand Down
81 changes: 81 additions & 0 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,87 @@ def unqualify_columns(expression: exp.Expression) -> exp.Expression:
return expression


def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
assert isinstance(expression, exp.Create)
for constraint in expression.find_all(exp.UniqueColumnConstraint):
if constraint.parent:
constraint.parent.pop()

return expression


def ctas_with_tmp_tables_to_create_tmp_view(
expression: exp.Expression,
tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
) -> exp.Expression:
assert isinstance(expression, exp.Create)
properties = expression.args.get("properties")
temporary = any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
)

# CTAS with temp tables map to CREATE TEMPORARY VIEW
if expression.kind == "TABLE" and temporary:
if expression.expression:
return exp.Create(
kind="TEMPORARY VIEW",
this=expression.this,
expression=expression.expression,
)
return tmp_storage_provider(expression)

return expression


def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
"""
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema.
The corresponding columns are removed from the create statement.
"""
assert isinstance(expression, exp.Create)
has_schema = isinstance(expression.this, exp.Schema)
is_partitionable = expression.kind in {"TABLE", "VIEW"}

if has_schema and is_partitionable:
prop = expression.find(exp.PartitionedByProperty)
if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in prop.this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
expression.set("this", schema)

return expression


def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
"""
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
"""
assert isinstance(expression, exp.Create)
prop = expression.find(exp.PartitionedByProperty)
if (
prop
and prop.this
and isinstance(prop.this, exp.Schema)
and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
):
prop_this = exp.Tuple(
expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
)
schema = expression.this
for e in prop.this.expressions:
schema.append("expressions", e)
prop.set("this", prop_this)

return expression


def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_ddl(self):
"duckdb": "CREATE TABLE x (w TEXT)", # Partition columns should exist in table
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])",
"hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING, y INT, z INT) PARTITIONED BY (y, z)",
},
)
self.validate_all(
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def test_ddl(self):
"duckdb": "CREATE TABLE x (w TEXT, y INT, z INT)",
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])",
"hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING, y INT, z INT) PARTITIONED BY (y, z)",
},
)
self.validate_all(
Expand Down
12 changes: 3 additions & 9 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ def test_ddl(self):
'x'='1'
)""",
"spark": """CREATE TABLE blah (
col_a INT
col_a INT,
date STRING
)
COMMENT 'Test comment: blah'
PARTITIONED BY (
date STRING
date
)
USING ICEBERG
TBLPROPERTIES (
Expand Down Expand Up @@ -125,13 +126,6 @@ def test_ddl(self):
"spark": "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)",
},
)
self.validate_all(
"CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
identify=True,
write={
"spark": "CREATE TABLE `x` USING ICEBERG PARTITIONED BY (MONTHS(`y`)) LOCATION 's3://z'",
},
)

def test_to_date(self):
self.validate_all(
Expand Down

0 comments on commit 76d6634

Please sign in to comment.