From 74239c231d65d5290ab9bd95e354c3276f4c5041 Mon Sep 17 00:00:00 2001 From: Taher Lakdawala <78196491+taherkl@users.noreply.github.com> Date: Wed, 27 Nov 2024 18:09:04 +0530 Subject: [PATCH] Check constraint backend (#9) Backend Support for Check Constraint --- internal/convert.go | 1 + internal/helpers.go | 5 + internal/mapping.go | 13 ++ internal/reports/report_helpers.go | 7 + schema/schema.go | 26 ++- sources/common/infoschema.go | 23 +-- sources/common/toddl.go | 32 +++- sources/common/toddl_test.go | 29 +++ sources/dynamodb/schema.go | 6 +- sources/dynamodb/schema_test.go | 4 +- sources/mysql/infoschema.go | 116 ++++++++---- sources/mysql/infoschema_test.go | 116 ++++++++++++ sources/oracle/infoschema.go | 6 +- sources/postgres/infoschema.go | 6 +- sources/spanner/infoschema.go | 8 +- sources/sqlserver/infoschema.go | 6 +- spanner/ddl/ast.go | 55 ++++-- spanner/ddl/ast_test.go | 93 +++++++++- webv2/api/schema.go | 85 +++++++++ webv2/api/schema_test.go | 274 +++++++++++++++++++++++++++++ webv2/routes.go | 2 + webv2/table/review_table_schema.go | 9 +- webv2/table/update_table_schema.go | 10 ++ 23 files changed, 831 insertions(+), 101 deletions(-) diff --git a/internal/convert.go b/internal/convert.go index 413691260..1f54f8611 100644 --- a/internal/convert.go +++ b/internal/convert.go @@ -132,6 +132,7 @@ const ( ForeignKeyActionNotSupported NumericPKNotSupported DefaultValueError + TypeMismatch ) const ( diff --git a/internal/helpers.go b/internal/helpers.go index 621b36365..3a7968646 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -65,6 +65,11 @@ func GenerateForeignkeyId() string { func GenerateIndexesId() string { return GenerateId("i") } + +func GenerateCheckConstrainstId() string { + return GenerateId("ck") +} + func GenerateRuleId() string { return GenerateId("r") } diff --git a/internal/mapping.go b/internal/mapping.go index 0eb9bd055..93f9f7f37 100644 --- a/internal/mapping.go +++ b/internal/mapping.go @@ -243,6 +243,19 @@ func ToSpannerIndexName(conv *Conv, srcIndexName string) string { return getSpannerValidName(conv, srcIndexName) } +// Note that the check constraints names in spanner have to be globally unique +// (across the database). But in some source databases, such as MySQL, +// they only have to be unique for a table. Hence we must map each source +// constraint name to a unique spanner constraint name. +func ToSpannerCheckConstraintName(conv *Conv, srcCheckConstraintName string) string { + return getSpannerValidName(conv, srcCheckConstraintName) +} + +func GetSpannerValidExpression(cks []ddl.CheckConstraint) []ddl.CheckConstraint { + // TODO validate the check constraints data with batch verification then send back + return cks +} + // conv.UsedNames tracks Spanner names that have been used for table names, foreign key constraints // and indexes. We use this to ensure we generate unique names when // we map from source dbs to Spanner since Spanner requires all these names to be diff --git a/internal/reports/report_helpers.go b/internal/reports/report_helpers.go index 38df589af..c4f3084bd 100644 --- a/internal/reports/report_helpers.go +++ b/internal/reports/report_helpers.go @@ -409,6 +409,13 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string Description: fmt.Sprintf("%s for table '%s' column '%s'", IssueDB[i].Brief, conv.SpSchema[tableId].Name, spColName), } l = append(l, toAppend) + case internal.TypeMismatch: + toAppend := Issue{ + Category: IssueDB[i].Category, + Description: fmt.Sprintf("Table '%s': Type mismatch in '%s'column affecting check constraints. Verify data type compatibility with constraint logic", conv.SpSchema[tableId].Name, conv.SpSchema[tableId].ColDefs[colId].Name), + } + l = append(l, toAppend) + default: toAppend := Issue{ Category: IssueDB[i].Category, diff --git a/schema/schema.go b/schema/schema.go index 7d73cd799..eab021cf0 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -35,15 +35,16 @@ import ( // Table represents a database table. type Table struct { - Name string - Schema string - ColIds []string // List of column Ids (for predictable iteration order e.g. printing). - ColDefs map[string]Column // Details of columns. - ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming - PrimaryKeys []Key - ForeignKeys []ForeignKey - Indexes []Index - Id string + Name string + Schema string + ColIds []string // List of column Ids (for predictable iteration order e.g. printing). + ColDefs map[string]Column // Details of columns. + ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming + PrimaryKeys []Key + ForeignKeys []ForeignKey + CheckConstraints []CheckConstraint + Indexes []Index + Id string } // Column represents a database column. @@ -77,6 +78,13 @@ type ForeignKey struct { Id string } +// CheckConstraints represents a check constraint defined in the schema. +type CheckConstraint struct { + Name string + Expr string + Id string +} + // Key respresents a primary key or index key. type Key struct { ColId string diff --git a/sources/common/infoschema.go b/sources/common/infoschema.go index b4f9c9e7c..ae43ea4a0 100644 --- a/sources/common/infoschema.go +++ b/sources/common/infoschema.go @@ -38,7 +38,7 @@ type InfoSchema interface { GetColumns(conv *internal.Conv, table SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) GetRowsFromTable(conv *internal.Conv, srcTable string) (interface{}, error) GetRowCount(table SchemaAndName) (int64, error) - GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, map[string][]string, error) + GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) GetForeignKeys(conv *internal.Conv, table SchemaAndName) (foreignKeys []schema.ForeignKey, err error) GetIndexes(conv *internal.Conv, table SchemaAndName, colNameIdMp map[string]string) ([]schema.Index, error) ProcessData(conv *internal.Conv, tableId string, srcSchema schema.Table, spCols []string, spSchema ddl.CreateTable, additionalAttributes internal.AdditionalDataAttributes) error @@ -187,7 +187,7 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, var t schema.Table fmt.Println("processing schema for table", table) tblId := internal.GenerateTableId() - primaryKeys, constraints, err := infoSchema.GetConstraints(conv, table) + primaryKeys, checkConstraints, constraints, err := infoSchema.GetConstraints(conv, table) if err != nil { return t, fmt.Errorf("couldn't get constraints for table %s.%s: %s", table.Schema, table.Name, err) } @@ -217,15 +217,16 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, schemaPKeys = append(schemaPKeys, schema.Key{ColId: colNameIdMap[k]}) } t = schema.Table{ - Id: tblId, - Name: name, - Schema: table.Schema, - ColIds: colIds, - ColNameIdMap: colNameIdMap, - ColDefs: colDefs, - PrimaryKeys: schemaPKeys, - Indexes: indexes, - ForeignKeys: foreignKeys} + Id: tblId, + Name: name, + Schema: table.Schema, + ColIds: colIds, + ColNameIdMap: colNameIdMap, + ColDefs: colDefs, + PrimaryKeys: schemaPKeys, + CheckConstraints: checkConstraints, + Indexes: indexes, + ForeignKeys: foreignKeys} return t, nil } diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 1b706274e..a361a80f9 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -167,14 +167,15 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod } comment := "Spanner schema for source table " + quoteIfNeeded(srcTable.Name) conv.SpSchema[srcTable.Id] = ddl.CreateTable{ - Name: spTableName, - ColIds: spColIds, - ColDefs: spColDef, - PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), - ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), - Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), - Comment: comment, - Id: srcTable.Id} + Name: spTableName, + ColIds: spColIds, + ColDefs: spColDef, + PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), + ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), + CheckConstraints: cvtCheckConstraint(conv, srcTable.CheckConstraints), + Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), + Comment: comment, + Id: srcTable.Id} return nil } @@ -234,6 +235,21 @@ func cvtForeignKeys(conv *internal.Conv, spTableName string, srcTableId string, return spKeys } +func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraint) []ddl.CheckConstraint { + var spcks []ddl.CheckConstraint + + for _, cks := range srcKeys { + spcks = append(spcks, ddl.CheckConstraint{ + Id: cks.Id, + Name: internal.ToSpannerCheckConstraintName(conv, cks.Name), + Expr: cks.Expr, + }) + + } + + return internal.GetSpannerValidExpression(spcks) +} + func CvtForeignKeysHelper(conv *internal.Conv, spTableName string, srcTableId string, srcKey schema.ForeignKey, isRestore bool) (ddl.Foreignkey, error) { if len(srcKey.ColIds) != len(srcKey.ReferColumnIds) { conv.Unexpected(fmt.Sprintf("ConvertForeignKeys: ColIds and referColumns don't have the same lengths: len(columns)=%d, len(referColumns)=%d for source tableId: %s, referenced table: %s", len(srcKey.ColIds), len(srcKey.ReferColumnIds), srcTableId, srcKey.ReferTableId)) diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index dcf5b3651..3886be6b7 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -536,4 +536,33 @@ func TestSpannerSchemaApplyExpressions(t *testing.T) { assert.Equal(t, tc.expectedConv, tc.conv) }) } +func Test_cvtCheckContraint(t *testing.T) { + + conv := internal.MakeConv() + srcSchema := []schema.CheckConstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + } + spSchema := []ddl.CheckConstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + } + result := cvtCheckConstraint(conv, srcSchema) + assert.Equal(t, spSchema, result) } diff --git a/sources/dynamodb/schema.go b/sources/dynamodb/schema.go index 4bc38f0ea..4af603988 100644 --- a/sources/dynamodb/schema.go +++ b/sources/dynamodb/schema.go @@ -129,20 +129,20 @@ func (isi InfoSchemaImpl) GetRowCount(table common.SchemaAndName) (int64, error) return *result.Table.ItemCount, err } -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) (primaryKeys []string, constraints map[string][]string, err error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) (primaryKeys []string, checkConstraints []schema.CheckConstraint, constraints map[string][]string, err error) { input := &dynamodb.DescribeTableInput{ TableName: aws.String(table.Name), } result, err := isi.DynamoClient.DescribeTable(input) if err != nil { - return primaryKeys, constraints, fmt.Errorf("failed to make a DescribeTable API call for table %v: %v", table.Name, err) + return primaryKeys, checkConstraints, constraints, fmt.Errorf("failed to make a DescribeTable API call for table %v: %v", table.Name, err) } // Primary keys. for _, i := range result.Table.KeySchema { primaryKeys = append(primaryKeys, *i.AttributeName) } - return primaryKeys, constraints, nil + return primaryKeys, checkConstraints, constraints, nil } func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.SchemaAndName) (foreignKeys []schema.ForeignKey, err error) { diff --git a/sources/dynamodb/schema_test.go b/sources/dynamodb/schema_test.go index b9d10ab3a..9919d3d5a 100644 --- a/sources/dynamodb/schema_test.go +++ b/sources/dynamodb/schema_test.go @@ -633,7 +633,7 @@ func TestInfoSchemaImpl_GetConstraints(t *testing.T) { dySchema := common.SchemaAndName{Name: "test"} conv := internal.MakeConv() isi := InfoSchemaImpl{client, nil, 10} - primaryKeys, constraints, err := isi.GetConstraints(conv, dySchema) + primaryKeys, _, constraints, err := isi.GetConstraints(conv, dySchema) assert.Nil(t, err) pKeys := []string{"a", "b"} @@ -705,7 +705,7 @@ func TestInfoSchemaImpl_GetColumns(t *testing.T) { client := &mockDynamoClient{ scanOutputs: scanOutputs, } - dySchema := common.SchemaAndName{Name: "test", Id: "t1"} + dySchema := common.SchemaAndName{Name: "test", Id: "t1"} isi := InfoSchemaImpl{client, nil, 10} diff --git a/sources/mysql/infoschema.go b/sources/mysql/infoschema.go index 0cb4ad7c8..682fa7bb3 100644 --- a/sources/mysql/infoschema.go +++ b/sources/mysql/infoschema.go @@ -194,16 +194,6 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd continue } ignored := schema.Ignored{} - for _, c := range constraints[colName] { - // c can be UNIQUE, PRIMARY KEY, FOREIGN KEY or CHECK - // We've already filtered out PRIMARY KEY. - switch c { - case "CHECK": - ignored.Check = true - case "FOREIGN KEY", "PRIMARY KEY", "UNIQUE": - // Nothing to do here -- these are all handled elsewhere. - } - } ignored.Default = colDefault.Valid colId := internal.GenerateColumnId() if colExtra.String == "auto_increment" { @@ -250,38 +240,98 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { - q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t - INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k - ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA AND t.TABLE_NAME=k.TABLE_NAME - WHERE k.TABLE_SCHEMA = ? AND k.TABLE_NAME = ? ORDER BY k.ordinal_position;` - rows, err := isi.Db.Query(q, table.Schema, table.Name) +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { + tableExists, err := isi.isCheckConstraintsTablePresent() if err != nil { - return nil, nil, err + return nil, nil, nil, err + } + + finalQuery := isi.getQuery(tableExists) + rows, err := isi.Db.Query(finalQuery, table.Schema, table.Name) + if err != nil { + return nil, nil, nil, err } defer rows.Close() + var primaryKeys []string - var col, constraint string + var checkKeys []schema.CheckConstraint m := make(map[string][]string) + for rows.Next() { - err := rows.Scan(&col, &constraint) - if err != nil { - conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) - continue - } - if col == "" || constraint == "" { - conv.Unexpected(fmt.Sprintf("Got empty col or constraint")) + if err := isi.processRow(rows, tableExists, conv, &primaryKeys, &checkKeys, m); err != nil { continue } - switch constraint { - case "PRIMARY KEY": - primaryKeys = append(primaryKeys, col) - default: - m[col] = append(m[col], constraint) - } } - return primaryKeys, m, nil + + return primaryKeys, checkKeys, m, nil +} + +// checkCheckConstraintsTableExists checks if the CHECK_CONSTRAINTS table exists. +func (isi InfoSchemaImpl) isCheckConstraintsTablePresent() (bool, error) { + var tableExistsCount int + checkQuery := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';` + err := isi.Db.QueryRow(checkQuery).Scan(&tableExistsCount) + if err != nil { + return false, err + } + return tableExistsCount > 0, nil +} + +// getQuery returns the appropriate SQL query based on the existence of CHECK_CONSTRAINTS. +func (isi InfoSchemaImpl) getQuery(tableExists bool) string { + if tableExists { + return `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE, COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + LEFT JOIN INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c + ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME + WHERE t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ? + ORDER BY k.ORDINAL_POSITION;` + } + return `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + WHERE t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ? + ORDER BY k.ORDINAL_POSITION;` +} + +// processRow handles scanning and processing of a database row for GetConstraints. +func (isi InfoSchemaImpl) processRow( + rows *sql.Rows, tableExists bool, conv *internal.Conv, primaryKeys *[]string, + checkKeys *[]schema.CheckConstraint, m map[string][]string) error { + + var col, constraintType, checkClause string + err := rows.Scan(&col, &constraintType, &checkClause) + if err != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + return err + } + + if col == "" || constraintType == "" { + conv.Unexpected("Got empty column or constraint type") + return nil + } + + switch constraintType { + case "PRIMARY KEY": + *primaryKeys = append(*primaryKeys, col) + case "CHECK": + checkClause = strings.ReplaceAll(checkClause, "_utf8mb4\\", "") + checkClause = strings.ReplaceAll(checkClause, "\\", "") + constraintName := fmt.Sprintf("%s_check", col) + *checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, Id: internal.GenerateCheckConstrainstId()}) + default: + m[col] = append(m[col], constraintType) + } + return nil } // GetForeignKeys return list all the foreign keys constraints. diff --git a/sources/mysql/infoschema_test.go b/sources/mysql/infoschema_test.go index 37836f382..c5d4b72e5 100644 --- a/sources/mysql/infoschema_test.go +++ b/sources/mysql/infoschema_test.go @@ -17,6 +17,7 @@ package mysql import ( "database/sql" "database/sql/driver" + "fmt" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -530,3 +531,118 @@ func mkMockDB(t *testing.T, ms []mockSpec) *sql.DB { } return db } + +func TestGetConstraints(t *testing.T) { + + case1 := []mockSpec{ + { + query: `SELECT COUNT\(\*\) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS'; + `, + cols: []string{"COUNT"}, + rows: [][]driver.Value{{1}}, + }, + { + query: `(?i)SELECT\s+COALESCE\(k.COLUMN_NAME,\s*''\)\s+AS\s+COLUMN_NAME,\s+t\.CONSTRAINT_NAME,\s+t\.CONSTRAINT_TYPE,\s+COALESCE\(c.CHECK_CLAUSE,\s*''\)\s+AS\s+CHECK_CLAUSE\s+FROM\s+INFORMATION_SCHEMA\.TABLE_CONSTRAINTS\s+AS\s+t\s+LEFT\s+JOIN\s+INFORMATION_SCHEMA\.KEY_COLUMN_USAGE\s+AS\s+k\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*k\.CONSTRAINT_NAME\s+AND\s+t\.CONSTRAINT_SCHEMA\s*=\s*k\.CONSTRAINT_SCHEMA\s+AND\s+t\.TABLE_NAME\s*=\s*k\.TABLE_NAME\s+LEFT\s+JOIN\s+INFORMATION_SCHEMA\.CHECK_CONSTRAINTS\s+AS\s+c\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*c\.CONSTRAINT_NAME\s+WHERE\s+t\.TABLE_SCHEMA\s*=\s*\?\s+AND\s+t\.TABLE_NAME\s*=\s*\?\s*ORDER\s+BY\s+k\.ORDINAL_POSITION`, + args: []driver.Value{"test_schema", "test_table"}, + cols: []string{"COLUMN_NAME", "CONSTRAINT_NAME", "CONSTRAINT_TYPE", "CHECK_CLAUSE"}, + rows: [][]driver.Value{{"id", "PRIMARY", "PRIMARY KEY", ""}, {"", "chk_test", "CHECK", "amount > 0"}}, + }, + } + + case2 := []mockSpec{ + { + query: `SELECT COUNT\(\*\) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS'; + `, + cols: []string{"COUNT"}, + rows: [][]driver.Value{{0}}, + }, + { + query: `(?i)SELECT\s+k\.COLUMN_NAME,\s+t\.CONSTRAINT_TYPE\s+FROM\s+INFORMATION_SCHEMA\.TABLE_CONSTRAINTS\s+AS\s+t\s+INNER\s+JOIN\s+INFORMATION_SCHEMA\.KEY_COLUMN_USAGE\s+AS\s+k\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*k\.CONSTRAINT_NAME\s+AND\s+t\.CONSTRAINT_SCHEMA\s*=\s*k\.CONSTRAINT_SCHEMA\s+AND\s+t\.TABLE_NAME\s*=\s*k\.TABLE_NAME\s+WHERE\s+k\.TABLE_SCHEMA\s*=\s*\?\s+AND\s+k\.TABLE_NAME\s*=\s*\?\s*ORDER\s+BY\s+k\.ORDINAL_POSITION;`, + args: []driver.Value{"test_schema", "test_table"}, + cols: []string{"COLUMN_NAME", "CONSTRAINT_NAME", "CONSTRAINT_TYPE", "CHECK_CLAUSE"}, + rows: [][]driver.Value{{"id", "PRIMARY", "PRIMARY KEY", ""}}, + }, + } + + cases := []struct { + db []mockSpec + tableExists bool + }{ + { + db: case1, + tableExists: true, + }, + { + db: case2, + tableExists: false, + }, + } + + for _, tc := range cases { + if tc.tableExists { + db := mkMockDB(t, tc.db) + + defer db.Close() + + isi := InfoSchemaImpl{Db: db} + + table := common.SchemaAndName{ + Schema: "test_schema", + Name: "test_table", + } + + conv := new(internal.Conv) + + primaryKeys, checkKeys, constraints, err := isi.GetConstraints(conv, table) + if err != nil { + t.Fatalf("expected no error, but got %v", err) + } + + expectedPrimaryKeys := []string{"id"} + if fmt.Sprintf("%v", primaryKeys) != fmt.Sprintf("%v", expectedPrimaryKeys) { + t.Errorf("expected %v, got %v for primary keys", expectedPrimaryKeys, primaryKeys) + } + + expectedCheckKeys := []schema.CheckConstraint{ + {Name: "chk_test", Expr: "amount > 0", Id: "ck1"}, + } + + assert.Equal(t, expectedCheckKeys, checkKeys) + assert.Equal(t, expectedPrimaryKeys, primaryKeys) + assert.Empty(t, constraints) + } else { + db := mkMockDB(t, tc.db) + + defer db.Close() + + isi := InfoSchemaImpl{Db: db} + + table := common.SchemaAndName{ + Schema: "test_schema", + Name: "test_table", + } + + conv := new(internal.Conv) + + primaryKeys, checkKeys, constraints, err := isi.GetConstraints(conv, table) + if err != nil { + t.Fatalf("expected no error, but got %v", err) + } + + expectedPrimaryKeys := []string{"id"} + if fmt.Sprintf("%v", primaryKeys) != fmt.Sprintf("%v", expectedPrimaryKeys) { + t.Errorf("expected %v, got %v for primary keys", expectedPrimaryKeys, primaryKeys) + } + + assert.Equal(t, expectedPrimaryKeys, primaryKeys) + assert.Empty(t, checkKeys) + assert.Empty(t, constraints) + } + } +} diff --git a/sources/oracle/infoschema.go b/sources/oracle/infoschema.go index f62bfb270..770ef93cc 100644 --- a/sources/oracle/infoschema.go +++ b/sources/oracle/infoschema.go @@ -247,7 +247,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := fmt.Sprintf(` SELECT k.column_name, @@ -260,7 +260,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem `, table.Schema, table.Name) rows, err := isi.Db.Query(q) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -292,7 +292,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys return list all the foreign keys constraints. diff --git a/sources/postgres/infoschema.go b/sources/postgres/infoschema.go index ccf7b8dcf..e4d278b0b 100644 --- a/sources/postgres/infoschema.go +++ b/sources/postgres/infoschema.go @@ -337,7 +337,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k @@ -345,7 +345,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem WHERE k.TABLE_SCHEMA = $1 AND k.TABLE_NAME = $2 ORDER BY k.ordinal_position;` rows, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -368,7 +368,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/sources/spanner/infoschema.go b/sources/spanner/infoschema.go index ca08985bc..c1bf36b4d 100644 --- a/sources/spanner/infoschema.go +++ b/sources/spanner/infoschema.go @@ -190,7 +190,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := `SELECT k.column_name, t.constraint_type FROM information_schema.table_constraints AS t INNER JOIN information_schema.KEY_COLUMN_USAGE AS k @@ -221,11 +221,11 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem break } if err != nil { - return nil, nil, fmt.Errorf("couldn't get row while reading constraints: %w", err) + return nil, nil, nil, fmt.Errorf("couldn't get row while reading constraints: %w", err) } err = row.Columns(&col, &constraint) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if col == "" || constraint == "" { conv.Unexpected("Got empty col or constraint") @@ -238,7 +238,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/sources/sqlserver/infoschema.go b/sources/sqlserver/infoschema.go index 50f46d157..b9c98c544 100644 --- a/sources/sqlserver/infoschema.go +++ b/sources/sqlserver/infoschema.go @@ -280,7 +280,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := ` SELECT k.COLUMN_NAME, @@ -292,7 +292,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem ` rows, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -315,7 +315,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index 91646281a..aad0c90c8 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -265,6 +265,12 @@ type IndexKey struct { Order int } +type CheckConstraint struct { + Id string + Name string + Expr string +} + // PrintPkOrIndexKey unparses the primary or index keys. func (idx IndexKey) PrintPkOrIndexKey(ct CreateTable, c Config) string { col := c.quote(ct.ColDefs[idx.ColId].Name) @@ -319,16 +325,17 @@ func (k Foreignkey) PrintForeignKey(c Config) string { // // create_table: CREATE TABLE table_name ([column_def, ...] ) primary_key [, cluster] type CreateTable struct { - Name string - ColIds []string // Provides names and order of columns - ShardIdColumn string - ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) - PrimaryKeys []IndexKey - ForeignKeys []Foreignkey - Indexes []CreateIndex - ParentTable InterleavedParent //if not empty, this table will be interleaved - Comment string - Id string + Name string + ColIds []string // Provides names and order of columns + ShardIdColumn string + ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) + PrimaryKeys []IndexKey + ForeignKeys []Foreignkey + Indexes []CreateIndex + ParentTable InterleavedParent //if not empty, this table will be interleaved + CheckConstraints []CheckConstraint + Comment string + Id string } // PrintCreateTable unparses a CREATE TABLE statement. @@ -382,13 +389,20 @@ func (ct CreateTable) PrintCreateTable(spSchema Schema, config Config) string { } } + var checkString string + if len(ct.CheckConstraints) > 0 { + checkString = PrintCheckConstraintTable(ct.CheckConstraints) + } else { + checkString = "" + } + if len(keys) == 0 { - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) %s", tableComment, config.quote(ct.Name), cols, interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s%s) %s", tableComment, config.quote(ct.Name), cols, checkString, interleave) } if config.SpDialect == constants.DIALECT_POSTGRESQL { return fmt.Sprintf("%sCREATE TABLE %s (\n%s\tPRIMARY KEY (%s)\n)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) } - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s%s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, checkString, strings.Join(keys, ", "), interleave) } // CreateIndex encodes the following DDL definition: @@ -534,6 +548,23 @@ func (k Foreignkey) PrintForeignKeyAlterTable(spannerSchema Schema, c Config, ta return s } +// PrintCheckConstraintTable formats the check constraints in SQL syntax. +func PrintCheckConstraintTable(cks []CheckConstraint) string { + var builder strings.Builder + + for _, col := range cks { + builder.WriteString(fmt.Sprintf("\tCONSTRAINT %s CHECK %s,\n", col.Name, col.Expr)) + } + + if builder.Len() > 0 { + // Trim the trailing comma and newline + result := builder.String() + return result[:len(result)-2] + "\n" + } + + return "" +} + // Schema stores a map of table names and Tables. type Schema map[string]CreateTable diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index b9e6a510e..329af1362 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -143,6 +143,10 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: []IndexKey{{ColId: "col1", Desc: true}}, ForeignKeys: nil, + CheckConstraints: []CheckConstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, ParentTable: InterleavedParent{}, Comment: "", @@ -156,12 +160,13 @@ func TestPrintCreateTable(t *testing.T) { "col4": {Name: "col4", T: Type{Name: Int64}, NotNull: true}, "col5": {Name: "col5", T: Type{Name: String, Len: MaxLength}, NotNull: false}, }, - PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, - ForeignKeys: nil, - Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, - Comment: "", - Id: "t2", + PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraints: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, + Comment: "", + Id: "t2", }, "t3": CreateTable{ Name: "table3", @@ -170,12 +175,33 @@ func TestPrintCreateTable(t *testing.T) { ColDefs: map[string]ColumnDef{ "col6": {Name: "col6", T: Type{Name: Int64}, NotNull: true}, }, - PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraints: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + Comment: "", + Id: "t3", + }, + "t4": CreateTable{ + Name: "table1", + ColIds: []string{"col1", "col2", "col3"}, + ShardIdColumn: "", + ColDefs: map[string]ColumnDef{ + "col1": {Name: "col1", T: Type{Name: Int64}, NotNull: true}, + "col2": {Name: "col2", T: Type{Name: String, Len: MaxLength}, NotNull: false}, + "col3": {Name: "col3", T: Type{Name: Bytes, Len: int64(42)}, NotNull: false}, + }, + PrimaryKeys: nil, ForeignKeys: nil, + CheckConstraints: []CheckConstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + ParentTable: InterleavedParent{}, Comment: "", - Id: "t3", + Id: "t1", }, } tests := []struct { @@ -192,6 +218,7 @@ func TestPrintCreateTable(t *testing.T) { " col1 INT64 NOT NULL ,\n" + " col2 STRING(MAX),\n" + " col3 BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (col1 DESC)", }, { @@ -202,6 +229,7 @@ func TestPrintCreateTable(t *testing.T) { " `col1` INT64 NOT NULL ,\n" + " `col2` STRING(MAX),\n" + " `col3` BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (`col1` DESC)", }, { @@ -223,6 +251,17 @@ func TestPrintCreateTable(t *testing.T) { ") PRIMARY KEY (col6 DESC),\n" + "INTERLEAVE IN PARENT table1", }, + { + "no quote", + false, + s["t4"], + "CREATE TABLE table1 (\n" + + " col1 INT64 NOT NULL ,\n" + + " col2 STRING(MAX),\n" + + " col3 BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + + ") ", + }, } for _, tc := range tests { assert.Equal(t, tc.expected, tc.ct.PrintCreateTable(s, Config{ProtectIds: tc.protectIds})) @@ -1040,3 +1079,39 @@ func TestGetSortedTableIdsBySpName(t *testing.T) { }) } } + +func TestPrintCheckConstraintTable(t *testing.T) { + tests := []struct { + description string + cks []CheckConstraint + expected string + }{ + { + description: "Empty constraints list", + cks: []CheckConstraint{}, + expected: "", + }, + { + description: "Single constraint", + cks: []CheckConstraint{ + {Name: "ck1", Expr: "(id > 0)"}, + }, + expected: "\tCONSTRAINT ck1 CHECK (id > 0)\n", + }, + { + description: "Multiple constraints", + cks: []CheckConstraint{ + {Name: "ck1", Expr: "(id > 0)"}, + {Name: "ck2", Expr: "(name IS NOT NULL)"}, + }, + expected: "\tCONSTRAINT ck1 CHECK (id > 0),\n\tCONSTRAINT ck2 CHECK (name IS NOT NULL)\n", + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + actual := PrintCheckConstraintTable(tc.cks) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/webv2/api/schema.go b/webv2/api/schema.go index 18cd3d4ff..49975f5d8 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -488,6 +488,91 @@ func RestoreSecondaryIndex(w http.ResponseWriter, r *http.Request) { } +// UpdateCheckConstraint processes the request to update spanner table check constraints, ensuring session and schema validity, and responds with the updated conversion metadata. +func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + newCKs := []ddl.CheckConstraint{} + if err = json.Unmarshal(reqBody, &newCKs); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + sp := sessionState.Conv.SpSchema[tableId] + sp.CheckConstraints = newCKs + sessionState.Conv.SpSchema[tableId] = sp + session.UpdateSessionFile() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +func doesNameExist(spcks []ddl.CheckConstraint, targetName string) bool { + for _, spck := range spcks { + if strings.Contains(spck.Expr, targetName) { + return true + } + } + return false +} + +// ValidateCheckConstraint verifies if the type of a database column has been altered and add an error if a change is detected. +func ValidateCheckConstraint(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + sp := sessionState.Conv.SpSchema + srcschema := sessionState.Conv.SrcSchema + flag := true + var schemaIssue []internal.SchemaIssue + + for _, src := range srcschema { + for _, col := range sp[src.Id].ColDefs { + if len(sp[src.Id].CheckConstraints) > 0 { + spType := col.T.Name + srcType := srcschema[src.Id].ColDefs[col.Id].Type + actualType := mysqlDefaultTypeMap[srcType.Name] + if actualType.Name != spType { + columnName := sp[src.Id].ColDefs[col.Id].Name + spcks := sp[src.Id].CheckConstraints + if doesNameExist(spcks, columnName) { + flag = false + schemaIssue = sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] + if !utilities.IsSchemaIssuePresent(schemaIssue, internal.TypeMismatch) { + schemaIssue = append(schemaIssue, internal.TypeMismatch) + } + sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] = schemaIssue + break + } + } + } + } + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(flag) +} + // renameForeignKeys checks the new names for spanner name validity, ensures the new names are already not used by existing tables // secondary indexes or foreign key constraints. If above checks passed then foreignKey renaming reflected in the schema else appropriate // error thrown. diff --git a/webv2/api/schema_test.go b/webv2/api/schema_test.go index 9e7894ea8..8623be071 100644 --- a/webv2/api/schema_test.go +++ b/webv2/api/schema_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -2541,3 +2542,276 @@ func TestGetAutoGenMapMySQL(t *testing.T) { } } +func TestUpdateCheckConstraint(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + tableID := "table1" + + expectedCheckConstraint := []ddl.CheckConstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + } + + checkConstraints := []schema.CheckConstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + } + + body, err := json.Marshal(checkConstraints) + assert.NoError(t, err) + + req, err := http.NewRequest("POST", "update/cks", bytes.NewBuffer(body)) + assert.NoError(t, err) + + q := req.URL.Query() + q.Add("table", tableID) + req.URL.RawQuery = q.Encode() + + rr := httptest.NewRecorder() + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + updatedSp := sessionState.Conv.SpSchema[tableID] + + assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraints) +} + +func TestUpdateCheckConstraint_ParseError(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + invalidJSON := "invalid json body" + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cks", io.NopCloser(strings.NewReader(invalidJSON))) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + + expectedErrorMessage := "Request Body parse error" + assert.Contains(t, rr.Body.String(), expectedErrorMessage) +} + +type errReader struct{} + +func (errReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("simulated read error") +} + +func TestUpdateCheckConstraint_ImproperSession(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Conv = nil // Simulate no conversion + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cks", io.NopCloser(errReader{})) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly") + +} + +func TestValidateCheckConstraint_ImproperSession(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Conv = nil // Simulate no conversion + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) + assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly") + +} + +func TestValidateCheckConstraint_NoTypeMismatch(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + buildConvMySQL_NoTypeMatch(sessionState.Conv) + rr1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "/spannerDefaultTypeMap", nil) + + handler1 := http.HandlerFunc(api.SpannerDefaultTypeMap) + handler1.ServeHTTP(rr1, req1) + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var responseFlag bool + json.NewDecoder(rr.Body).Decode(&responseFlag) + assert.True(t, responseFlag) +} + +func TestValidateCheckConstraint_TypeMismatch(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + buildConvMySQL_TypeMatch(sessionState.Conv) + + rr1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "/spannerDefaultTypeMap", nil) + + handler1 := http.HandlerFunc(api.SpannerDefaultTypeMap) + handler1.ServeHTTP(rr1, req1) + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var responseFlag bool + json.NewDecoder(rr.Body).Decode(&responseFlag) + assert.False(t, responseFlag) + issues := sessionState.Conv.SchemaIssues["t1"].ColumnLevelIssues["c2"] + assert.Contains(t, issues, internal.TypeMismatch) +} + +func buildConvMySQL_NoTypeMatch(conv *internal.Conv) { + conv.SrcSchema = map[string]schema.Table{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []schema.CheckConstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Id: "c1", Type: schema.Type{Name: "json"}}, + "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "decimal"}}, + "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "datetime"}}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1"}}}, + } + conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []ddl.CheckConstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.JSON}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Numeric}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Timestamp}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + } + + conv.SchemaIssues = map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Time}, + }, + }, + } + conv.SyntheticPKeys["t2"] = internal.SyntheticPKey{"c20", 0} + conv.Audit.MigrationType = migration.MigrationData_SCHEMA_AND_DATA.Enum() +} + +func buildConvMySQL_TypeMatch(conv *internal.Conv) { + conv.SrcSchema = map[string]schema.Table{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []schema.CheckConstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Id: "c1", Type: schema.Type{Name: "json"}}, + "c2": {Name: "age", Id: "c2", Type: schema.Type{Name: "decimal"}}, + "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "datetime"}}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1"}}}, + } + conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []ddl.CheckConstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.JSON}}, + "c2": {Name: "age", Id: "c2", T: ddl.Type{Name: ddl.String}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Timestamp}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + } + + conv.SchemaIssues = map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Time}, + }, + }, + } + conv.SyntheticPKeys["t2"] = internal.SyntheticPKey{"c20", 0} + conv.Audit.MigrationType = migration.MigrationData_SCHEMA_AND_DATA.Enum() +} diff --git a/webv2/routes.go b/webv2/routes.go index 191c1e1a5..6fc511320 100644 --- a/webv2/routes.go +++ b/webv2/routes.go @@ -75,6 +75,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/spannerDefaultTypeMap", api.SpannerDefaultTypeMap).Methods("GET") router.HandleFunc("/autoGenMap", api.GetAutoGenMap).Methods("GET") router.HandleFunc("/getSequenceKind", api.GetSequenceKind).Methods("GET") + router.HandleFunc("/validateCheckConstraint", api.ValidateCheckConstraint).Methods("GET") router.HandleFunc("/setparent", api.SetParentTable).Methods("GET") router.HandleFunc("/removeParent", api.RemoveParentTable).Methods("POST") @@ -92,6 +93,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/UpdateSequence", api.UpdateSequence).Methods("POST") router.HandleFunc("/update/fks", api.UpdateForeignKeys).Methods("POST") + router.HandleFunc("/update/cks", api.UpdateCheckConstraint).Methods("POST") router.HandleFunc("/update/indexes", api.UpdateIndexes).Methods("POST") // Session Management diff --git a/webv2/table/review_table_schema.go b/webv2/table/review_table_schema.go index 0ad6fb0ec..5e9635e31 100644 --- a/webv2/table/review_table_schema.go +++ b/webv2/table/review_table_schema.go @@ -108,6 +108,13 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { return } } + oldName := conv.SrcSchema[tableId].ColDefs[colId].Name + + for i := range conv.SpSchema[tableId].CheckConstraints { + originalString := conv.SpSchema[tableId].CheckConstraints[i].Expr + updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) + conv.SpSchema[tableId].CheckConstraints[i].Expr = updatedValue + } interleaveTableSchema = reviewRenameColumn(v.Rename, tableId, colId, conv, interleaveTableSchema) @@ -148,7 +155,7 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { } } - if !v.Removed && !v.Add && v.Rename== ""{ + if !v.Removed && !v.Add && v.Rename == "" { sequences := UpdateAutoGenCol(v.AutoGen, tableId, colId, conv) conv.SpSequences = sequences } diff --git a/webv2/table/update_table_schema.go b/webv2/table/update_table_schema.go index 6e03c709c..27d05ad47 100644 --- a/webv2/table/update_table_schema.go +++ b/webv2/table/update_table_schema.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "net/http" + "strings" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" @@ -55,6 +56,7 @@ type updateTable struct { // (3) Rename column. // (4) Add or Remove NotNull constraint. // (5) Update Spanner type. +// (6) Update Check constraints Name. func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { reqBody, err := ioutil.ReadAll(r.Body) @@ -96,6 +98,14 @@ func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { if v.Rename != "" && v.Rename != conv.SpSchema[tableId].ColDefs[colId].Name { + oldName := conv.SrcSchema[tableId].ColDefs[colId].Name + + for i := range conv.SpSchema[tableId].CheckConstraints { + originalString := conv.SpSchema[tableId].CheckConstraints[i].Expr + updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) + conv.SpSchema[tableId].CheckConstraints[i].Expr = updatedValue + } + renameColumn(v.Rename, tableId, colId, conv) }