diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 6e2d190757..4e9c954a5a 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -654,28 +654,6 @@ def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> return _time_format -def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: - """ - In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the - PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding - columns are removed from the create statement. - """ - has_schema = isinstance(expression.this, exp.Schema) - is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") - - if has_schema and is_partitionable: - prop = expression.find(exp.PartitionedByProperty) - if prop and prop.this and not isinstance(prop.this, exp.Schema): - schema = expression.this - columns = {v.name.upper() for v in prop.this.expressions} - partitions = [col for col in schema.expressions if col.name.upper() in columns] - schema.set("expressions", [e for e in schema.expressions if e not in partitions]) - prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) - expression.set("this", schema) - - return self.create_sql(expression) - - def parse_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None ) -> t.Callable[[t.List], E]: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index be23355f9e..409e260c47 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -5,7 +5,6 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, - create_with_partitions_sql, datestrtodate_sql, format_time_lambda, no_trycast_sql, @@ -13,6 +12,7 @@ str_position_sql, timestrtotime_sql, ) +from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]: @@ -125,7 +125,7 @@ class Generator(generator.Generator): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), - exp.Create: create_with_partitions_sql, + exp.Create: preprocess([move_schema_columns_to_partitioned_by]), exp.DateAdd: _date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 6337ffd940..e575528600 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -9,7 +9,6 @@ NormalizationStrategy, approx_count_distinct_sql, arg_max_or_min_no_count, - create_with_partitions_sql, datestrtodate_sql, format_time_lambda, if_sql, @@ -32,6 +31,12 @@ timestrtotime_sql, var_map_sql, ) +from sqlglot.transforms import ( + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + preprocess, + move_schema_columns_to_partitioned_by, +) from sqlglot.helper import seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType @@ -55,30 +60,6 @@ DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _create_sql(self, expression: exp.Create) -> str: - # remove UNIQUE column constraints - for constraint in expression.find_all(exp.UniqueColumnConstraint): - if constraint.parent: - constraint.parent.pop() - - properties = expression.args.get("properties") - temporary = any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - - # CTAS with temp tables map to CREATE TEMPORARY VIEW - kind = expression.args["kind"] - if kind.upper() == "TABLE" and temporary: - if expression.expression: - return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" - else: - # CREATE TEMPORARY TABLE may require storage provider - expression = self.temporary_storage_provider(expression) - - return create_with_partitions_sql(self, expression) - - def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str: if isinstance(expression, exp.TsOrDsAdd) and not expression.unit: return self.func("DATE_ADD", expression.this, expression.expression) @@ -518,7 +499,13 @@ class Generator(generator.Generator): "" if e.args.get("allow_null") else "NOT NULL" ), exp.VarMap: var_map_sql, - exp.Create: _create_sql, + exp.Create: preprocess( + [ + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + move_schema_columns_to_partitioned_by, + ] + ), exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpExtract: regexp_extract_sql, @@ -581,10 +568,6 @@ def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: return super()._jsonpathkey_sql(expression) - def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: - # Hive has no temporary storage provider (there are hive settings though) - return expression - def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 44bd12d307..c662ab594d 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -5,8 +5,14 @@ from sqlglot import exp from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.hive import _parse_ignore_nulls -from sqlglot.dialects.spark2 import Spark2 +from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get +from sqlglot.transforms import ( + ctas_with_tmp_tables_to_create_tmp_view, + remove_unique_constraints, + preprocess, + move_partitioned_by_to_schema_columns, +) def _parse_datediff(args: t.List) -> exp.Expression: @@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression: ) +def _normalize_partition(e: exp.Expression) -> exp.Expression: + """Normalize the expressions in PARTITION BY (, , ...)""" + if isinstance(e, str): + return exp.to_identifier(e) + if isinstance(e, exp.Literal): + return exp.to_identifier(e.name) + return e + + class Spark(Spark2): class Tokenizer(Spark2.Tokenizer): RAW_STRINGS = [ @@ -72,6 +87,17 @@ class Generator(Spark2.Generator): TRANSFORMS = { **Spark2.Generator.TRANSFORMS, + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_partitioned_by_to_schema_columns, + ] + ), + exp.PartitionedByProperty: lambda self, + e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), exp.TimestampAdd: lambda self, e: self.func( "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 9378d9958f..72f82006ea 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -13,6 +13,12 @@ ) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get +from sqlglot.transforms import ( + preprocess, + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + move_schema_columns_to_partitioned_by, +) def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: @@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def temporary_storage_provider(expression: exp.Expression) -> exp.Expression: + # spark2, spark, Databricks require a storage provider for temporary tables + provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) + expression.args["properties"].append("expressions", provider) + return expression + + class Spark2(Hive): class Parser(Hive.Parser): TRIM_PATTERN_FIRST = True @@ -193,6 +206,15 @@ class Generator(Hive.Generator): e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_schema_columns_to_partitioned_by, + ] + ), exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -251,12 +273,6 @@ def struct_sql(self, expression: exp.Struct) -> str: return self.func("STRUCT", *args) - def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: - # spark2, spark, Databricks require a storage provider for temporary tables - provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) - expression.args["properties"].append("expressions", provider) - return expression - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a1d960d84c..151e77addd 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1090,6 +1090,11 @@ class Create(DDL): "clone": False, } + @property + def kind(self) -> t.Optional[str]: + kind = self.args.get("kind") + return kind and kind.upper() + # https://docs.snowflake.com/en/sql-reference/sql/create-clone # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 5dc5d6e95c..caaa8acc7d 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -476,6 +476,87 @@ def unqualify_columns(expression: exp.Expression) -> exp.Expression: return expression +def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: + assert isinstance(expression, exp.Create) + for constraint in expression.find_all(exp.UniqueColumnConstraint): + if constraint.parent: + constraint.parent.pop() + + return expression + + +def ctas_with_tmp_tables_to_create_tmp_view( + expression: exp.Expression, + tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, +) -> exp.Expression: + assert isinstance(expression, exp.Create) + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + # CTAS with temp tables map to CREATE TEMPORARY VIEW + if expression.kind == "TABLE" and temporary: + if expression.expression: + return exp.Create( + kind="TEMPORARY VIEW", + this=expression.this, + expression=expression.expression, + ) + return tmp_storage_provider(expression) + + return expression + + +def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: + """ + In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the + PARTITIONED BY value is an array of column names, they are transformed into a schema. + The corresponding columns are removed from the create statement. + """ + assert isinstance(expression, exp.Create) + has_schema = isinstance(expression.this, exp.Schema) + is_partitionable = expression.kind in {"TABLE", "VIEW"} + + if has_schema and is_partitionable: + prop = expression.find(exp.PartitionedByProperty) + if prop and prop.this and not isinstance(prop.this, exp.Schema): + schema = expression.this + columns = {v.name.upper() for v in prop.this.expressions} + partitions = [col for col in schema.expressions if col.name.upper() in columns] + schema.set("expressions", [e for e in schema.expressions if e not in partitions]) + prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) + expression.set("this", schema) + + return expression + + +def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: + """ + Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. + + Currently, SQLGlot uses the DATASOURCE format for Spark 3. + """ + assert isinstance(expression, exp.Create) + prop = expression.find(exp.PartitionedByProperty) + if ( + prop + and prop.this + and isinstance(prop.this, exp.Schema) + and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions) + ): + prop_this = exp.Tuple( + expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] + ) + schema = expression.this + for e in prop.this.expressions: + schema.append("expressions", e) + prop.set("this", prop_this) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index d1b75890a8..ea28f2938e 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -152,7 +152,7 @@ def test_ddl(self): "duckdb": "CREATE TABLE x (w TEXT)", # Partition columns should exist in table "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", - "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING, y INT, z INT) PARTITIONED BY (y, z)", }, ) self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 36006d2361..d3d1a7606b 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -424,7 +424,7 @@ def test_ddl(self): "duckdb": "CREATE TABLE x (w TEXT, y INT, z INT)", "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", - "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING, y INT, z INT) PARTITIONED BY (y, z)", }, ) self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 75bb91af4f..4d0362e3be 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -93,11 +93,12 @@ def test_ddl(self): 'x'='1' )""", "spark": """CREATE TABLE blah ( - col_a INT + col_a INT, + date STRING ) COMMENT 'Test comment: blah' PARTITIONED BY ( - date STRING + date ) USING ICEBERG TBLPROPERTIES ( @@ -125,13 +126,6 @@ def test_ddl(self): "spark": "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)", }, ) - self.validate_all( - "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", - identify=True, - write={ - "spark": "CREATE TABLE `x` USING ICEBERG PARTITIONED BY (MONTHS(`y`)) LOCATION 's3://z'", - }, - ) def test_to_date(self): self.validate_all(