Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default Value: Spanner Metadata Accessor #886

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package spannermetadataclient

import (
"context"
"fmt"
"sync"

sp "cloud.google.com/go/spanner"
)

var once sync.Once
var spannermetadataClient *sp.Client

var newClient = sp.NewClient

func GetOrCreateClient(ctx context.Context, dbURI string) (*sp.Client, error) {
var err error
if spannermetadataClient == nil {
once.Do(func() {
spannermetadataClient, err = newClient(ctx, dbURI)
})
if err != nil {
return nil, fmt.Errorf("failed to create spanner metadata database client: %v", err)
}
return spannermetadataClient, nil
}
return spannermetadataClient, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package spannermetadataclient

import (
"context"
"fmt"
"os"
"sync"
"testing"

sp "cloud.google.com/go/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/logger"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"google.golang.org/api/option"
)

func init() {
logger.Log = zap.NewNop()
}

func TestMain(m *testing.M) {
res := m.Run()
os.Exit(res)
}

func resetTest() {
spannermetadataClient = nil
once = sync.Once{}
}

func TestGetOrCreateClient_Basic(t *testing.T) {
resetTest()
ctx := context.Background()
oldFunc := newClient
defer func() { newClient = oldFunc }()
newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) {
return &sp.Client{}, nil
}
client, err := GetOrCreateClient(ctx, "testURI")
assert.NotNil(t, client)
assert.Nil(t, err)
}

func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) {
resetTest()
ctx := context.Background()
oldFunc := newClient
defer func() { newClient = oldFunc }()

newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) {
return &sp.Client{}, nil
}
client, err := GetOrCreateClient(ctx, "testURI")
assert.NotNil(t, client)
assert.Nil(t, err)
// Explicitly set the client to nil. Running GetOrCreateClient should not create a
// new client since sync would already be executed.
spannermetadataClient = nil
newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) {
return nil, fmt.Errorf("test error")
}
client, err = GetOrCreateClient(ctx, "testURI")
assert.Nil(t, client)
assert.Nil(t, err)
}

func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) {
resetTest()
ctx := context.Background()
oldFunc := newClient
defer func() { newClient = oldFunc }()

newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) {
return &sp.Client{}, nil
}
oldC, err := GetOrCreateClient(ctx, "testURI")
assert.NotNil(t, oldC)
assert.Nil(t, err)

// Explicitly reset once. Running GetOrCreateClient should not create a
// new client the if condition should prevent it.
once = sync.Once{}
newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) {
return nil, fmt.Errorf("test error")
}
newC, err := GetOrCreateClient(ctx, "testURI")
assert.Equal(t, oldC, newC)
assert.Nil(t, err)
}

func TestGetOrCreateClient_Error(t *testing.T) {
resetTest()
ctx := context.Background()
oldFunc := newClient
defer func() { newClient = oldFunc }()

newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) {
return nil, fmt.Errorf("test error")
}
client, err := GetOrCreateClient(ctx, "testURI")
assert.Nil(t, client)
assert.NotNil(t, err)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package spannermetadataaccessor

import (
"context"
"fmt"

"cloud.google.com/go/spanner"
spannermetadataclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner/spannermetadataaccessor/clients"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"google.golang.org/api/iterator"
)

type SpannerMetadataAccessor interface {
// IsSpannerSupportedDefaultStatement checks if the given statement is supported by Spanner.
IsSpannerSupportedDefaultStatement(SpProjectId string, SpInstanceId string, statement string, coldatatype string) bool
}

type SpannerMetadataAccessorImpl struct{}

func (spm *SpannerMetadataAccessorImpl) IsSpannerSupportedDefaultStatement(SpProjectId string, SpInstanceId string, statement string, coldatatype string) bool {
db := getSpannerMetadataDbUri(SpProjectId, SpInstanceId)
if SpProjectId == "" || SpInstanceId == "" {
return false
}

ctx := context.Background()
spmClient, err := spannermetadataclient.GetOrCreateClient(ctx, db)
if err != nil {
return false
}

if spmClient == nil {
return false
}
stmt := spanner.Statement{
SQL: "SELECT CAST(" + statement + " AS " + coldatatype + ") AS statementValue",
}
iter := spmClient.Single().Query(ctx, stmt)
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
return true
}
if err != nil {
return false
}

}

}

func getSpannerMetadataDbUri(projectId string, instanceId string) string {
return fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, constants.METADATA_DB)
}
4 changes: 3 additions & 1 deletion cmd/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface
}

