Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: APIs for Backend Changes for Default Values #965

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions conversion/conversion_from_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type DataFromSourceImpl struct{}
func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) {
conv := internal.MakeConv()
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
conv.Source = sourceProfile.Driver
//handle fetching schema differently for sharded migrations, we only connect to the primary shard to
//fetch the schema. We reuse the SourceProfileConnection object for this purpose.
var infoSchema common.InfoSchema
Expand Down Expand Up @@ -159,6 +162,9 @@ func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile p
return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source")
}
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
conv.Source = sourceProfile.Driver
dialect, err := targetProfile.FetchTargetDialect(ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch dialect: %v", err)
Expand Down
107 changes: 104 additions & 3 deletions expressions_api/expression_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"sync"

spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/task"
Expand All @@ -18,22 +19,49 @@ const THREAD_POOL = 500
type ExpressionVerificationAccessor interface {
//Batch API which parallelizes expression verification calls
VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput
RefreshSpannerClient(ctx context.Context, project string, instance string) error
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
}

type ExpressionVerificationAccessorImpl struct {
SpannerAccessor *spanneraccessor.SpannerAccessorImpl
}

func NewExpressionVerificationAccessorImpl(ctx context.Context, project string, instance string) (*ExpressionVerificationAccessorImpl, error) {
spannerAccessor, err := spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
if err != nil {
return nil, err
var spannerAccessor *spanneraccessor.SpannerAccessorImpl
var err error
if project != "" && instance != "" {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
if err != nil {
return nil, err
}
} else {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImpl(ctx)
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}
}
return &ExpressionVerificationAccessorImpl{
SpannerAccessor: spannerAccessor,
}, nil
}

type DDLVerifier interface {
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error)
GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
RefreshSpannerClient(ctx context.Context, project string, instance string) error
}
type DDLVerifierImpl struct {
Expressions ExpressionVerificationAccessor
}

func NewDDLVerifierImpl(ctx context.Context, project string, instance string) (*DDLVerifierImpl, error) {
expVerifier, err := NewExpressionVerificationAccessorImpl(ctx, project, instance)
return &DDLVerifierImpl{
Expressions: expVerifier,
}, err
}

func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
err := ev.validateRequest(verifyExpressionsInput)
if err != nil {
Expand Down Expand Up @@ -79,6 +107,15 @@ func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Cont
return verifyExpressionsOutput
}

func (ev *ExpressionVerificationAccessorImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
spannerClient, err := spannerclient.NewSpannerClientImpl(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}
ev.SpannerAccessor.SpannerClient = spannerClient
return nil
}

func (ev *ExpressionVerificationAccessorImpl) verifyExpressionInternal(expressionDetail internal.ExpressionDetail, mutex *sync.Mutex) task.TaskResult[internal.ExpressionVerificationOutput] {
var sqlStatement string
switch expressionDetail.Type {
Expand Down Expand Up @@ -129,3 +166,67 @@ func (ev *ExpressionVerificationAccessorImpl) removeExpressions(inputConv *inter
}
return convCopy, nil
}

func (ddlv *DDLVerifierImpl) VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) {
ctx := context.Background()
verifyExpressionsInput := internal.VerifyExpressionsInput{
Conv: conv,
Source: conv.Source,
ExpressionDetailList: expressionDetails,
}
verificationResults := ddlv.Expressions.VerifyExpressions(ctx, verifyExpressionsInput)

return verificationResults, verificationResults.Err
}

