From 1772387a710ff92886e3fe7f30732595f092bb6a Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Thu, 29 Aug 2024 10:22:31 +0200 Subject: [PATCH] fix: Fix issues 2972 and 3007 (#3020) Address two issues: #2972 and #3007: - Correctly handle renamed/deleted stage - Handle data type diff suppression better for text and number in the table resource References: #2972 #3007 --- MIGRATION_GUIDE.md | 24 ++++ pkg/acceptance/helpers/stage_client.go | 8 ++ pkg/internal/util/strings.go | 12 ++ pkg/internal/util/strings_test.go | 34 +++++ pkg/resources/helpers.go | 37 +++++- pkg/resources/helpers_test.go | 145 +++++++++++++++++++++ pkg/resources/stage.go | 20 +-- pkg/resources/stage_acceptance_test.go | 51 ++++++++ pkg/resources/table.go | 2 +- pkg/resources/table_acceptance_test.go | 169 +++++++++++++++++++++++++ pkg/sdk/data_types.go | 116 +++++++++++++---- pkg/sdk/data_types_test.go | 76 +++++++++++ 12 files changed, 656 insertions(+), 38 deletions(-) create mode 100644 pkg/internal/util/strings.go create mode 100644 pkg/internal/util/strings_test.go diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index 10279af5fe..8f9dbb46c3 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -107,6 +107,30 @@ Some of the resources are excluded from this change: #### *(breaking change)* removed `qualified_name` from `snowflake_masking_policy`, `snowflake_network_rule`, `snowflake_password_policy` and `snowflake_table` Because of introducing a new `fully_qualified_name` field for all of the resources, `qualified_name` was removed from `snowflake_masking_policy`, `snowflake_network_rule`, `snowflake_password_policy` and `snowflake_table`. Please adjust your configurations. State is automatically migrated. +### snowflake_stage resource changes + +#### *(bugfix)* Correctly handle renamed/deleted stage + +Correctly handle the situation when stage was rename/deleted externally (earlier it resulted in a permanent loop). No action is required on the user's side. + +Connected issues: [#2972](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2972) + +### snowflake_table resource changes + +#### *(bugfix)* Handle data type diff suppression better for text and number + +Data types are not entirely correctly handled inside the provider (read more e.g. in [#2735](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2735)). It will be still improved with the upcoming function, procedure, and table rework. Currently, diff suppression was fixed for text and number data types in the table resource with the following assumptions/limitations: +- for numbers the default precision is 38 and the default scale is 0 (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-numeric#number)) +- for number types the following types are treated as synonyms: `NUMBER`, `DECIMAL`, `NUMERIC`, `INT`, `INTEGER`, `BIGINT`, `SMALLINT`, `TINYINT`, `BYTEINT` +- for text the default length is 16777216 (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-text#varchar)) +- for text types the following types are treated as synonyms: `VARCHAR`, `CHAR`, `CHARACTER`, `STRING`, `TEXT` +- whitespace and casing is ignored +- if the type arguments cannot be parsed the defaults are used and therefore diff may be suppressed unexpectedly (please report such cases) + +No action is required on the user's side. + +Connected issues: [#3007](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/3007) + ### snowflake_user resource changes #### *(breaking change)* user parameters added to snowflake_user resource diff --git a/pkg/acceptance/helpers/stage_client.go b/pkg/acceptance/helpers/stage_client.go index 5505237709..529d2f9552 100644 --- a/pkg/acceptance/helpers/stage_client.go +++ b/pkg/acceptance/helpers/stage_client.go @@ -128,3 +128,11 @@ func (c *StageClient) CopyIntoTableFromFile(t *testing.T, table, stage sdk.Schem MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE`, table.FullyQualifiedName(), stage.FullyQualifiedName(), filename)) require.NoError(t, err) } + +func (c *StageClient) Rename(t *testing.T, id sdk.SchemaObjectIdentifier, newId sdk.SchemaObjectIdentifier) { + t.Helper() + ctx := context.Background() + + err := c.client().Alter(ctx, sdk.NewAlterStageRequest(id).WithRenameTo(&newId)) + require.NoError(t, err) +} diff --git a/pkg/internal/util/strings.go b/pkg/internal/util/strings.go new file mode 100644 index 0000000000..2e8aadf014 --- /dev/null +++ b/pkg/internal/util/strings.go @@ -0,0 +1,12 @@ +package util + +import "strings" + +// TrimAllPrefixes removes all prefixes from the input. Order matters. +func TrimAllPrefixes(text string, prefixes ...string) string { + result := text + for _, prefix := range prefixes { + result = strings.TrimPrefix(result, prefix) + } + return result +} diff --git a/pkg/internal/util/strings_test.go b/pkg/internal/util/strings_test.go new file mode 100644 index 0000000000..ab2560b65d --- /dev/null +++ b/pkg/internal/util/strings_test.go @@ -0,0 +1,34 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_TrimAllPrefixes(t *testing.T) { + type test struct { + input string + prefixes []string + expected string + } + + tests := []test{ + {input: "VARCHAR(30)", prefixes: []string{"VARCHAR", "TEXT"}, expected: "(30)"}, + {input: "VARCHAR (30) ", prefixes: []string{"VARCHAR", "TEXT"}, expected: " (30) "}, + {input: "VARCHAR(30)", prefixes: []string{"VARCHAR"}, expected: "(30)"}, + {input: "VARCHAR(30)", prefixes: []string{}, expected: "VARCHAR(30)"}, + {input: "VARCHARVARCHAR(30)", prefixes: []string{"VARCHAR"}, expected: "VARCHAR(30)"}, + {input: "VARCHAR(30)", prefixes: []string{"NUMBER"}, expected: "VARCHAR(30)"}, + {input: "VARCHARTEXT(30)", prefixes: []string{"VARCHAR", "TEXT"}, expected: "(30)"}, + {input: "TEXTVARCHAR(30)", prefixes: []string{"VARCHAR", "TEXT"}, expected: "VARCHAR(30)"}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.input, func(t *testing.T) { + output := TrimAllPrefixes(tc.input, tc.prefixes...) + require.Equal(t, tc.expected, output) + }) + } +} diff --git a/pkg/resources/helpers.go b/pkg/resources/helpers.go index 7b91689d48..f88ae30bd2 100644 --- a/pkg/resources/helpers.go +++ b/pkg/resources/helpers.go @@ -5,10 +5,10 @@ import ( "slices" "strings" - "github.com/hashicorp/terraform-plugin-sdk/v2/diag" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -35,6 +35,39 @@ func dataTypeDiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { return oldDT == newDT } +// DataTypeIssue3007DiffSuppressFunc is a temporary solution to handle data type suppression problems. +// Currently, it handles only number and text data types. +// It falls back to Snowflake defaults for arguments if no arguments were provided for the data type. +// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework +func DataTypeIssue3007DiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + oldDataType, err := sdk.ToDataType(old) + if err != nil { + return false + } + newDataType, err := sdk.ToDataType(new) + if err != nil { + return false + } + if oldDataType != newDataType { + return false + } + switch v := oldDataType; v { + case sdk.DataTypeNumber: + logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Handling number data type diff suppression") + oldPrecision, oldScale := sdk.ParseNumberDataTypeRaw(old) + newPrecision, newScale := sdk.ParseNumberDataTypeRaw(new) + return oldPrecision == newPrecision && oldScale == newScale + case sdk.DataTypeVARCHAR: + logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Handling text data type diff suppression") + oldLength := sdk.ParseVarcharDataTypeRaw(old) + newLength := sdk.ParseVarcharDataTypeRaw(new) + return oldLength == newLength + default: + logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Diff suppression for %s can't be currently handled", v) + } + return true +} + func ignoreTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { return strings.TrimSpace(old) == strings.TrimSpace(new) } diff --git a/pkg/resources/helpers_test.go b/pkg/resources/helpers_test.go index 6447fa8e93..c143d40f03 100644 --- a/pkg/resources/helpers_test.go +++ b/pkg/resources/helpers_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-testing/terraform" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_GetPropertyAsPointer(t *testing.T) { @@ -258,3 +259,147 @@ func TestListDiff(t *testing.T) { }) } } + +func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { + testCases := []struct { + name string + old string + new string + expected bool + }{ + { + name: "different data type", + old: string(sdk.DataTypeVARCHAR), + new: string(sdk.DataTypeNumber), + expected: false, + }, + { + name: "same number data type without arguments", + old: string(sdk.DataTypeNumber), + new: string(sdk.DataTypeNumber), + expected: true, + }, + { + name: "same number data type different casing", + old: string(sdk.DataTypeNumber), + new: "number", + expected: true, + }, + { + name: "same text data type without arguments", + old: string(sdk.DataTypeVARCHAR), + new: string(sdk.DataTypeVARCHAR), + expected: true, + }, + { + name: "same other data type", + old: string(sdk.DataTypeFloat), + new: string(sdk.DataTypeFloat), + expected: true, + }, + { + name: "synonym number data type without arguments", + old: string(sdk.DataTypeNumber), + new: "DECIMAL", + expected: true, + }, + { + name: "synonym text data type without arguments", + old: string(sdk.DataTypeVARCHAR), + new: "TEXT", + expected: true, + }, + { + name: "synonym other data type without arguments", + old: string(sdk.DataTypeFloat), + new: "DOUBLE", + expected: true, + }, + { + name: "synonym number data type same precision, no scale", + old: "NUMBER(30)", + new: "DECIMAL(30)", + expected: true, + }, + { + name: "synonym number data type precision implicit and same", + old: "NUMBER", + new: fmt.Sprintf("DECIMAL(%d)", sdk.DefaultNumberPrecision), + expected: true, + }, + { + name: "synonym number data type precision implicit and different", + old: "NUMBER", + new: "DECIMAL(30)", + expected: false, + }, + { + name: "number data type different precisions, no scale", + old: "NUMBER(35)", + new: "NUMBER(30)", + expected: false, + }, + { + name: "synonym number data type same precision, different scale", + old: "NUMBER(30, 2)", + new: "DECIMAL(30, 1)", + expected: false, + }, + { + name: "synonym number data type default scale implicit and explicit", + old: "NUMBER(30)", + new: fmt.Sprintf("DECIMAL(30, %d)", sdk.DefaultNumberScale), + expected: true, + }, + { + name: "synonym number data type default scale implicit and different", + old: "NUMBER(30)", + new: "DECIMAL(30, 3)", + expected: false, + }, + { + name: "synonym number data type both precision and scale implicit and explicit", + old: "NUMBER", + new: fmt.Sprintf("DECIMAL(%d, %d)", sdk.DefaultNumberPrecision, sdk.DefaultNumberScale), + expected: true, + }, + { + name: "synonym number data type both precision and scale implicit and scale different", + old: "NUMBER", + new: fmt.Sprintf("DECIMAL(%d, 2)", sdk.DefaultNumberPrecision), + expected: false, + }, + { + name: "synonym text data type same length", + old: "VARCHAR(30)", + new: "TEXT(30)", + expected: true, + }, + { + name: "synonym text data type different length", + old: "VARCHAR(30)", + new: "TEXT(40)", + expected: false, + }, + { + name: "synonym text data type length implicit and same", + old: "VARCHAR", + new: fmt.Sprintf("TEXT(%d)", sdk.DefaultVarcharLength), + expected: true, + }, + { + name: "synonym text data type length implicit and different", + old: "VARCHAR", + new: "TEXT(40)", + expected: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + result := resources.DataTypeIssue3007DiffSuppressFunc("", tc.old, tc.new, nil) + require.Equal(t, tc.expected, result) + }) + } +} diff --git a/pkg/resources/stage.go b/pkg/resources/stage.go index d102bd2a71..47a1747524 100644 --- a/pkg/resources/stage.go +++ b/pkg/resources/stage.go @@ -2,6 +2,7 @@ package resources import ( "context" + "errors" "fmt" "strings" @@ -174,14 +175,17 @@ func ReadStage(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagn properties, err := client.Stages.Describe(ctx, id) if err != nil { - d.SetId("") - return diag.Diagnostics{ - diag.Diagnostic{ - Severity: diag.Error, - Summary: "Failed to describe stage", - Detail: fmt.Sprintf("Id: %s, Err: %s", d.Id(), err), - }, + if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { + d.SetId("") + return diag.Diagnostics{ + diag.Diagnostic{ + Severity: diag.Warning, + Summary: "Failed to describe stage. Marking the resource as removed.", + Detail: fmt.Sprintf("Stage: %s, Err: %s", id.FullyQualifiedName(), err), + }, + } } + return diag.FromErr(err) } stage, err := client.Stages.ShowByID(ctx, id) @@ -191,7 +195,7 @@ func ReadStage(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagn diag.Diagnostic{ Severity: diag.Error, Summary: "Failed to show stage by id", - Detail: fmt.Sprintf("Id: %s, Err: %s", d.Id(), err), + Detail: fmt.Sprintf("Stage: %s, Err: %s", id.FullyQualifiedName(), err), }, } } diff --git a/pkg/resources/stage_acceptance_test.go b/pkg/resources/stage_acceptance_test.go index 7275987e56..e5489fdff5 100644 --- a/pkg/resources/stage_acceptance_test.go +++ b/pkg/resources/stage_acceptance_test.go @@ -204,3 +204,54 @@ resource "snowflake_stage" "test" { } `, name, siNameSuffix, url, databaseName, schemaName) } + +func TestAcc_Stage_Issue2972(t *testing.T) { + stageId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + newId := acc.TestClient().Ids.RandomSchemaObjectIdentifierInSchema(stageId.SchemaId()) + resourceName := "snowflake_stage.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.Stage), + Steps: []resource.TestStep{ + { + Config: stageIssue2972Config(stageId), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "database", stageId.DatabaseName()), + resource.TestCheckResourceAttr(resourceName, "schema", stageId.SchemaName()), + resource.TestCheckResourceAttr(resourceName, "name", stageId.Name()), + ), + }, + { + PreConfig: func() { + acc.TestClient().Stage.Rename(t, stageId, newId) + }, + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionCreate), + }, + }, + Config: stageIssue2972Config(stageId), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "database", stageId.DatabaseName()), + resource.TestCheckResourceAttr(resourceName, "schema", stageId.SchemaName()), + resource.TestCheckResourceAttr(resourceName, "name", stageId.Name()), + ), + }, + }, + }) +} + +func stageIssue2972Config(stageId sdk.SchemaObjectIdentifier) string { + return fmt.Sprintf(` +resource "snowflake_stage" "test" { + name = "%[1]s" + database = "%[2]s" + schema = "%[3]s" +} +`, stageId.Name(), stageId.DatabaseName(), stageId.SchemaName()) +} diff --git a/pkg/resources/table.go b/pkg/resources/table.go index 13280eb9fc..ad10836b74 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -62,7 +62,7 @@ var tableSchema = map[string]*schema.Schema{ Required: true, Description: "Column type, e.g. VARIANT. For a full list of column types, see [Summary of Data Types](https://docs.snowflake.com/en/sql-reference/intro-summary-data-types).", ValidateFunc: dataTypeValidateFunc, - DiffSuppressFunc: dataTypeDiffSuppressFunc, + DiffSuppressFunc: DataTypeIssue3007DiffSuppressFunc, }, "nullable": { Type: schema.TypeBool, diff --git a/pkg/resources/table_acceptance_test.go b/pkg/resources/table_acceptance_test.go index 49d8bc461b..a855fb2a49 100644 --- a/pkg/resources/table_acceptance_test.go +++ b/pkg/resources/table_acceptance_test.go @@ -9,7 +9,9 @@ import ( acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" r "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/resources" + tfjson "github.com/hashicorp/terraform-json" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/planchecks" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -2090,3 +2092,170 @@ resource "snowflake_table" "test_table" { } `, name, databaseName, schemaName) } + +func TestAcc_Table_issue3007_textColumn(t *testing.T) { + tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + resourceName := "snowflake_table.test_table" + + defaultVarchar := fmt.Sprintf("VARCHAR(%d)", sdk.DefaultVarcharLength) + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + Steps: []resource.TestStep{ + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "VARCHAR(3)"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.0.type", "NUMBER(11,2)"), + resource.TestCheckResourceAttr(resourceName, "column.1.type", "VARCHAR(3)"), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "VARCHAR(256)"), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + planchecks.ExpectChange(resourceName, "column.1.type", tfjson.ActionUpdate, sdk.String("VARCHAR(3)"), sdk.String("VARCHAR(256)")), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.1.type", "VARCHAR(256)"), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "VARCHAR"), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + planchecks.ExpectChange(resourceName, "column.1.type", tfjson.ActionUpdate, sdk.String("VARCHAR(256)"), sdk.String("VARCHAR")), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.1.type", defaultVarchar), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, defaultVarchar), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionNoop), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.1.type", defaultVarchar), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "text"), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionNoop), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.1.type", defaultVarchar), + ), + }, + }, + }) +} + +// TODO [SNOW-1348114]: visit with table rework (e.g. changing scale is not supported: err 040052 (22000): SQL compilation error: cannot change column SOME_COLUMN from type NUMBER(38,0) to NUMBER(11,2) because changing the scale of a number is not supported.) +func TestAcc_Table_issue3007_numberColumn(t *testing.T) { + tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + resourceName := "snowflake_table.test_table" + + defaultNumber := fmt.Sprintf("NUMBER(%d,%d)", sdk.DefaultNumberPrecision, sdk.DefaultNumberScale) + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + Steps: []resource.TestStep{ + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "NUMBER"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.0.type", "NUMBER(11,2)"), + resource.TestCheckResourceAttr(resourceName, "column.1.type", "NUMBER(38,0)"), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "NUMBER(11)"), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + planchecks.ExpectChange(resourceName, "column.1.type", tfjson.ActionUpdate, sdk.String("NUMBER(38,0)"), sdk.String("NUMBER(11)")), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.0.type", "NUMBER(11,2)"), + resource.TestCheckResourceAttr(resourceName, "column.1.type", "NUMBER(11,0)"), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "NUMBER"), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + planchecks.ExpectChange(resourceName, "column.1.type", tfjson.ActionUpdate, sdk.String("NUMBER(11,0)"), sdk.String("NUMBER")), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.1.type", defaultNumber), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, defaultNumber), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionNoop), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.1.type", defaultNumber), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: tableConfigIssue3007(tableId, "decimal"), + ConfigPlanChecks: resource.ConfigPlanChecks{ + PreApply: []plancheck.PlanCheck{ + plancheck.ExpectResourceAction(resourceName, plancheck.ResourceActionNoop), + }, + }, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "column.1.type", defaultNumber), + ), + }, + }, + }) +} + +func tableConfigIssue3007(tableId sdk.SchemaObjectIdentifier, dataType string) string { + return fmt.Sprintf(` +resource "snowflake_table" "test_table" { + name = "%[1]s" + database = "%[2]s" + schema = "%[3]s" + comment = "Issue 3007 confirmation" + + column { + name = "ID" + type = "NUMBER(11,2)" + } + + column { + name = "SOME_COLUMN" + type = "%[4]s" + } +} +`, tableId.Name(), tableId.DatabaseName(), tableId.SchemaName(), dataType) +} diff --git a/pkg/sdk/data_types.go b/pkg/sdk/data_types.go index 9a8c46c07f..40ced4cb83 100644 --- a/pkg/sdk/data_types.go +++ b/pkg/sdk/data_types.go @@ -3,7 +3,11 @@ package sdk import ( "fmt" "slices" + "strconv" "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/util" ) // DataType is based on https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. @@ -35,6 +39,25 @@ const ( DataTypeGeometry DataType = "GEOMETRY" ) +var ( + DataTypeNumberSynonyms = []string{"NUMBER", "DECIMAL", "NUMERIC", "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} + DataTypeFloatSynonyms = []string{"FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DOUBLE PRECISION", "REAL"} + DataTypeVarcharSynonyms = []string{"VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"} + DataTypeBinarySynonyms = []string{"BINARY", "VARBINARY"} + DataTypeBooleanSynonyms = []string{"BOOLEAN", "BOOL"} + DataTypeTimestampLTZSynonyms = []string{"TIMESTAMP_LTZ"} + DataTypeTimestampTZSynonyms = []string{"TIMESTAMP_TZ"} + DataTypeTimestampNTZSynonyms = []string{"DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ"} + DataTypeTimeSynonyms = []string{"TIME"} + DataTypeVectorSynonyms = []string{"VECTOR"} +) + +const ( + DefaultNumberPrecision = 38 + DefaultNumberScale = 0 + DefaultVarcharLength = 16777216 +) + func ToDataType(s string) (DataType, error) { dType := strings.ToUpper(s) @@ -53,53 +76,36 @@ func ToDataType(s string) (DataType, error) { return DataTypeGeometry, nil } - numberSynonyms := []string{"NUMBER", "DECIMAL", "NUMERIC", "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} - if slices.ContainsFunc(numberSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeNumberSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeNumber, nil } - - floatSynonyms := []string{"FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DOUBLE PRECISION", "REAL"} - if slices.ContainsFunc(floatSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeFloatSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeFloat, nil } - varcharSynonyms := []string{"VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"} - if slices.ContainsFunc(varcharSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeVarcharSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeVARCHAR, nil } - binarySynonyms := []string{"BINARY", "VARBINARY"} - if slices.ContainsFunc(binarySynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeBinarySynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeBinary, nil } - booleanSynonyms := []string{"BOOLEAN", "BOOL"} - if slices.Contains(booleanSynonyms, dType) { + if slices.Contains(DataTypeBooleanSynonyms, dType) { return DataTypeBoolean, nil } - - timestampLTZSynonyms := []string{"TIMESTAMP_LTZ"} - if slices.ContainsFunc(timestampLTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeTimestampLTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeTimestampLTZ, nil } - - timestampTZSynonyms := []string{"TIMESTAMP_TZ"} - if slices.ContainsFunc(timestampTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeTimestampTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeTimestampTZ, nil } - - timestampNTZSynonyms := []string{"DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ"} - if slices.ContainsFunc(timestampNTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeTimestampNTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeTimestampNTZ, nil } - - timeSynonyms := []string{"TIME"} - if slices.ContainsFunc(timeSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + if slices.ContainsFunc(DataTypeTimeSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeTime, nil } - - vectorSynonyms := []string{"VECTOR"} - if slices.ContainsFunc(vectorSynonyms, func(e string) bool { return strings.HasPrefix(dType, e) }) { + if slices.ContainsFunc(DataTypeVectorSynonyms, func(e string) bool { return strings.HasPrefix(dType, e) }) { return DataType(dType), nil } - return "", fmt.Errorf("invalid data type: %s", s) } @@ -112,3 +118,59 @@ func IsStringType(_type string) bool { strings.HasPrefix(t, "NVARCHAR") || strings.HasPrefix(t, "NCHAR") } + +// ParseNumberDataTypeRaw extracts precision and scale from the raw number data type input. +// It returns defaults if it can't parse arguments, data type is different, or no arguments were provided. +// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework +func ParseNumberDataTypeRaw(rawDataType string) (int, int) { + r := util.TrimAllPrefixes(strings.TrimSpace(strings.ToUpper(rawDataType)), DataTypeNumberSynonyms...) + r = strings.TrimSpace(r) + if strings.HasPrefix(r, "(") && strings.HasSuffix(r, ")") { + parts := strings.Split(r[1:len(r)-1], ",") + switch l := len(parts); l { + case 1: + precision, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err == nil { + return precision, DefaultNumberScale + } else { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s", err: %v`, parts[0], err) + } + case 2: + precision, err1 := strconv.Atoi(strings.TrimSpace(parts[0])) + scale, err2 := strconv.Atoi(strings.TrimSpace(parts[1])) + if err1 == nil && err2 == nil { + return precision, scale + } else { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s" or scale "%s", errs: %v, %v`, parts[0], parts[1], err1, err2) + } + default: + logging.DebugLogger.Printf("[DEBUG] Unexpected length of number arguments") + } + } + logging.DebugLogger.Printf("[DEBUG] Returning default number precision and scale") + return DefaultNumberPrecision, DefaultNumberScale +} + +// ParseVarcharDataTypeRaw extracts length from the raw text data type input. +// It returns default if it can't parse arguments, data type is different, or no length argument was provided. +// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework +func ParseVarcharDataTypeRaw(rawDataType string) int { + r := util.TrimAllPrefixes(strings.TrimSpace(strings.ToUpper(rawDataType)), DataTypeVarcharSynonyms...) + r = strings.TrimSpace(r) + if strings.HasPrefix(r, "(") && strings.HasSuffix(r, ")") { + parts := strings.Split(r[1:len(r)-1], ",") + switch l := len(parts); l { + case 1: + length, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err == nil { + return length + } else { + logging.DebugLogger.Printf(`[DEBUG] Could not parse varchar length "%s", err: %v`, parts[0], err) + } + default: + logging.DebugLogger.Printf("[DEBUG] Unexpected length of varchar arguments") + } + } + logging.DebugLogger.Printf("[DEBUG] Returning default varchar length") + return DefaultVarcharLength +} diff --git a/pkg/sdk/data_types_test.go b/pkg/sdk/data_types_test.go index 17abeca4c4..156e5dc8f2 100644 --- a/pkg/sdk/data_types_test.go +++ b/pkg/sdk/data_types_test.go @@ -1,8 +1,10 @@ package sdk import ( + "fmt" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -140,3 +142,77 @@ func TestIsStringType(t *testing.T) { }) } } + +func Test_ParseNumberDataTypeRaw(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedScale int + } + defaults := func(input string) test { + return test{input: input, expectedPrecision: DefaultNumberPrecision, expectedScale: DefaultNumberScale} + } + + tests := []test{ + {input: "NUMBER(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale}, + {input: "NUMBER(30, 2)", expectedPrecision: 30, expectedScale: 2}, + {input: "decimal(30, 2)", expectedPrecision: 30, expectedScale: 2}, + {input: "NUMBER( 30 , 2 )", expectedPrecision: 30, expectedScale: 2}, + {input: " NUMBER ( 30 , 2 ) ", expectedPrecision: 30, expectedScale: 2}, + + // returns defaults if it can't parse arguments, data type is different, or no arguments were provided + defaults("VARCHAR(1, 2)"), + defaults("VARCHAR(1)"), + defaults("VARCHAR"), + defaults("NUMBER"), + defaults("NUMBER()"), + defaults("NUMBER(x)"), + defaults(fmt.Sprintf("NUMBER(%d)", DefaultNumberPrecision)), + defaults(fmt.Sprintf("NUMBER(%d, x)", DefaultNumberPrecision)), + defaults(fmt.Sprintf("NUMBER(x, %d)", DefaultNumberScale)), + defaults("NUMBER(1, 2, 3)"), + } + + for _, tc := range tests { + tc := tc + t.Run(tc.input, func(t *testing.T) { + precision, scale := ParseNumberDataTypeRaw(tc.input) + assert.Equal(t, tc.expectedPrecision, precision) + assert.Equal(t, tc.expectedScale, scale) + }) + } +} + +func Test_ParseVarcharDataTypeRaw(t *testing.T) { + type test struct { + input string + expectedLength int + } + defaults := func(input string) test { + return test{input: input, expectedLength: DefaultVarcharLength} + } + + tests := []test{ + {input: "VARCHAR(30)", expectedLength: 30}, + {input: "text(30)", expectedLength: 30}, + {input: "VARCHAR( 30 )", expectedLength: 30}, + {input: " VARCHAR ( 30 ) ", expectedLength: 30}, + + // returns defaults if it can't parse arguments, data type is different, or no arguments were provided + defaults("VARCHAR(1, 2)"), + defaults("VARCHAR(x)"), + defaults("VARCHAR"), + defaults("NUMBER"), + defaults("NUMBER()"), + defaults("NUMBER(x)"), + defaults(fmt.Sprintf("VARCHAR(%d)", DefaultVarcharLength)), + } + + for _, tc := range tests { + tc := tc + t.Run(tc.input, func(t *testing.T) { + length := ParseVarcharDataTypeRaw(tc.input) + assert.Equal(t, tc.expectedLength, length) + }) + } +}