// validateExistingDb validates that the existing spanner schema is in accordance with the one specified in the session file.
func validateExistingDb(ctx context.Context, spDialect, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, conv *internal.Conv) error {
func validateExistingDb(SpProjectId string, SpInstanceId string, ctx context.Context, spDialect, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, conv *internal.Conv) error {
NirnayaSindhuSuthari marked this conversation as resolved.
Show resolved Hide resolved
adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx)
if err != nil {
return err
Expand All @@ -230,6 +230,8 @@ func validateExistingDb(ctx context.Context, spDialect, dbURI string, adminClien
}
spannerConv := internal.MakeConv()
spannerConv.SpDialect = spDialect
spannerConv.SpProjectId = SpProjectId
spannerConv.SpInstanceId = SpInstanceId
err = utils.ReadSpannerSchema(ctx, spannerConv, client)
if err != nil {
err = fmt.Errorf("can't read spanner schema: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func migrateData(ctx context.Context, migrationProjectId string, targetProfile p
err error
)
if !sourceProfile.UseTargetSchema() {
err = validateExistingDb(ctx, conv.SpDialect, dbURI, adminClient, client, conv)
err = validateExistingDb(conv.SpProjectId, conv.SpInstanceId, ctx, conv.SpDialect, dbURI, adminClient, client, conv)
if err != nil {
err = fmt.Errorf("error while validating existing database: %v", err)
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion conversion/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (ci *ConvImpl) SchemaConv(migrationProjectId string, sourceProfile profiles
case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB, constants.SQLSERVER, constants.ORACLE:
return schemaFromSource.schemaFromDatabase(migrationProjectId, sourceProfile, targetProfile, &GetInfoImpl{}, &common.ProcessSchemaImpl{})
case constants.PGDUMP, constants.MYSQLDUMP:
return schemaFromSource.SchemaFromDump(sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{})
return schemaFromSource.SchemaFromDump(targetProfile.Conn.Sp.Project, targetProfile.Conn.Sp.Instance, sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had discussed that you will pass the whole targetProfile here right? Instead of 3 params

default:
return nil, fmt.Errorf("schema conversion for driver %s not supported", sourceProfile.Driver)
}
Expand Down
10 changes: 8 additions & 2 deletions conversion/conversion_from_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (

type SchemaFromSourceInterface interface {
schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error)
SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error)
SchemaFromDump(SpProjectId string, SpInstanceId string, driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error)
}

type SchemaFromSourceImpl struct{}
Expand All @@ -54,6 +54,8 @@ type DataFromSourceImpl struct{}
func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) {
conv := internal.MakeConv()
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
//handle fetching schema differently for sharded migrations, we only connect to the primary shard to
//fetch the schema. We reuse the SourceProfileConnection object for this purpose.
var infoSchema common.InfoSchema
Expand Down Expand Up @@ -99,7 +101,7 @@ func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string,
return conv, processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{})
}

func (sads *SchemaFromSourceImpl) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) {
func (sads *SchemaFromSourceImpl) SchemaFromDump(SpProjectId string, SpInstanceId string, driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) {
f, n, err := getSeekable(ioHelper.In)
if err != nil {
utils.PrintSeekError(driver, err, ioHelper.Out)
Expand All @@ -109,6 +111,8 @@ func (sads *SchemaFromSourceImpl) SchemaFromDump(driver string, spDialect string
ioHelper.BytesRead = n
conv := internal.MakeConv()
conv.SpDialect = spDialect
conv.SpProjectId = SpProjectId
conv.SpInstanceId = SpInstanceId
p := internal.NewProgress(n, "Generating schema", internal.Verbose(), false, int(internal.SchemaCreationInProgress))
r := internal.NewReader(bufio.NewReader(f), p)
conv.SetSchemaMode() // Build schema and ignore data in dump.
Expand Down Expand Up @@ -159,6 +163,8 @@ func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile p
return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source")
}
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
dialect, err := targetProfile.FetchTargetDialect(ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch dialect: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions conversion/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ func (msads *MockSchemaFromSource) schemaFromDatabase(migrationProjectId string,
args := msads.Called(migrationProjectId, sourceProfile, targetProfile, getInfo, processSchema)
return args.Get(0).(*internal.Conv), args.Error(1)
}
func (msads *MockSchemaFromSource) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) {
args := msads.Called(driver, spDialect, ioHelper, processDump)
func (msads *MockSchemaFromSource) SchemaFromDump(SpProjectId string, SpInstanceId string, driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) {
args := msads.Called(SpProjectId, SpInstanceId, driver, spDialect, ioHelper, processDump)
return args.Get(0).(*internal.Conv), args.Error(1)
}

Expand Down
2 changes: 2 additions & 0 deletions internal/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ type Conv struct {
Stats stats `json:"-"`
TimezoneOffset string // Timezone offset for timestamp conversion.
SpDialect string // The dialect of the spanner database to which Spanner migration tool is writing.
SpProjectId string // The projectId of the spanner database to which Spanner migration tool is writing.
SpInstanceId string // The instanceId of the spanner database to which Spanner migration tool is writing.
UniquePKey map[string][]string // Maps Spanner table name to unique column name being used as primary key (if needed).
Audit Audit `json:"-"` // Stores the audit information for the database conversion
Rules []Rule // Stores applied rules during schema conversion
Expand Down
27 changes: 21 additions & 6 deletions sources/common/toddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"strconv"
"unicode"

spannermetadataaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner/spannermetadataaccessor"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/schema"
Expand Down Expand Up @@ -146,13 +147,27 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod
columnLevelIssues[srcColId] = issues
}

defaultVal := ddl.DefaultValue{
IsPresent: false,
Value: "",
}

if srcCol.DefaultValue.IsPresent {
spM := spannermetadataaccessor.SpannerMetadataAccessorImpl{}
defaultVal.IsPresent = spM.IsSpannerSupportedDefaultStatement(conv.SpProjectId, conv.SpInstanceId, srcCol.DefaultValue.Value, ty.Name)
}
if defaultVal.IsPresent {
defaultVal.Value = srcCol.DefaultValue.Value
}

spColDef[srcColId] = ddl.ColumnDef{
Name: colName,
T: ty,
NotNull: isNotNull,
Comment: "From: " + quoteIfNeeded(srcCol.Name) + " " + srcCol.Type.Print(),
Id: srcColId,
AutoGen: *autoGenCol,
Name: colName,
T: ty,
NotNull: isNotNull,
Comment: "From: " + quoteIfNeeded(srcCol.Name) + " " + srcCol.Type.Print(),
Id: srcColId,
AutoGen: *autoGenCol,
DefaultValue: defaultVal,
NirnayaSindhuSuthari marked this conversation as resolved.
Show resolved Hide resolved
}
if !checkIfColumnIsPartOfPK(srcColId, srcTable.PrimaryKeys) {
totalNonKeyColumnSize += getColumnSize(ty.Name, ty.Len)
Expand Down
Loading
Loading