func (ddlv *DDLVerifierImpl) GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
srcTable := conv.SrcSchema[tableId]
for _, srcColId := range srcTable.ColIds {
srcCol := srcTable.ColDefs[srcColId]
if srcCol.DefaultValue.IsPresent {
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema[tableId].ColDefs[srcColId].T.Name,
},
ExpressionId: srcCol.DefaultValue.Value.ExpressionId,
Expression: srcCol.DefaultValue.Value.Statement,
Type: "DEFAULT",
Metadata: map[string]string{"TableId": tableId, "ColId": srcColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}

func (ddlv *DDLVerifierImpl) GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
spTable := conv.SpSchema[tableId]
for _, spColId := range spTable.ColIds {
spCol := spTable.ColDefs[spColId]
if spCol.DefaultValue.IsPresent {
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema[tableId].ColDefs[spColId].T.Name,
},
ExpressionId: spCol.DefaultValue.Value.ExpressionId,
Expression: spCol.DefaultValue.Value.Statement,
Type: "DEFAULT",
Metadata: map[string]string{"TableId": tableId, "ColId": spColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}

func (ddlv *DDLVerifierImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
return ddlv.Expressions.RefreshSpannerClient(ctx, project, instance)
}
187 changes: 185 additions & 2 deletions expressions_api/expression_verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/logger"
"github.com/GoogleCloudPlatform/spanner-migration-tool/schema"
"github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl"
"github.com/googleapis/gax-go/v2"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
Expand All @@ -32,8 +34,8 @@ func TestVerifyExpressions(t *testing.T) {
conv := internal.MakeConv()
ReadSessionFile(conv, "../../test_data/session_expression_verify.json")
input := internal.VerifyExpressionsInput{
Conv: conv,
Source: "mysql",
Conv: conv,
Source: "mysql",
ExpressionDetailList: []internal.ExpressionDetail{
{
Expression: "id > 10",
Expand Down Expand Up @@ -297,3 +299,184 @@ func ReadSessionFile(conv *internal.Conv, sessionJSON string) error {
}
return nil
}

func TestVerifySpannerDDL(t *testing.T) {
conv := *internal.MakeConv()
testCases := []struct {
name string
conv internal.Conv
expressionDetails []internal.ExpressionDetail
verifyExpressionMock expressions_api.MockExpressionVerificationAccessor
errorExpected bool
}{
{
name: "no error flow",
conv: conv,
expressionDetails: []internal.ExpressionDetail{},
verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{
VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
return internal.VerifyExpressionsOutput{
ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{},
Err: nil,
}
},
},
errorExpected: false,
},
{
name: "error flow",
conv: conv,
expressionDetails: []internal.ExpressionDetail{},
verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{
VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
return internal.VerifyExpressionsOutput{
ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{},
Err: fmt.Errorf("error"),
}
},
},
errorExpected: true,
},
}

for _, tc := range testCases {
ddlV := expressions_api.DDLVerifierImpl{
Expressions: &tc.verifyExpressionMock,
}
_, err := ddlV.VerifySpannerDDL(&tc.conv, tc.expressionDetails)
assert.Equal(t, tc.errorExpected, err != nil)
}
}

func TestGetSourceExpressionDetails(t *testing.T) {
conv := internal.MakeConv()
conv.SrcSchema = map[string]schema.Table{
"table1": {
ColIds: []string{"col1", "col2"},
ColDefs: map[string]schema.Column{
"col1": {
DefaultValue: ddl.DefaultValue{
IsPresent: true,
Value: ddl.Expression{
ExpressionId: "expr1",
Statement: "SELECT 1",
},
},
},
"col2": {
DefaultValue: ddl.DefaultValue{},
},
},
},
}
conv.SpSchema = ddl.Schema{
"table1": {
ColDefs: map[string]ddl.ColumnDef{
"col1": {
T: ddl.Type{
Name: "INT64",
},
},
},
},
}

testCases := []struct {
name string
conv *internal.Conv
tableIds []string
expectedDetails []internal.ExpressionDetail
}{
{
name: "single table with default value",
conv: conv,
tableIds: []string{"table1"},
expectedDetails: []internal.ExpressionDetail{
{
ReferenceElement: internal.ReferenceElement{
Name: "INT64",
},
ExpressionId: "expr1",
Expression: "SELECT 1",
Type: "DEFAULT",
Metadata: map[string]string{"TableId": "table1", "ColId": "col1"},
},
},
},
{
name: "no tables",
conv: conv,
tableIds: []string{},
expectedDetails: []internal.ExpressionDetail{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ddlv := &expressions_api.DDLVerifierImpl{}
actualDetails := ddlv.GetSourceExpressionDetails(tc.conv, tc.tableIds)
assert.Equal(t, tc.expectedDetails, actualDetails)
})
}
}

func TestGetSpannerExpressionDetails(t *testing.T) {
conv := internal.MakeConv()
conv.SpSchema = ddl.Schema{
"table1": {
ColIds: []string{"col1", "col2"},
ColDefs: map[string]ddl.ColumnDef{
"col1": {
DefaultValue: ddl.DefaultValue{
IsPresent: true,
Value: ddl.Expression{
ExpressionId: "expr1",
Statement: "SELECT 1",
},
},
},
"col2": {
DefaultValue: ddl.DefaultValue{},
},
},
},
}

testCases := []struct {
name string
conv *internal.Conv
tableIds []string
expectedDetails []internal.ExpressionDetail
}{
{
name: "single table with default value",
conv: conv,
tableIds: []string{"table1"},
expectedDetails: []internal.ExpressionDetail{
{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema["table1"].ColDefs["col1"].T.Name,
},
ExpressionId: "expr1",
Expression: "SELECT 1",
Type: "DEFAULT",
Metadata: map[string]string{"TableId": "table1", "ColId": "col1"},
},
},
},
{
name: "no tables",
conv: conv,
tableIds: []string{},
expectedDetails: []internal.ExpressionDetail{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ddlv := &expressions_api.DDLVerifierImpl{}
actualDetails := ddlv.GetSpannerExpressionDetails(tc.conv, tc.tableIds)
assert.Equal(t, tc.expectedDetails, actualDetails)
})
}
}
Loading
Loading