diff --git a/Makefile b/Makefile index 6b0f8744cbf..f90016290ff 100644 --- a/Makefile +++ b/Makefile @@ -103,7 +103,7 @@ parser: make -C go/vt/sqlparser visitor: - go generate go/vt/sqlparser/rewriter.go + go run ./go/tools/asthelpergen -in ./go/vt/sqlparser -iface vitess.io/vitess/go/vt/sqlparser.SQLNode -except "*ColName" sizegen: go run go/tools/sizegen/sizegen.go \ diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go new file mode 100644 index 00000000000..0bf452fccaa --- /dev/null +++ b/go/tools/asthelpergen/asthelpergen.go @@ -0,0 +1,291 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/types" + "io/ioutil" + "log" + "path" + "strings" + + "github.com/dave/jennifer/jen" + "golang.org/x/tools/go/packages" +) + +const licenseFileHeader = `Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License.` + +type generator interface { + visitStruct(t types.Type, stroct *types.Struct) error + visitInterface(t types.Type, iface *types.Interface) error + visitSlice(t types.Type, slice *types.Slice) error + createFile(pkgName string) (string, *jen.File) +} + +// astHelperGen finds implementations of the given interface, +// and uses the supplied `generator`s to produce the output code +type astHelperGen struct { + DebugTypes bool + mod *packages.Module + sizes types.Sizes + namedIface *types.Named + iface *types.Interface + gens []generator +} + +func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator) *astHelperGen { + return &astHelperGen{ + DebugTypes: true, + mod: mod, + sizes: sizes, + namedIface: named, + iface: named.Underlying().(*types.Interface), + gens: generators, + } +} + +func findImplementations(scope *types.Scope, iff *types.Interface, impl func(types.Type) error) error { + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if _, ok := obj.(*types.TypeName); !ok { + continue + } + baseType := obj.Type() + if types.Implements(baseType, iff) { + err := impl(baseType) + if err != nil { + return err + } + continue + } + pointerT := types.NewPointer(baseType) + if types.Implements(pointerT, iff) { + err := impl(pointerT) + if err != nil { + return err + } + continue + } + } + return nil +} + +func (gen *astHelperGen) visitStruct(t types.Type, stroct *types.Struct) error { + for _, g := range gen.gens { + err := g.visitStruct(t, stroct) + if err != nil { + return err + } + } + return nil +} + +func (gen *astHelperGen) visitSlice(t types.Type, slice *types.Slice) error { + for _, g := range gen.gens { + err := g.visitSlice(t, slice) + if err != nil { + return err + } + } + return nil +} + +func (gen *astHelperGen) visitInterface(t types.Type, iface *types.Interface) error { + for _, g := range gen.gens { + err := g.visitInterface(t, iface) + if err != nil { + return err + } + } + return nil +} + +// GenerateCode is the main loop where we build up the code per file. +func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) { + pkg := gen.namedIface.Obj().Pkg() + iface, ok := gen.iface.Underlying().(*types.Interface) + if !ok { + return nil, fmt.Errorf("expected interface, but got %T", gen.iface) + } + + err := findImplementations(pkg.Scope(), iface, func(t types.Type) error { + switch n := t.Underlying().(type) { + case *types.Struct: + return gen.visitStruct(t, n) + case *types.Slice: + return gen.visitSlice(t, n) + case *types.Pointer: + strct, isStrct := n.Elem().Underlying().(*types.Struct) + if isStrct { + return gen.visitStruct(t, strct) + } + case *types.Interface: + return gen.visitInterface(t, n) + default: + // do nothing + } + return nil + }) + + if err != nil { + return nil, err + } + + result := map[string]*jen.File{} + for _, g := range gen.gens { + file, code := g.createFile(pkg.Name()) + fullPath := path.Join(gen.mod.Dir, strings.TrimPrefix(pkg.Path(), gen.mod.Path), file) + result[fullPath] = code + } + + return result, nil +} + +type typePaths []string + +func (t *typePaths) String() string { + return fmt.Sprintf("%v", *t) +} + +func (t *typePaths) Set(path string) error { + *t = append(*t, path) + return nil +} + +func main() { + var patterns typePaths + var generate, except string + var verify bool + + flag.Var(&patterns, "in", "Go packages to load the generator") + flag.StringVar(&generate, "iface", "", "Root interface generate rewriter for") + flag.BoolVar(&verify, "verify", false, "ensure that the generated files are correct") + flag.StringVar(&except, "except", "", "don't deep clone these types") + flag.Parse() + + result, err := GenerateASTHelpers(patterns, generate, except) + if err != nil { + log.Fatal(err) + } + + if verify { + for _, err := range VerifyFilesOnDisk(result) { + log.Fatal(err) + } + log.Printf("%d files OK", len(result)) + } else { + for fullPath, file := range result { + if err := file.Save(fullPath); err != nil { + log.Fatalf("failed to save file to '%s': %v", fullPath, err) + } + log.Printf("saved '%s'", fullPath) + } + } +} + +// VerifyFilesOnDisk compares the generated results from the codegen against the files that +// currently exist on disk and returns any mismatches +func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) { + for fullPath, file := range result { + existing, err := ioutil.ReadFile(fullPath) + if err != nil { + errors = append(errors, fmt.Errorf("missing file on disk: %s (%w)", fullPath, err)) + continue + } + + var buf bytes.Buffer + if err := file.Render(&buf); err != nil { + errors = append(errors, fmt.Errorf("render error for '%s': %w", fullPath, err)) + continue + } + + if !bytes.Equal(existing, buf.Bytes()) { + errors = append(errors, fmt.Errorf("'%s' has changed", fullPath)) + continue + } + } + return errors +} + +// GenerateASTHelpers loads the input code, constructs the necessary generators, +// and generates the rewriter and clone methods for the AST +func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType string) (map[string]*jen.File, error) { + loaded, err := packages.Load(&packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule, + Logf: log.Printf, + }, packagePatterns...) + + if err != nil { + return nil, err + } + + scopes := make(map[string]*types.Scope) + for _, pkg := range loaded { + scopes[pkg.PkgPath] = pkg.Types.Scope() + } + + pos := strings.LastIndexByte(rootIface, '.') + if pos < 0 { + return nil, fmt.Errorf("unexpected input type: %s", rootIface) + } + + pkgname := rootIface[:pos] + typename := rootIface[pos+1:] + + scope := scopes[pkgname] + if scope == nil { + return nil, fmt.Errorf("no scope found for type '%s'", rootIface) + } + + tt := scope.Lookup(typename) + if tt == nil { + return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname) + } + + nt := tt.Type().(*types.Named) + + iface := nt.Underlying().(*types.Interface) + + interestingType := func(t types.Type) bool { + return types.Implements(t, iface) + } + rewriter := newRewriterGen(interestingType, nt.Obj().Name()) + clone := newCloneGen(iface, scope, exceptCloneType) + + generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, rewriter, clone) + it, err := generator.GenerateCode() + if err != nil { + return nil, err + } + + return it, nil +} diff --git a/go/tools/asthelpergen/asthelpergen_test.go b/go/tools/asthelpergen/asthelpergen_test.go new file mode 100644 index 00000000000..30dc2eb61a5 --- /dev/null +++ b/go/tools/asthelpergen/asthelpergen_test.go @@ -0,0 +1,43 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFullGeneration(t *testing.T) { + result, err := GenerateASTHelpers([]string{"./integration/..."}, "vitess.io/vitess/go/tools/asthelpergen/integration.AST", "*NoCloneType") + require.NoError(t, err) + + verifyErrors := VerifyFilesOnDisk(result) + require.Empty(t, verifyErrors) + + for _, file := range result { + contents := fmt.Sprintf("%#v", file) + require.Contains(t, contents, "http://www.apache.org/licenses/LICENSE-2.0") + applyIdx := strings.Index(contents, "func (a *application) apply(parent, node AST, replacer replacerFunc)") + cloneIdx := strings.Index(contents, "CloneAST(in AST) AST") + if applyIdx == 0 && cloneIdx == 0 { + t.Fatalf("file doesn't contain expected contents") + } + } +} diff --git a/go/tools/asthelpergen/clone_gen.go b/go/tools/asthelpergen/clone_gen.go new file mode 100644 index 00000000000..066a4c82923 --- /dev/null +++ b/go/tools/asthelpergen/clone_gen.go @@ -0,0 +1,360 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "go/types" + + "vitess.io/vitess/go/vt/log" + + "github.com/dave/jennifer/jen" +) + +// cloneGen creates the deep clone methods for the AST. It works by discovering the types that it needs to support, +// starting from a root interface type. While creating the clone method for this root interface, more types that need +// to be cloned are discovered. This continues type by type until all necessary types have been traversed. +type cloneGen struct { + methods []jen.Code + iface *types.Interface + scope *types.Scope + todo []types.Type + exceptType string +} + +var _ generator = (*cloneGen)(nil) + +func newCloneGen(iface *types.Interface, scope *types.Scope, exceptType string) *cloneGen { + return &cloneGen{ + iface: iface, + scope: scope, + exceptType: exceptType, + } +} + +func (c *cloneGen) visitStruct(types.Type, *types.Struct) error { + return nil +} + +func (c *cloneGen) visitSlice(types.Type, *types.Slice) error { + return nil +} + +func (c *cloneGen) visitInterface(t types.Type, _ *types.Interface) error { + c.todo = append(c.todo, t) + return nil +} + +const cloneName = "Clone" + +func (c *cloneGen) addFunc(name string, code jen.Code) { + c.methods = append(c.methods, jen.Comment(name+" creates a deep clone of the input."), code) +} + +// readValueOfType produces code to read the expression of type `t`, and adds the type to the todo-list +func (c *cloneGen) readValueOfType(t types.Type, expr jen.Code) jen.Code { + switch t.Underlying().(type) { + case *types.Basic: + return expr + case *types.Interface: + if types.TypeString(t, noQualifier) == "interface{}" { + // these fields have to be taken care of manually + return expr + } + } + c.todo = append(c.todo, t) + return jen.Id(cloneName + printableTypeName(t)).Call(expr) +} + +func (c *cloneGen) makeStructCloneMethod(t types.Type) error { + receiveType := types.TypeString(t, noQualifier) + funcName := "Clone" + printableTypeName(t) + c.addFunc(funcName, + jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType).Block( + jen.Return(jen.Op("*").Add(c.readValueOfType(types.NewPointer(t), jen.Op("&").Id("n")))), + )) + return nil +} + +func (c *cloneGen) makeSliceCloneMethod(t types.Type, slice *types.Slice) error { + typeString := types.TypeString(t, noQualifier) + name := printableTypeName(t) + funcName := cloneName + name + + c.addFunc(funcName, + //func (n Bytes) Clone() Bytes { + jen.Func().Id(funcName).Call(jen.Id("n").Id(typeString)).Id(typeString).Block( + // res := make(Bytes, len(n)) + jen.Id("res").Op(":=").Id("make").Call(jen.Id(typeString), jen.Lit(0), jen.Id("len").Call(jen.Id("n"))), + c.copySliceElement(slice.Elem()), + // return res + jen.Return(jen.Id("res")), + )) + return nil +} + +func (c *cloneGen) copySliceElement(elType types.Type) jen.Code { + if isBasic(elType) { + // copy(res, n) + return jen.Id("copy").Call(jen.Id("res"), jen.Id("n")) + } + + //for _, x := range n { + // res = append(res, CloneAST(x)) + //} + c.todo = append(c.todo, elType) + return jen.For(jen.List(jen.Op("_"), jen.Id("x"))).Op(":=").Range().Id("n").Block( + jen.Id("res").Op("=").Id("append").Call(jen.Id("res"), c.readValueOfType(elType, jen.Id("x"))), + ) +} + +func (c *cloneGen) makeInterfaceCloneMethod(t types.Type, iface *types.Interface) error { + + //func CloneAST(in AST) AST { + // if in == nil { + // return nil + //} + // switch in := in.(type) { + //case *RefContainer: + // return in.CloneRefOfRefContainer() + //} + // // this should never happen + // return nil + //} + + typeString := types.TypeString(t, noQualifier) + typeName := printableTypeName(t) + + stmts := []jen.Code{ifNilReturnNil("in")} + + var cases []jen.Code + _ = findImplementations(c.scope, iface, func(t types.Type) error { + typeString := types.TypeString(t, noQualifier) + + // case Type: return CloneType(in) + block := jen.Case(jen.Id(typeString)).Block(jen.Return(c.readValueOfType(t, jen.Id("in")))) + switch t := t.(type) { + case *types.Pointer: + _, isIface := t.Elem().(*types.Interface) + if !isIface { + cases = append(cases, block) + } + + case *types.Named: + _, isIface := t.Underlying().(*types.Interface) + if !isIface { + cases = append(cases, block) + } + + default: + log.Errorf("unexpected type encountered: %s", typeString) + } + + return nil + }) + + cases = append(cases, + jen.Default().Block( + jen.Comment("this should never happen"), + jen.Return(jen.Nil()), + )) + + // switch n := node.(type) { + stmts = append(stmts, jen.Switch(jen.Id("in").Op(":=").Id("in").Assert(jen.Id("type")).Block( + cases..., + ))) + + funcName := cloneName + typeName + funcDecl := jen.Func().Id(funcName).Call(jen.Id("in").Id(typeString)).Id(typeString).Block(stmts...) + c.addFunc(funcName, funcDecl) + return nil +} + +func (c *cloneGen) makePtrCloneMethod(t types.Type, ptr *types.Pointer) error { + receiveType := types.TypeString(t, noQualifier) + + funcName := "Clone" + printableTypeName(t) + c.addFunc(funcName, + jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType).Block( + ifNilReturnNil("n"), + jen.Id("out").Op(":=").Add(c.readValueOfType(ptr.Elem(), jen.Op("*").Id("n"))), + jen.Return(jen.Op("&").Id("out")), + )) + + return nil +} + +func (c *cloneGen) createFile(pkgName string) (string, *jen.File) { + out := jen.NewFile(pkgName) + out.HeaderComment(licenseFileHeader) + out.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") + alreadyDone := map[string]bool{} + for len(c.todo) > 0 { + t := c.todo[0] + underlying := t.Underlying() + typeName := printableTypeName(t) + c.todo = c.todo[1:] + + if alreadyDone[typeName] { + continue + } + + if c.tryInterface(underlying, t) || + c.trySlice(underlying, t) || + c.tryStruct(underlying, t) || + c.tryPtr(underlying, t) { + alreadyDone[typeName] = true + continue + } + + log.Errorf("don't know how to handle %s %T", typeName, underlying) + } + + for _, method := range c.methods { + out.Add(method) + } + + return "clone.go", out +} + +func ifNilReturnNil(id string) *jen.Statement { + return jen.If(jen.Id(id).Op("==").Nil()).Block(jen.Return(jen.Nil())) +} + +func isBasic(t types.Type) bool { + _, x := t.Underlying().(*types.Basic) + return x +} + +func (c *cloneGen) tryStruct(underlying, t types.Type) bool { + _, ok := underlying.(*types.Struct) + if !ok { + return false + } + + err := c.makeStructCloneMethod(t) + if err != nil { + panic(err) // todo + } + return true +} +func (c *cloneGen) tryPtr(underlying, t types.Type) bool { + ptr, ok := underlying.(*types.Pointer) + if !ok { + return false + } + + if strct, isStruct := ptr.Elem().Underlying().(*types.Struct); isStruct { + c.makePtrToStructCloneMethod(t, strct) + return true + } + + err := c.makePtrCloneMethod(t, ptr) + if err != nil { + panic(err) // todo + } + return true +} + +func (c *cloneGen) makePtrToStructCloneMethod(t types.Type, strct *types.Struct) { + receiveType := types.TypeString(t, noQualifier) + funcName := "Clone" + printableTypeName(t) + + //func CloneRefOfType(n *Type) *Type + funcDeclaration := jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType) + + if receiveType == c.exceptType { + c.addFunc(funcName, funcDeclaration.Block( + jen.Return(jen.Id("n")), + )) + return + } + + var fields []jen.Code + for i := 0; i < strct.NumFields(); i++ { + field := strct.Field(i) + if isBasic(field.Type()) || field.Name() == "_" { + continue + } + // out.Field = CloneType(n.Field) + fields = append(fields, + jen.Id("out").Dot(field.Name()).Op("=").Add(c.readValueOfType(field.Type(), jen.Id("n").Dot(field.Name())))) + } + + stmts := []jen.Code{ + // if n == nil { return nil } + ifNilReturnNil("n"), + // out := *n + jen.Id("out").Op(":=").Op("*").Id("n"), + } + + // handle all fields with CloneAble types + stmts = append(stmts, fields...) + + stmts = append(stmts, + // return &out + jen.Return(jen.Op("&").Id("out")), + ) + + c.addFunc(funcName, + funcDeclaration.Block(stmts...), + ) +} + +func (c *cloneGen) tryInterface(underlying, t types.Type) bool { + iface, ok := underlying.(*types.Interface) + if !ok { + return false + } + + err := c.makeInterfaceCloneMethod(t, iface) + if err != nil { + panic(err) // todo + } + return true +} + +func (c *cloneGen) trySlice(underlying, t types.Type) bool { + slice, ok := underlying.(*types.Slice) + if !ok { + return false + } + + err := c.makeSliceCloneMethod(t, slice) + if err != nil { + panic(err) // todo + } + return true +} + +// printableTypeName returns a string that can be used as a valid golang identifier +func printableTypeName(t types.Type) string { + switch t := t.(type) { + case *types.Pointer: + return "RefOf" + printableTypeName(t.Elem()) + case *types.Slice: + return "SliceOf" + printableTypeName(t.Elem()) + case *types.Named: + return t.Obj().Name() + case *types.Basic: + return t.Name() + case *types.Interface: + return t.String() + default: + panic(fmt.Sprintf("unknown type %T %v", t, t)) + } +} diff --git a/go/tools/asthelpergen/integration/clone.go b/go/tools/asthelpergen/integration/clone.go new file mode 100644 index 00000000000..84d1c2b7537 --- /dev/null +++ b/go/tools/asthelpergen/integration/clone.go @@ -0,0 +1,213 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Code generated by ASTHelperGen. DO NOT EDIT. + +package integration + +// CloneAST creates a deep clone of the input. +func CloneAST(in AST) AST { + if in == nil { + return nil + } + switch in := in.(type) { + case BasicType: + return in + case Bytes: + return CloneBytes(in) + case InterfaceContainer: + return CloneInterfaceContainer(in) + case InterfaceSlice: + return CloneInterfaceSlice(in) + case *Leaf: + return CloneRefOfLeaf(in) + case LeafSlice: + return CloneLeafSlice(in) + case *NoCloneType: + return CloneRefOfNoCloneType(in) + case *RefContainer: + return CloneRefOfRefContainer(in) + case *RefSliceContainer: + return CloneRefOfRefSliceContainer(in) + case *SubImpl: + return CloneRefOfSubImpl(in) + case ValueContainer: + return CloneValueContainer(in) + case ValueSliceContainer: + return CloneValueSliceContainer(in) + default: + // this should never happen + return nil + } +} + +// CloneSubIface creates a deep clone of the input. +func CloneSubIface(in SubIface) SubIface { + if in == nil { + return nil + } + switch in := in.(type) { + case *SubImpl: + return CloneRefOfSubImpl(in) + default: + // this should never happen + return nil + } +} + +// CloneBytes creates a deep clone of the input. +func CloneBytes(n Bytes) Bytes { + res := make(Bytes, 0, len(n)) + copy(res, n) + return res +} + +// CloneInterfaceContainer creates a deep clone of the input. +func CloneInterfaceContainer(n InterfaceContainer) InterfaceContainer { + return *CloneRefOfInterfaceContainer(&n) +} + +// CloneInterfaceSlice creates a deep clone of the input. +func CloneInterfaceSlice(n InterfaceSlice) InterfaceSlice { + res := make(InterfaceSlice, 0, len(n)) + for _, x := range n { + res = append(res, CloneAST(x)) + } + return res +} + +// CloneRefOfLeaf creates a deep clone of the input. +func CloneRefOfLeaf(n *Leaf) *Leaf { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneLeafSlice creates a deep clone of the input. +func CloneLeafSlice(n LeafSlice) LeafSlice { + res := make(LeafSlice, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfLeaf(x)) + } + return res +} + +// CloneRefOfNoCloneType creates a deep clone of the input. +func CloneRefOfNoCloneType(n *NoCloneType) *NoCloneType { + return n +} + +// CloneRefOfRefContainer creates a deep clone of the input. +func CloneRefOfRefContainer(n *RefContainer) *RefContainer { + if n == nil { + return nil + } + out := *n + out.ASTType = CloneAST(n.ASTType) + out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) + return &out +} + +// CloneRefOfRefSliceContainer creates a deep clone of the input. +func CloneRefOfRefSliceContainer(n *RefSliceContainer) *RefSliceContainer { + if n == nil { + return nil + } + out := *n + out.ASTElements = CloneSliceOfAST(n.ASTElements) + out.NotASTElements = CloneSliceOfint(n.NotASTElements) + out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) + return &out +} + +// CloneRefOfSubImpl creates a deep clone of the input. +func CloneRefOfSubImpl(n *SubImpl) *SubImpl { + if n == nil { + return nil + } + out := *n + out.inner = CloneSubIface(n.inner) + return &out +} + +// CloneValueContainer creates a deep clone of the input. +func CloneValueContainer(n ValueContainer) ValueContainer { + return *CloneRefOfValueContainer(&n) +} + +// CloneValueSliceContainer creates a deep clone of the input. +func CloneValueSliceContainer(n ValueSliceContainer) ValueSliceContainer { + return *CloneRefOfValueSliceContainer(&n) +} + +// CloneRefOfInterfaceContainer creates a deep clone of the input. +func CloneRefOfInterfaceContainer(n *InterfaceContainer) *InterfaceContainer { + if n == nil { + return nil + } + out := *n + out.v = n.v + return &out +} + +// CloneSliceOfAST creates a deep clone of the input. +func CloneSliceOfAST(n []AST) []AST { + res := make([]AST, 0, len(n)) + for _, x := range n { + res = append(res, CloneAST(x)) + } + return res +} + +// CloneSliceOfint creates a deep clone of the input. +func CloneSliceOfint(n []int) []int { + res := make([]int, 0, len(n)) + copy(res, n) + return res +} + +// CloneSliceOfRefOfLeaf creates a deep clone of the input. +func CloneSliceOfRefOfLeaf(n []*Leaf) []*Leaf { + res := make([]*Leaf, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfLeaf(x)) + } + return res +} + +// CloneRefOfValueContainer creates a deep clone of the input. +func CloneRefOfValueContainer(n *ValueContainer) *ValueContainer { + if n == nil { + return nil + } + out := *n + out.ASTType = CloneAST(n.ASTType) + out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) + return &out +} + +// CloneRefOfValueSliceContainer creates a deep clone of the input. +func CloneRefOfValueSliceContainer(n *ValueSliceContainer) *ValueSliceContainer { + if n == nil { + return nil + } + out := *n + out.ASTElements = CloneSliceOfAST(n.ASTElements) + out.NotASTElements = CloneSliceOfint(n.NotASTElements) + out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) + return &out +} diff --git a/go/tools/asthelpergen/integration/integration_clone_test.go b/go/tools/asthelpergen/integration/integration_clone_test.go new file mode 100644 index 00000000000..f7adf9e7eef --- /dev/null +++ b/go/tools/asthelpergen/integration/integration_clone_test.go @@ -0,0 +1,66 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package integration + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCloneLeaf(t *testing.T) { + leaf1 := &Leaf{1} + clone := CloneRefOfLeaf(leaf1) + assert.Equal(t, leaf1, clone) + leaf1.v = 5 + assert.NotEqual(t, leaf1, clone) +} + +func TestClone2(t *testing.T) { + container := &RefContainer{ + ASTType: &RefContainer{}, + NotASTType: 0, + ASTImplementationType: &Leaf{2}, + } + clone := CloneRefOfRefContainer(container) + assert.Equal(t, container, clone) + container.ASTImplementationType.v = 5 + assert.NotEqual(t, container, clone) +} + +func TestTypeException(t *testing.T) { + l1 := &Leaf{1} + nc := &NoCloneType{1} + + slice := InterfaceSlice{ + l1, + nc, + } + + clone := CloneAST(slice) + + // change the original values + l1.v = 99 + nc.v = 99 + + expected := InterfaceSlice{ + &Leaf{1}, // the change is not seen + &NoCloneType{99}, // since this type is not cloned, we do see the change + } + + assert.Equal(t, expected, clone) +} diff --git a/go/tools/asthelpergen/integration/integration_rewriter_test.go b/go/tools/asthelpergen/integration/integration_rewriter_test.go new file mode 100644 index 00000000000..1189974a79c --- /dev/null +++ b/go/tools/asthelpergen/integration/integration_rewriter_test.go @@ -0,0 +1,389 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package integration + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVisitRefContainer(t *testing.T) { + leaf1 := &Leaf{1} + leaf2 := &Leaf{2} + container := &RefContainer{ASTType: leaf1, ASTImplementationType: leaf2} + containerContainer := &RefContainer{ASTType: container} + + tv := &testVisitor{} + + Rewrite(containerContainer, tv.pre, tv.post) + + expected := []step{ + Pre{containerContainer}, + Pre{container}, + Pre{leaf1}, + Post{leaf1}, + Pre{leaf2}, + Post{leaf2}, + Post{container}, + Post{containerContainer}, + } + tv.assertEquals(t, expected) +} + +func TestVisitValueContainer(t *testing.T) { + leaf1 := &Leaf{1} + leaf2 := &Leaf{2} + container := ValueContainer{ASTType: leaf1, ASTImplementationType: leaf2} + containerContainer := ValueContainer{ASTType: container} + + tv := &testVisitor{} + + Rewrite(containerContainer, tv.pre, tv.post) + + expected := []step{ + Pre{containerContainer}, + Pre{container}, + Pre{leaf1}, + Post{leaf1}, + Pre{leaf2}, + Post{leaf2}, + Post{container}, + Post{containerContainer}, + } + tv.assertEquals(t, expected) +} + +func TestVisitRefSliceContainer(t *testing.T) { + leaf1 := &Leaf{1} + leaf2 := &Leaf{2} + leaf3 := &Leaf{3} + leaf4 := &Leaf{4} + container := &RefSliceContainer{ASTElements: []AST{leaf1, leaf2}, ASTImplementationElements: []*Leaf{leaf3, leaf4}} + containerContainer := &RefSliceContainer{ASTElements: []AST{container}} + + tv := &testVisitor{} + + Rewrite(containerContainer, tv.pre, tv.post) + + tv.assertEquals(t, []step{ + Pre{containerContainer}, + Pre{container}, + Pre{leaf1}, + Post{leaf1}, + Pre{leaf2}, + Post{leaf2}, + Pre{leaf3}, + Post{leaf3}, + Pre{leaf4}, + Post{leaf4}, + Post{container}, + Post{containerContainer}, + }) +} + +func TestVisitValueSliceContainer(t *testing.T) { + leaf1 := &Leaf{1} + leaf2 := &Leaf{2} + leaf3 := &Leaf{3} + leaf4 := &Leaf{4} + container := ValueSliceContainer{ASTElements: []AST{leaf1, leaf2}, ASTImplementationElements: []*Leaf{leaf3, leaf4}} + containerContainer := ValueSliceContainer{ASTElements: []AST{container}} + + tv := &testVisitor{} + + Rewrite(containerContainer, tv.pre, tv.post) + + tv.assertEquals(t, []step{ + Pre{containerContainer}, + Pre{container}, + Pre{leaf1}, + Post{leaf1}, + Pre{leaf2}, + Post{leaf2}, + Pre{leaf3}, + Post{leaf3}, + Pre{leaf4}, + Post{leaf4}, + Post{container}, + Post{containerContainer}, + }) +} + +func TestVisitInterfaceSlice(t *testing.T) { + leaf1 := &Leaf{2} + astType := &RefContainer{NotASTType: 12} + implementationType := &Leaf{2} + + leaf2 := &Leaf{3} + refContainer := &RefContainer{ + ASTType: astType, + ASTImplementationType: implementationType, + } + ast := InterfaceSlice{ + refContainer, + leaf1, + leaf2, + } + + tv := &testVisitor{} + + Rewrite(ast, tv.pre, tv.post) + + tv.assertEquals(t, []step{ + Pre{ast}, + Pre{refContainer}, + Pre{astType}, + Post{astType}, + Pre{implementationType}, + Post{implementationType}, + Post{refContainer}, + Pre{leaf1}, + Post{leaf1}, + Pre{leaf2}, + Post{leaf2}, + Post{ast}, + }) +} + +func TestVisitRefContainerReplace(t *testing.T) { + ast := &RefContainer{ + ASTType: &RefContainer{NotASTType: 12}, + ASTImplementationType: &Leaf{2}, + } + + // rewrite field of type AST + Rewrite(ast, func(cursor *Cursor) bool { + leaf, ok := cursor.node.(*RefContainer) + if ok && leaf.NotASTType == 12 { + cursor.Replace(&Leaf{99}) + } + return true + }, nil) + + assert.Equal(t, &RefContainer{ + ASTType: &Leaf{99}, + ASTImplementationType: &Leaf{2}, + }, ast) + + Rewrite(ast, rewriteLeaf(2, 55), nil) + + assert.Equal(t, &RefContainer{ + ASTType: &Leaf{99}, + ASTImplementationType: &Leaf{55}, + }, ast) +} + +func TestVisitValueContainerReplace(t *testing.T) { + ast := ValueContainer{ + ASTType: ValueContainer{NotASTType: 12}, + ASTImplementationType: &Leaf{2}, + } + + defer func() { + if r := recover(); r != nil { + assert.Contains(t, r, "ValueContainer ASTType") + } + }() + + Rewrite(ast, func(cursor *Cursor) bool { + leaf, ok := cursor.node.(ValueContainer) + if ok && leaf.NotASTType == 12 { + cursor.Replace(&Leaf{99}) + } + return true + }, nil) + + t.Fatalf("should not get here") +} + +func TestVisitValueContainerReplace2(t *testing.T) { + ast := ValueContainer{ + ASTType: ValueContainer{NotASTType: 12}, + ASTImplementationType: &Leaf{2}, + } + + defer func() { + if r := recover(); r != nil { + assert.Contains(t, r, "ValueContainer ASTImplementationType") + } + }() + + Rewrite(ast, rewriteLeaf(2, 10), nil) + + t.Fatalf("should not get here") +} + +func rewriteLeaf(from, to int) func(*Cursor) bool { + return func(cursor *Cursor) bool { + leaf, ok := cursor.node.(*Leaf) + if ok && leaf.v == from { + cursor.Replace(&Leaf{to}) + } + return true + } +} + +func TestRefSliceContainerReplace(t *testing.T) { + ast := &RefSliceContainer{ + ASTElements: []AST{&Leaf{1}, &Leaf{2}}, + ASTImplementationElements: []*Leaf{{3}, {4}}, + } + + Rewrite(ast, rewriteLeaf(2, 42), nil) + + assert.Equal(t, &RefSliceContainer{ + ASTElements: []AST{&Leaf{1}, &Leaf{42}}, + ASTImplementationElements: []*Leaf{{3}, {4}}, + }, ast) + + Rewrite(ast, rewriteLeaf(3, 88), nil) + + assert.Equal(t, &RefSliceContainer{ + ASTElements: []AST{&Leaf{1}, &Leaf{42}}, + ASTImplementationElements: []*Leaf{{88}, {4}}, + }, ast) +} + +type step interface { + String() string +} +type Pre struct { + el AST +} + +func (r Pre) String() string { + return fmt.Sprintf("Pre(%s)", r.el.String()) +} +func (r Post) String() string { + return fmt.Sprintf("Pre(%s)", r.el.String()) +} + +type Post struct { + el AST +} + +type testVisitor struct { + walk []step +} + +func (tv *testVisitor) pre(cursor *Cursor) bool { + tv.walk = append(tv.walk, Pre{el: cursor.Node()}) + return true +} +func (tv *testVisitor) post(cursor *Cursor) bool { + tv.walk = append(tv.walk, Post{el: cursor.Node()}) + return true +} +func (tv *testVisitor) assertEquals(t *testing.T, expected []step) { + t.Helper() + var lines []string + error := false + expectedSize := len(expected) + for i, step := range tv.walk { + if expectedSize <= i { + t.Errorf("❌️ - Expected less elements %v", tv.walk[i:]) + break + } else { + e := expected[i] + if reflect.DeepEqual(e, step) { + a := "✔️ - " + e.String() + if error { + fmt.Println(a) + } else { + lines = append(lines, a) + } + } else { + if !error { + // first error we see. + error = true + for _, line := range lines { + fmt.Println(line) + } + } + t.Errorf("❌️ - Expected: %s Got: %s\n", e.String(), step.String()) + } + } + } + walkSize := len(tv.walk) + if expectedSize > walkSize { + t.Errorf("❌️ - Expected more elements %v", expected[walkSize:]) + } + +} + +// below follows two different ways of creating the replacement method for slices, and benchmark +// between them. Diff seems to be very small, so I'll use the most readable form +type replaceA int + +func (r *replaceA) replace(newNode, container AST) { + container.(InterfaceSlice)[int(*r)] = newNode.(AST) +} + +func (r *replaceA) inc() { + *r++ +} + +func replaceB(idx int) func(AST, AST) { + return func(newNode, container AST) { + container.(InterfaceSlice)[idx] = newNode.(AST) + } +} + +func BenchmarkSliceReplacerA(b *testing.B) { + islice := make(InterfaceSlice, 20) + for i := range islice { + islice[i] = &Leaf{i} + } + a := &application{ + pre: func(c *Cursor) bool { + return true + }, + post: nil, + cursor: Cursor{}, + } + + for i := 0; i < b.N; i++ { + replacer := replaceA(0) + for _, el := range islice { + a.apply(islice, el, replacer.replace) + replacer.inc() + } + } +} + +func BenchmarkSliceReplacerB(b *testing.B) { + islice := make(InterfaceSlice, 20) + for i := range islice { + islice[i] = &Leaf{i} + } + a := &application{ + pre: func(c *Cursor) bool { + return true + }, + post: nil, + cursor: Cursor{}, + } + + for i := 0; i < b.N; i++ { + for x, el := range islice { + a.apply(islice, el, replaceB(x)) + } + } +} diff --git a/go/tools/asthelpergen/integration/rewriter.go b/go/tools/asthelpergen/integration/rewriter.go new file mode 100644 index 00000000000..ace51dc2937 --- /dev/null +++ b/go/tools/asthelpergen/integration/rewriter.go @@ -0,0 +1,90 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Code generated by ASTHelperGen. DO NOT EDIT. + +package integration + +func (a *application) apply(parent, node AST, replacer replacerFunc) { + if node == nil || isNilValue(node) { + return + } + saved := a.cursor + a.cursor.replacer = replacer + a.cursor.node = node + a.cursor.parent = parent + if a.pre != nil && !a.pre(&a.cursor) { + a.cursor = saved + return + } + switch n := node.(type) { + case Bytes: + case InterfaceContainer: + case InterfaceSlice: + for x, el := range n { + a.apply(node, el, func(newNode, container AST) { + container.(InterfaceSlice)[x] = newNode.(AST) + }) + } + case *Leaf: + case LeafSlice: + for x, el := range n { + a.apply(node, el, func(newNode, container AST) { + container.(LeafSlice)[x] = newNode.(*Leaf) + }) + } + case *NoCloneType: + case *RefContainer: + a.apply(node, n.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }) + a.apply(node, n.ASTImplementationType, func(newNode, parent AST) { + parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) + }) + case *RefSliceContainer: + for x, el := range n.ASTElements { + a.apply(node, el, func(newNode, container AST) { + container.(*RefSliceContainer).ASTElements[x] = newNode.(AST) + }) + } + for x, el := range n.ASTImplementationElements { + a.apply(node, el, func(newNode, container AST) { + container.(*RefSliceContainer).ASTImplementationElements[x] = newNode.(*Leaf) + }) + } + case *SubImpl: + a.apply(node, n.inner, func(newNode, parent AST) { + parent.(*SubImpl).inner = newNode.(SubIface) + }) + case ValueContainer: + a.apply(node, n.ASTType, replacePanic("ValueContainer ASTType")) + a.apply(node, n.ASTImplementationType, replacePanic("ValueContainer ASTImplementationType")) + case ValueSliceContainer: + for x, el := range n.ASTElements { + a.apply(node, el, func(newNode, container AST) { + container.(ValueSliceContainer).ASTElements[x] = newNode.(AST) + }) + } + for x, el := range n.ASTImplementationElements { + a.apply(node, el, func(newNode, container AST) { + container.(ValueSliceContainer).ASTImplementationElements[x] = newNode.(*Leaf) + }) + } + } + if a.post != nil && !a.post(&a.cursor) { + panic(abort) + } + a.cursor = saved +} diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go new file mode 100644 index 00000000000..3a2da19be80 --- /dev/null +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -0,0 +1,97 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package integration + +import ( + "reflect" + "strings" +) + +// ast type helpers + +func sliceStringAST(els ...AST) string { + result := make([]string, len(els)) + for i, el := range els { + result[i] = el.String() + } + return strings.Join(result, ", ") +} +func sliceStringLeaf(els ...*Leaf) string { + result := make([]string, len(els)) + for i, el := range els { + result[i] = el.String() + } + return strings.Join(result, ", ") +} + +// the methods below are what the generated code expected to be there in the package + +type application struct { + pre, post ApplyFunc + cursor Cursor +} + +type ApplyFunc func(*Cursor) bool + +type Cursor struct { + parent AST + replacer replacerFunc + node AST +} + +// Node returns the current Node. +func (c *Cursor) Node() AST { return c.node } + +// Parent returns the parent of the current Node. +func (c *Cursor) Parent() AST { return c.parent } + +// Replace replaces the current node in the parent field with this new object. The use needs to make sure to not +// replace the object with something of the wrong type, or the visitor will panic. +func (c *Cursor) Replace(newNode AST) { + c.replacer(newNode, c.parent) + c.node = newNode +} + +type replacerFunc func(newNode, parent AST) + +func isNilValue(i interface{}) bool { + valueOf := reflect.ValueOf(i) + kind := valueOf.Kind() + isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice + return isNullable && valueOf.IsNil() +} + +var abort = new(int) // singleton, to signal termination of Apply + +func Rewrite(node AST, pre, post ApplyFunc) (result AST) { + parent := &struct{ AST }{node} + + a := &application{ + pre: pre, + post: post, + cursor: Cursor{}, + } + + a.apply(parent.AST, node, nil) + return parent.AST +} + +func replacePanic(msg string) func(newNode, parent AST) { + return func(newNode, parent AST) { + panic("Tried replacing a field of a value type. This is not supported. " + msg) + } +} diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go new file mode 100644 index 00000000000..1a89d78c3a2 --- /dev/null +++ b/go/tools/asthelpergen/integration/types.go @@ -0,0 +1,172 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//nolint +package integration + +import ( + "fmt" + "strings" +) + +/* +These types are used to test the rewriter generator against these types. +To recreate them, just run: + +go run go/tools/asthelpergen -in ./go/tools/asthelpergen/integration -iface vitess.io/vitess/go/tools/asthelpergen/integration.AST +*/ +// AST is the interface all interface types implement +type AST interface { + String() string +} + +// Empty struct impl of the iface +type Leaf struct { + v int +} + +func (l *Leaf) String() string { + if l == nil { + return "nil" + } + return fmt.Sprintf("Leaf(%d)", l.v) +} + +// Container implements the interface ByRef +type RefContainer struct { + ASTType AST + NotASTType int + ASTImplementationType *Leaf +} + +func (r *RefContainer) String() string { + if r == nil { + return "nil" + } + var astType = "" + if r.ASTType == nil { + astType = "nil" + } else { + astType = r.ASTType.String() + } + return fmt.Sprintf("RefContainer{%s, %d, %s}", astType, r.NotASTType, r.ASTImplementationType.String()) +} + +// Container implements the interface ByRef +type RefSliceContainer struct { + ASTElements []AST + NotASTElements []int + ASTImplementationElements []*Leaf +} + +func (r *RefSliceContainer) String() string { + return fmt.Sprintf("RefSliceContainer{%s, %s, %s}", sliceStringAST(r.ASTElements...), "r.NotASTType", sliceStringLeaf(r.ASTImplementationElements...)) +} + +// Container implements the interface ByValue +type ValueContainer struct { + ASTType AST + NotASTType int + ASTImplementationType *Leaf +} + +func (r ValueContainer) String() string { + return fmt.Sprintf("ValueContainer{%s, %d, %s}", r.ASTType.String(), r.NotASTType, r.ASTImplementationType.String()) +} + +// Container implements the interface ByValue +type ValueSliceContainer struct { + ASTElements []AST + NotASTElements []int + ASTImplementationElements []*Leaf +} + +func (r ValueSliceContainer) String() string { + return fmt.Sprintf("ValueSliceContainer{%s, %s, %s}", sliceStringAST(r.ASTElements...), "r.NotASTType", sliceStringLeaf(r.ASTImplementationElements...)) +} + +// We need to support these types - a slice of AST elements can implement the interface +type InterfaceSlice []AST + +func (r InterfaceSlice) String() string { + var elements []string + for _, el := range r { + elements = append(elements, el.String()) + } + + return "[" + strings.Join(elements, ", ") + "]" +} + +// We need to support these types - a slice of AST elements can implement the interface +type Bytes []byte + +func (r Bytes) String() string { + return string(r) +} + +type LeafSlice []*Leaf + +func (r LeafSlice) String() string { + var elements []string + for _, el := range r { + elements = append(elements, el.String()) + } + return strings.Join(elements, ", ") +} + +type BasicType int + +func (r BasicType) String() string { + return fmt.Sprintf("int(%d)", r) +} + +const ( + // these consts are here to try to trick the generator + thisIsNotAType BasicType = 1 + thisIsNotAType2 BasicType = 2 +) + +// We want to support all types that are used as field types, which can include interfaces. +// Example would be sqlparser.Expr that implements sqlparser.SQLNode +type SubIface interface { + AST + iface() +} + +type SubImpl struct { + inner SubIface +} + +func (r *SubImpl) String() string { + return "SubImpl" +} +func (r *SubImpl) iface() {} + +type InterfaceContainer struct { + v interface{} +} + +func (r InterfaceContainer) String() string { + return fmt.Sprintf("%v", r.v) +} + +type NoCloneType struct { + v int +} + +func (r *NoCloneType) String() string { + return fmt.Sprintf("NoClone(%d)", r.v) +} diff --git a/go/tools/asthelpergen/rewriter_gen.go b/go/tools/asthelpergen/rewriter_gen.go new file mode 100644 index 00000000000..a9e23a98a02 --- /dev/null +++ b/go/tools/asthelpergen/rewriter_gen.go @@ -0,0 +1,208 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "go/types" + + "github.com/dave/jennifer/jen" +) + +type rewriterGen struct { + cases []jen.Code + interestingType func(types.Type) bool + ifaceName string +} + +func newRewriterGen(f func(types.Type) bool, name string) *rewriterGen { + return &rewriterGen{interestingType: f, ifaceName: name} +} + +var noQualifier = func(p *types.Package) string { + return "" +} + +func (r *rewriterGen) visitStruct(t types.Type, stroct *types.Struct) error { + typeString := types.TypeString(t, noQualifier) + typeName := printableTypeName(t) + var caseStmts []jen.Code + for i := 0; i < stroct.NumFields(); i++ { + field := stroct.Field(i) + if r.interestingType(field.Type()) { + if _, ok := t.(*types.Pointer); ok { + function := r.createReplaceMethod(typeString, field) + caseStmts = append(caseStmts, caseStmtFor(field, function)) + } else { + caseStmts = append(caseStmts, casePanicStmtFor(field, typeName+" "+field.Name())) + } + } + sliceT, ok := field.Type().(*types.Slice) + if ok && r.interestingType(sliceT.Elem()) { // we have a field containing a slice of interesting elements + function := r.createReplacementMethod(t, sliceT.Elem(), jen.Dot(field.Name())) + caseStmts = append(caseStmts, caseStmtForSliceField(field, function)) + } + } + r.cases = append(r.cases, jen.Case(jen.Id(typeString)).Block(caseStmts...)) + return nil +} + +func (r *rewriterGen) visitInterface(types.Type, *types.Interface) error { + return nil // rewriter doesn't deal with interfaces +} + +func (r *rewriterGen) visitSlice(t types.Type, slice *types.Slice) error { + typeString := types.TypeString(t, noQualifier) + + var stmts []jen.Code + if r.interestingType(slice.Elem()) { + function := r.createReplacementMethod(t, slice.Elem(), jen.Empty()) + stmts = append(stmts, caseStmtForSlice(function)) + } + r.cases = append(r.cases, jen.Case(jen.Id(typeString)).Block(stmts...)) + return nil +} + +func caseStmtFor(field *types.Var, expr jen.Code) *jen.Statement { + // a.apply(node, node.Field, replacerMethod) + return jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), expr) +} + +func casePanicStmtFor(field *types.Var, name string) *jen.Statement { + return jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), jen.Id("replacePanic").Call(jen.Lit(name))) +} + +func caseStmtForSlice(function *jen.Statement) jen.Code { + return jen.For(jen.List(jen.Op("x"), jen.Id("el"))).Op(":=").Range().Id("n").Block( + jen.Id("a").Dot("apply").Call( + jen.Id("node"), + jen.Id("el"), + function, + ), + ) +} + +func caseStmtForSliceField(field *types.Var, function *jen.Statement) jen.Code { + //for x, el := range n { + return jen.For(jen.List(jen.Op("x"), jen.Id("el"))).Op(":=").Range().Id("n").Dot(field.Name()).Block( + jen.Id("a").Dot("apply").Call( + // a.apply(node, el, replaceInterfaceSlice(x)) + jen.Id("node"), + jen.Id("el"), + function, + ), + ) +} + +func (r *rewriterGen) structCase(name string, stroct *types.Struct) (jen.Code, error) { + var stmts []jen.Code + for i := 0; i < stroct.NumFields(); i++ { + field := stroct.Field(i) + if r.interestingType(field.Type()) { + stmts = append(stmts, jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), jen.Nil())) + } + } + return jen.Case(jen.Op("*").Id(name)).Block(stmts...), nil +} + +func (r *rewriterGen) createReplaceMethod(structType string, field *types.Var) jen.Code { + return jen.Func().Params( + jen.Id("newNode"), + jen.Id("parent").Id(r.ifaceName), + ).Block( + jen.Id("parent").Assert(jen.Id(structType)).Dot(field.Name()).Op("=").Id("newNode").Assert(jen.Id(types.TypeString(field.Type(), noQualifier))), + ) +} + +func (r *rewriterGen) createReplacementMethod(container, elem types.Type, x jen.Code) *jen.Statement { + /* + func replacer(idx int) func(AST, AST) { + return func(newnode, container AST) { + container.(InterfaceSlice)[idx] = newnode.(AST) + } + } + + */ + return jen.Func().Params(jen.List(jen.Id("newNode"), jen.Id("container")).Id(r.ifaceName)).Block( + jen.Id("container").Assert(jen.Id(types.TypeString(container, noQualifier))).Add(x).Index(jen.Id("x")).Op("="). + Id("newNode").Assert(jen.Id(types.TypeString(elem, noQualifier))), + ) +} + +func (r *rewriterGen) createFile(pkgName string) (string, *jen.File) { + out := jen.NewFile(pkgName) + out.HeaderComment(licenseFileHeader) + out.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") + + out.Add( + // func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { + jen.Func().Params( + jen.Id("a").Op("*").Id("application"), + ).Id("apply").Params( + jen.Id("parent"), + jen.Id("node").Id(r.ifaceName), + jen.Id("replacer").Id("replacerFunc"), + ).Block( + /* + if node == nil || isNilValue(node) { + return + } + */ + jen.If( + jen.Id("node").Op("==").Nil().Op("||"). + Id("isNilValue").Call(jen.Id("node"))).Block( + jen.Return(), + ), + /* + saved := a.cursor + a.cursor.replacer = replacer + a.cursor.node = node + a.cursor.parent = parent + */ + jen.Id("saved").Op(":=").Id("a").Dot("cursor"), + jen.Id("a").Dot("cursor").Dot("replacer").Op("=").Id("replacer"), + jen.Id("a").Dot("cursor").Dot("node").Op("=").Id("node"), + jen.Id("a").Dot("cursor").Dot("parent").Op("=").Id("parent"), + jen.If( + jen.Id("a").Dot("pre").Op("!=").Nil().Op("&&"). + Op("!").Id("a").Dot("pre").Call(jen.Op("&").Id("a").Dot("cursor"))).Block( + jen.Id("a").Dot("cursor").Op("=").Id("saved"), + jen.Return(), + ), + + // switch n := node.(type) { + jen.Switch(jen.Id("n").Op(":=").Id("node").Assert(jen.Id("type")).Block( + r.cases..., + )), + + /* + if a.post != nil && !a.post(&a.cursor) { + panic(abort) + } + */ + jen.If( + jen.Id("a").Dot("post").Op("!=").Nil().Op("&&"). + Op("!").Id("a").Dot("post").Call(jen.Op("&").Id("a").Dot("cursor"))).Block( + jen.Id("panic").Call(jen.Id("abort")), + ), + + // a.cursor = saved + jen.Id("a").Dot("cursor").Op("=").Id("saved"), + ), + ) + + return "rewriter.go", out +} diff --git a/go/tools/sizegen/sizegen.go b/go/tools/sizegen/sizegen.go index 6281cd1485e..a02c2aa69a9 100644 --- a/go/tools/sizegen/sizegen.go +++ b/go/tools/sizegen/sizegen.go @@ -170,14 +170,6 @@ func findImplementations(scope *types.Scope, iff *types.Interface, impl func(typ } } -func (sizegen *sizegen) generateKnownInterface(pkg *types.Package, iff *types.Interface) { - findImplementations(pkg.Scope(), iff, func(tt types.Type) { - if named, ok := tt.(*types.Named); ok { - sizegen.generateKnownType(named) - } - }) -} - func (sizegen *sizegen) finalize() map[string]*jen.File { var complete bool diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 92c3300ccd1..a146d4b792d 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -1425,7 +1425,7 @@ type ( func (*StarExpr) iSelectExpr() {} func (*AliasedExpr) iSelectExpr() {} -func (Nextval) iSelectExpr() {} +func (*Nextval) iSelectExpr() {} // Columns represents an insert column list. type Columns []ColIdent @@ -1538,7 +1538,6 @@ type ( Expr interface { iExpr() SQLNode - Clone() Expr } // AndExpr represents an AND expression. @@ -2598,7 +2597,7 @@ func (node *AliasedExpr) Format(buf *TrackedBuffer) { } // Format formats the node. -func (node Nextval) Format(buf *TrackedBuffer) { +func (node *Nextval) Format(buf *TrackedBuffer) { buf.astPrintf(node, "next %v values", node.Expr) } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 07fea6767d5..bd5773d438a 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -56,7 +56,10 @@ func Walk(visit Visit, nodes ...SQLNode) error { return err == nil // now we can abort the traversal if an error was found } - Rewrite(node, pre, post) + _, rewriterErr := Rewrite(node, pre, post) + if rewriterErr != nil { + return rewriterErr + } if err != nil { return err } @@ -391,10 +394,15 @@ func NewWhere(typ WhereType, expr Expr) *Where { // and replaces it with to. If from matches root, // then to is returned. func ReplaceExpr(root, from, to Expr) Expr { - tmp := Rewrite(root, replaceExpr(from, to), nil) + tmp, err := Rewrite(root, replaceExpr(from, to), nil) + if err != nil { + log.Errorf("Failed to rewrite expression. Rewriter returned an error: %s", err.Error()) + return from + } + expr, success := tmp.(Expr) if !success { - log.Errorf("Failed to rewrite expression. Rewriter returned a non-expression: " + String(tmp)) + log.Errorf("Failed to rewrite expression. Rewriter returned a non-expression: %s", String(tmp)) return from } @@ -1330,315 +1338,3 @@ const ( // DoubleAt represnts @@ DoubleAt ) - -func nilOrClone(in Expr) Expr { - if in == nil { - return nil - } - return in.Clone() -} - -// Clone implements the Expr interface -func (node *Subquery) Clone() Expr { - if node == nil { - return nil - } - panic("Subquery cloning not supported") -} - -// Clone implements the Expr interface -func (node *AndExpr) Clone() Expr { - if node == nil { - return nil - } - return &AndExpr{ - Left: nilOrClone(node.Left), - Right: nilOrClone(node.Right), - } -} - -// Clone implements the Expr interface -func (node *OrExpr) Clone() Expr { - if node == nil { - return nil - } - return &OrExpr{ - Left: nilOrClone(node.Left), - Right: nilOrClone(node.Right), - } -} - -// Clone implements the Expr interface -func (node *XorExpr) Clone() Expr { - if node == nil { - return nil - } - return &XorExpr{ - Left: nilOrClone(node.Left), - Right: nilOrClone(node.Right), - } -} - -// Clone implements the Expr interface -func (node *NotExpr) Clone() Expr { - if node == nil { - return nil - } - return &NotExpr{ - Expr: nilOrClone(node), - } -} - -// Clone implements the Expr interface -func (node *ComparisonExpr) Clone() Expr { - if node == nil { - return nil - } - return &ComparisonExpr{ - Operator: node.Operator, - Left: nilOrClone(node.Left), - Right: nilOrClone(node.Right), - Escape: nilOrClone(node.Escape), - } -} - -// Clone implements the Expr interface -func (node *RangeCond) Clone() Expr { - if node == nil { - return nil - } - return &RangeCond{ - Operator: node.Operator, - Left: nilOrClone(node.Left), - From: nilOrClone(node.From), - To: nilOrClone(node.To), - } -} - -// Clone implements the Expr interface -func (node *IsExpr) Clone() Expr { - if node == nil { - return nil - } - return &IsExpr{ - Operator: node.Operator, - Expr: nilOrClone(node.Expr), - } -} - -// Clone implements the Expr interface -func (node *ExistsExpr) Clone() Expr { - if node == nil { - return nil - } - return &ExistsExpr{ - Subquery: nilOrClone(node.Subquery).(*Subquery), - } -} - -// Clone implements the Expr interface -func (node *Literal) Clone() Expr { - if node == nil { - return nil - } - return &Literal{} -} - -// Clone implements the Expr interface -func (node Argument) Clone() Expr { - if node == nil { - return nil - } - cpy := make(Argument, len(node)) - copy(cpy, node) - return cpy -} - -// Clone implements the Expr interface -func (node *NullVal) Clone() Expr { - if node == nil { - return nil - } - return &NullVal{} -} - -// Clone implements the Expr interface -func (node BoolVal) Clone() Expr { - return node -} - -// Clone implements the Expr interface -func (node *ColName) Clone() Expr { - return node -} - -// Clone implements the Expr interface -func (node ValTuple) Clone() Expr { - if node == nil { - return nil - } - cpy := make(ValTuple, len(node)) - copy(cpy, node) - return cpy -} - -// Clone implements the Expr interface -func (node ListArg) Clone() Expr { - if node == nil { - return nil - } - cpy := make(ListArg, len(node)) - copy(cpy, node) - return cpy -} - -// Clone implements the Expr interface -func (node *BinaryExpr) Clone() Expr { - if node == nil { - return nil - } - return &BinaryExpr{ - Operator: node.Operator, - Left: nilOrClone(node.Left), - Right: nilOrClone(node.Right), - } -} - -// Clone implements the Expr interface -func (node *UnaryExpr) Clone() Expr { - if node == nil { - return nil - } - return &UnaryExpr{ - Operator: node.Operator, - Expr: nilOrClone(node.Expr), - } -} - -// Clone implements the Expr interface -func (node *IntervalExpr) Clone() Expr { - if node == nil { - return nil - } - return &IntervalExpr{ - Expr: nilOrClone(node.Expr), - Unit: node.Unit, - } -} - -// Clone implements the Expr interface -func (node *CollateExpr) Clone() Expr { - if node == nil { - return nil - } - return &CollateExpr{ - Expr: nilOrClone(node.Expr), - Charset: node.Charset, - } -} - -// Clone implements the Expr interface -func (node *FuncExpr) Clone() Expr { - if node == nil { - return nil - } - panic("FuncExpr cloning not supported") -} - -// Clone implements the Expr interface -func (node *TimestampFuncExpr) Clone() Expr { - if node == nil { - return nil - } - return &TimestampFuncExpr{ - Name: node.Name, - Expr1: nilOrClone(node.Expr1), - Expr2: nilOrClone(node.Expr2), - Unit: node.Unit, - } -} - -// Clone implements the Expr interface -func (node *CurTimeFuncExpr) Clone() Expr { - if node == nil { - return nil - } - return &CurTimeFuncExpr{ - Name: node.Name, - Fsp: nilOrClone(node.Fsp), - } -} - -// Clone implements the Expr interface -func (node *CaseExpr) Clone() Expr { - if node == nil { - return nil - } - panic("CaseExpr cloning not supported") -} - -// Clone implements the Expr interface -func (node *ValuesFuncExpr) Clone() Expr { - if node == nil { - return nil - } - return &ValuesFuncExpr{ - Name: nilOrClone(node.Name).(*ColName), - } -} - -// Clone implements the Expr interface -func (node *ConvertExpr) Clone() Expr { - if node == nil { - return nil - } - panic("ConvertExpr cloning not supported") -} - -// Clone implements the Expr interface -func (node *SubstrExpr) Clone() Expr { - if node == nil { - return nil - } - return &SubstrExpr{ - Name: node.Name, - StrVal: nilOrClone(node.StrVal).(*Literal), - From: nilOrClone(node.From), - To: nilOrClone(node.To), - } -} - -// Clone implements the Expr interface -func (node *ConvertUsingExpr) Clone() Expr { - if node == nil { - return nil - } - return &ConvertUsingExpr{ - Expr: nilOrClone(node.Expr), - Type: node.Type, - } -} - -// Clone implements the Expr interface -func (node *MatchExpr) Clone() Expr { - if node == nil { - return nil - } - panic("MatchExpr cloning not supported") -} - -// Clone implements the Expr interface -func (node *GroupConcatExpr) Clone() Expr { - if node == nil { - return nil - } - panic("GroupConcatExpr cloning not supported") -} - -// Clone implements the Expr interface -func (node *Default) Clone() Expr { - if node == nil { - return nil - } - return &Default{ColName: node.ColName} -} diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index 2494a39527f..6b075f731ab 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -35,7 +35,10 @@ type RewriteASTResult struct { // PrepareAST will normalize the query func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix string, parameterize bool, keyspace string) (*RewriteASTResult, error) { if parameterize { - Normalize(in, bindVars, prefix) + err := Normalize(in, bindVars, prefix) + if err != nil { + return nil, err + } } return RewriteAST(in, keyspace) } @@ -45,7 +48,11 @@ func RewriteAST(in Statement, keyspace string) (*RewriteASTResult, error) { er := newExpressionRewriter(keyspace) er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in) setRewriter := &setNormalizer{} - out, ok := Rewrite(in, er.rewrite, setRewriter.rewriteSetComingUp).(Statement) + result, err := Rewrite(in, er.rewrite, setRewriter.rewriteSetComingUp) + if err != nil { + return nil, err + } + out, ok := result.(Statement) if !ok { return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out)) } @@ -114,7 +121,10 @@ const ( func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { inner := newExpressionRewriter(er.keyspace) inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := Rewrite(node.Expr, inner.rewrite, nil) + tmp, err := Rewrite(node.Expr, inner.rewrite, nil) + if err != nil { + return nil, err + } newExpr, ok := tmp.(Expr) if !ok { return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) @@ -312,7 +322,11 @@ func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquer er.bindVars.NoteRewrite() // we need to make sure that the inner expression also gets rewritten, // so we fire off another rewriter traversal here - rewrittenExpr := Rewrite(expr.Expr, er.rewrite, nil) + rewrittenExpr, err := Rewrite(expr.Expr, er.rewrite, nil) + if err != nil { + er.err = err + return + } cursor.Replace(rewrittenExpr) } diff --git a/go/vt/sqlparser/clone.go b/go/vt/sqlparser/clone.go new file mode 100644 index 00000000000..7999702feb7 --- /dev/null +++ b/go/vt/sqlparser/clone.go @@ -0,0 +1,2487 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Code generated by ASTHelperGen. DO NOT EDIT. + +package sqlparser + +// CloneAlterOption creates a deep clone of the input. +func CloneAlterOption(in AlterOption) AlterOption { + if in == nil { + return nil + } + switch in := in.(type) { + case *AddColumns: + return CloneRefOfAddColumns(in) + case *AddConstraintDefinition: + return CloneRefOfAddConstraintDefinition(in) + case *AddIndexDefinition: + return CloneRefOfAddIndexDefinition(in) + case AlgorithmValue: + return in + case *AlterCharset: + return CloneRefOfAlterCharset(in) + case *AlterColumn: + return CloneRefOfAlterColumn(in) + case *ChangeColumn: + return CloneRefOfChangeColumn(in) + case *DropColumn: + return CloneRefOfDropColumn(in) + case *DropKey: + return CloneRefOfDropKey(in) + case *Force: + return CloneRefOfForce(in) + case *KeyState: + return CloneRefOfKeyState(in) + case *LockOption: + return CloneRefOfLockOption(in) + case *ModifyColumn: + return CloneRefOfModifyColumn(in) + case *OrderByOption: + return CloneRefOfOrderByOption(in) + case *RenameIndex: + return CloneRefOfRenameIndex(in) + case *RenameTableName: + return CloneRefOfRenameTableName(in) + case TableOptions: + return CloneTableOptions(in) + case *TablespaceOperation: + return CloneRefOfTablespaceOperation(in) + case *Validation: + return CloneRefOfValidation(in) + default: + // this should never happen + return nil + } +} + +// CloneCharacteristic creates a deep clone of the input. +func CloneCharacteristic(in Characteristic) Characteristic { + if in == nil { + return nil + } + switch in := in.(type) { + case AccessMode: + return in + case IsolationLevel: + return in + default: + // this should never happen + return nil + } +} + +// CloneColTuple creates a deep clone of the input. +func CloneColTuple(in ColTuple) ColTuple { + if in == nil { + return nil + } + switch in := in.(type) { + case ListArg: + return CloneListArg(in) + case *Subquery: + return CloneRefOfSubquery(in) + case ValTuple: + return CloneValTuple(in) + default: + // this should never happen + return nil + } +} + +// CloneConstraintInfo creates a deep clone of the input. +func CloneConstraintInfo(in ConstraintInfo) ConstraintInfo { + if in == nil { + return nil + } + switch in := in.(type) { + case *CheckConstraintDefinition: + return CloneRefOfCheckConstraintDefinition(in) + case *ForeignKeyDefinition: + return CloneRefOfForeignKeyDefinition(in) + default: + // this should never happen + return nil + } +} + +// CloneDBDDLStatement creates a deep clone of the input. +func CloneDBDDLStatement(in DBDDLStatement) DBDDLStatement { + if in == nil { + return nil + } + switch in := in.(type) { + case *AlterDatabase: + return CloneRefOfAlterDatabase(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + default: + // this should never happen + return nil + } +} + +// CloneDDLStatement creates a deep clone of the input. +func CloneDDLStatement(in DDLStatement) DDLStatement { + if in == nil { + return nil + } + switch in := in.(type) { + case *AlterTable: + return CloneRefOfAlterTable(in) + case *AlterView: + return CloneRefOfAlterView(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + default: + // this should never happen + return nil + } +} + +// CloneExplain creates a deep clone of the input. +func CloneExplain(in Explain) Explain { + if in == nil { + return nil + } + switch in := in.(type) { + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + default: + // this should never happen + return nil + } +} + +// CloneExpr creates a deep clone of the input. +func CloneExpr(in Expr) Expr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AndExpr: + return CloneRefOfAndExpr(in) + case Argument: + return CloneArgument(in) + case *BinaryExpr: + return CloneRefOfBinaryExpr(in) + case BoolVal: + return in + case *CaseExpr: + return CloneRefOfCaseExpr(in) + case *ColName: + return CloneRefOfColName(in) + case *CollateExpr: + return CloneRefOfCollateExpr(in) + case *ComparisonExpr: + return CloneRefOfComparisonExpr(in) + case *ConvertExpr: + return CloneRefOfConvertExpr(in) + case *ConvertUsingExpr: + return CloneRefOfConvertUsingExpr(in) + case *CurTimeFuncExpr: + return CloneRefOfCurTimeFuncExpr(in) + case *Default: + return CloneRefOfDefault(in) + case *ExistsExpr: + return CloneRefOfExistsExpr(in) + case *FuncExpr: + return CloneRefOfFuncExpr(in) + case *GroupConcatExpr: + return CloneRefOfGroupConcatExpr(in) + case *IntervalExpr: + return CloneRefOfIntervalExpr(in) + case *IsExpr: + return CloneRefOfIsExpr(in) + case ListArg: + return CloneListArg(in) + case *Literal: + return CloneRefOfLiteral(in) + case *MatchExpr: + return CloneRefOfMatchExpr(in) + case *NotExpr: + return CloneRefOfNotExpr(in) + case *NullVal: + return CloneRefOfNullVal(in) + case *OrExpr: + return CloneRefOfOrExpr(in) + case *RangeCond: + return CloneRefOfRangeCond(in) + case *Subquery: + return CloneRefOfSubquery(in) + case *SubstrExpr: + return CloneRefOfSubstrExpr(in) + case *TimestampFuncExpr: + return CloneRefOfTimestampFuncExpr(in) + case *UnaryExpr: + return CloneRefOfUnaryExpr(in) + case ValTuple: + return CloneValTuple(in) + case *ValuesFuncExpr: + return CloneRefOfValuesFuncExpr(in) + case *XorExpr: + return CloneRefOfXorExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneInsertRows creates a deep clone of the input. +func CloneInsertRows(in InsertRows) InsertRows { + if in == nil { + return nil + } + switch in := in.(type) { + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *Select: + return CloneRefOfSelect(in) + case *Union: + return CloneRefOfUnion(in) + case Values: + return CloneValues(in) + default: + // this should never happen + return nil + } +} + +// CloneSQLNode creates a deep clone of the input. +func CloneSQLNode(in SQLNode) SQLNode { + if in == nil { + return nil + } + switch in := in.(type) { + case AccessMode: + return in + case *AddColumns: + return CloneRefOfAddColumns(in) + case *AddConstraintDefinition: + return CloneRefOfAddConstraintDefinition(in) + case *AddIndexDefinition: + return CloneRefOfAddIndexDefinition(in) + case AlgorithmValue: + return in + case *AliasedExpr: + return CloneRefOfAliasedExpr(in) + case *AliasedTableExpr: + return CloneRefOfAliasedTableExpr(in) + case *AlterCharset: + return CloneRefOfAlterCharset(in) + case *AlterColumn: + return CloneRefOfAlterColumn(in) + case *AlterDatabase: + return CloneRefOfAlterDatabase(in) + case *AlterTable: + return CloneRefOfAlterTable(in) + case *AlterView: + return CloneRefOfAlterView(in) + case *AlterVschema: + return CloneRefOfAlterVschema(in) + case *AndExpr: + return CloneRefOfAndExpr(in) + case Argument: + return CloneArgument(in) + case *AutoIncSpec: + return CloneRefOfAutoIncSpec(in) + case *Begin: + return CloneRefOfBegin(in) + case *BinaryExpr: + return CloneRefOfBinaryExpr(in) + case BoolVal: + return in + case *CallProc: + return CloneRefOfCallProc(in) + case *CaseExpr: + return CloneRefOfCaseExpr(in) + case *ChangeColumn: + return CloneRefOfChangeColumn(in) + case *CheckConstraintDefinition: + return CloneRefOfCheckConstraintDefinition(in) + case ColIdent: + return CloneColIdent(in) + case *ColName: + return CloneRefOfColName(in) + case *CollateExpr: + return CloneRefOfCollateExpr(in) + case *ColumnDefinition: + return CloneRefOfColumnDefinition(in) + case *ColumnType: + return CloneRefOfColumnType(in) + case Columns: + return CloneColumns(in) + case Comments: + return CloneComments(in) + case *Commit: + return CloneRefOfCommit(in) + case *ComparisonExpr: + return CloneRefOfComparisonExpr(in) + case *ConstraintDefinition: + return CloneRefOfConstraintDefinition(in) + case *ConvertExpr: + return CloneRefOfConvertExpr(in) + case *ConvertType: + return CloneRefOfConvertType(in) + case *ConvertUsingExpr: + return CloneRefOfConvertUsingExpr(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *CurTimeFuncExpr: + return CloneRefOfCurTimeFuncExpr(in) + case *Default: + return CloneRefOfDefault(in) + case *Delete: + return CloneRefOfDelete(in) + case *DerivedTable: + return CloneRefOfDerivedTable(in) + case *DropColumn: + return CloneRefOfDropColumn(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + case *DropKey: + return CloneRefOfDropKey(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *ExistsExpr: + return CloneRefOfExistsExpr(in) + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + case Exprs: + return CloneExprs(in) + case *Flush: + return CloneRefOfFlush(in) + case *Force: + return CloneRefOfForce(in) + case *ForeignKeyDefinition: + return CloneRefOfForeignKeyDefinition(in) + case *FuncExpr: + return CloneRefOfFuncExpr(in) + case GroupBy: + return CloneGroupBy(in) + case *GroupConcatExpr: + return CloneRefOfGroupConcatExpr(in) + case *IndexDefinition: + return CloneRefOfIndexDefinition(in) + case *IndexHints: + return CloneRefOfIndexHints(in) + case *IndexInfo: + return CloneRefOfIndexInfo(in) + case *Insert: + return CloneRefOfInsert(in) + case *IntervalExpr: + return CloneRefOfIntervalExpr(in) + case *IsExpr: + return CloneRefOfIsExpr(in) + case IsolationLevel: + return in + case JoinCondition: + return CloneJoinCondition(in) + case *JoinTableExpr: + return CloneRefOfJoinTableExpr(in) + case *KeyState: + return CloneRefOfKeyState(in) + case *Limit: + return CloneRefOfLimit(in) + case ListArg: + return CloneListArg(in) + case *Literal: + return CloneRefOfLiteral(in) + case *Load: + return CloneRefOfLoad(in) + case *LockOption: + return CloneRefOfLockOption(in) + case *LockTables: + return CloneRefOfLockTables(in) + case *MatchExpr: + return CloneRefOfMatchExpr(in) + case *ModifyColumn: + return CloneRefOfModifyColumn(in) + case *Nextval: + return CloneRefOfNextval(in) + case *NotExpr: + return CloneRefOfNotExpr(in) + case *NullVal: + return CloneRefOfNullVal(in) + case OnDup: + return CloneOnDup(in) + case *OptLike: + return CloneRefOfOptLike(in) + case *OrExpr: + return CloneRefOfOrExpr(in) + case *Order: + return CloneRefOfOrder(in) + case OrderBy: + return CloneOrderBy(in) + case *OrderByOption: + return CloneRefOfOrderByOption(in) + case *OtherAdmin: + return CloneRefOfOtherAdmin(in) + case *OtherRead: + return CloneRefOfOtherRead(in) + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *ParenTableExpr: + return CloneRefOfParenTableExpr(in) + case *PartitionDefinition: + return CloneRefOfPartitionDefinition(in) + case *PartitionSpec: + return CloneRefOfPartitionSpec(in) + case Partitions: + return ClonePartitions(in) + case *RangeCond: + return CloneRefOfRangeCond(in) + case ReferenceAction: + return in + case *Release: + return CloneRefOfRelease(in) + case *RenameIndex: + return CloneRefOfRenameIndex(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *RenameTableName: + return CloneRefOfRenameTableName(in) + case *Rollback: + return CloneRefOfRollback(in) + case *SRollback: + return CloneRefOfSRollback(in) + case *Savepoint: + return CloneRefOfSavepoint(in) + case *Select: + return CloneRefOfSelect(in) + case SelectExprs: + return CloneSelectExprs(in) + case *SelectInto: + return CloneRefOfSelectInto(in) + case *Set: + return CloneRefOfSet(in) + case *SetExpr: + return CloneRefOfSetExpr(in) + case SetExprs: + return CloneSetExprs(in) + case *SetTransaction: + return CloneRefOfSetTransaction(in) + case *Show: + return CloneRefOfShow(in) + case *ShowBasic: + return CloneRefOfShowBasic(in) + case *ShowCreate: + return CloneRefOfShowCreate(in) + case *ShowFilter: + return CloneRefOfShowFilter(in) + case *ShowLegacy: + return CloneRefOfShowLegacy(in) + case *StarExpr: + return CloneRefOfStarExpr(in) + case *Stream: + return CloneRefOfStream(in) + case *Subquery: + return CloneRefOfSubquery(in) + case *SubstrExpr: + return CloneRefOfSubstrExpr(in) + case TableExprs: + return CloneTableExprs(in) + case TableIdent: + return CloneTableIdent(in) + case TableName: + return CloneTableName(in) + case TableNames: + return CloneTableNames(in) + case TableOptions: + return CloneTableOptions(in) + case *TableSpec: + return CloneRefOfTableSpec(in) + case *TablespaceOperation: + return CloneRefOfTablespaceOperation(in) + case *TimestampFuncExpr: + return CloneRefOfTimestampFuncExpr(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + case *UnaryExpr: + return CloneRefOfUnaryExpr(in) + case *Union: + return CloneRefOfUnion(in) + case *UnionSelect: + return CloneRefOfUnionSelect(in) + case *UnlockTables: + return CloneRefOfUnlockTables(in) + case *Update: + return CloneRefOfUpdate(in) + case *UpdateExpr: + return CloneRefOfUpdateExpr(in) + case UpdateExprs: + return CloneUpdateExprs(in) + case *Use: + return CloneRefOfUse(in) + case *VStream: + return CloneRefOfVStream(in) + case ValTuple: + return CloneValTuple(in) + case *Validation: + return CloneRefOfValidation(in) + case Values: + return CloneValues(in) + case *ValuesFuncExpr: + return CloneRefOfValuesFuncExpr(in) + case VindexParam: + return CloneVindexParam(in) + case *VindexSpec: + return CloneRefOfVindexSpec(in) + case *When: + return CloneRefOfWhen(in) + case *Where: + return CloneRefOfWhere(in) + case *XorExpr: + return CloneRefOfXorExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneSelectExpr creates a deep clone of the input. +func CloneSelectExpr(in SelectExpr) SelectExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AliasedExpr: + return CloneRefOfAliasedExpr(in) + case *Nextval: + return CloneRefOfNextval(in) + case *StarExpr: + return CloneRefOfStarExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneSelectStatement creates a deep clone of the input. +func CloneSelectStatement(in SelectStatement) SelectStatement { + if in == nil { + return nil + } + switch in := in.(type) { + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *Select: + return CloneRefOfSelect(in) + case *Union: + return CloneRefOfUnion(in) + default: + // this should never happen + return nil + } +} + +// CloneShowInternal creates a deep clone of the input. +func CloneShowInternal(in ShowInternal) ShowInternal { + if in == nil { + return nil + } + switch in := in.(type) { + case *ShowBasic: + return CloneRefOfShowBasic(in) + case *ShowCreate: + return CloneRefOfShowCreate(in) + case *ShowLegacy: + return CloneRefOfShowLegacy(in) + default: + // this should never happen + return nil + } +} + +// CloneSimpleTableExpr creates a deep clone of the input. +func CloneSimpleTableExpr(in SimpleTableExpr) SimpleTableExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *DerivedTable: + return CloneRefOfDerivedTable(in) + case TableName: + return CloneTableName(in) + default: + // this should never happen + return nil + } +} + +// CloneStatement creates a deep clone of the input. +func CloneStatement(in Statement) Statement { + if in == nil { + return nil + } + switch in := in.(type) { + case *AlterDatabase: + return CloneRefOfAlterDatabase(in) + case *AlterTable: + return CloneRefOfAlterTable(in) + case *AlterView: + return CloneRefOfAlterView(in) + case *AlterVschema: + return CloneRefOfAlterVschema(in) + case *Begin: + return CloneRefOfBegin(in) + case *CallProc: + return CloneRefOfCallProc(in) + case *Commit: + return CloneRefOfCommit(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *Delete: + return CloneRefOfDelete(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + case *Flush: + return CloneRefOfFlush(in) + case *Insert: + return CloneRefOfInsert(in) + case *Load: + return CloneRefOfLoad(in) + case *LockTables: + return CloneRefOfLockTables(in) + case *OtherAdmin: + return CloneRefOfOtherAdmin(in) + case *OtherRead: + return CloneRefOfOtherRead(in) + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *Release: + return CloneRefOfRelease(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *Rollback: + return CloneRefOfRollback(in) + case *SRollback: + return CloneRefOfSRollback(in) + case *Savepoint: + return CloneRefOfSavepoint(in) + case *Select: + return CloneRefOfSelect(in) + case *Set: + return CloneRefOfSet(in) + case *SetTransaction: + return CloneRefOfSetTransaction(in) + case *Show: + return CloneRefOfShow(in) + case *Stream: + return CloneRefOfStream(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + case *Union: + return CloneRefOfUnion(in) + case *UnlockTables: + return CloneRefOfUnlockTables(in) + case *Update: + return CloneRefOfUpdate(in) + case *Use: + return CloneRefOfUse(in) + case *VStream: + return CloneRefOfVStream(in) + default: + // this should never happen + return nil + } +} + +// CloneTableExpr creates a deep clone of the input. +func CloneTableExpr(in TableExpr) TableExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AliasedTableExpr: + return CloneRefOfAliasedTableExpr(in) + case *JoinTableExpr: + return CloneRefOfJoinTableExpr(in) + case *ParenTableExpr: + return CloneRefOfParenTableExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneRefOfAddColumns creates a deep clone of the input. +func CloneRefOfAddColumns(n *AddColumns) *AddColumns { + if n == nil { + return nil + } + out := *n + out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out +} + +// CloneRefOfAddConstraintDefinition creates a deep clone of the input. +func CloneRefOfAddConstraintDefinition(n *AddConstraintDefinition) *AddConstraintDefinition { + if n == nil { + return nil + } + out := *n + out.ConstraintDefinition = CloneRefOfConstraintDefinition(n.ConstraintDefinition) + return &out +} + +// CloneRefOfAddIndexDefinition creates a deep clone of the input. +func CloneRefOfAddIndexDefinition(n *AddIndexDefinition) *AddIndexDefinition { + if n == nil { + return nil + } + out := *n + out.IndexDefinition = CloneRefOfIndexDefinition(n.IndexDefinition) + return &out +} + +// CloneRefOfAlterCharset creates a deep clone of the input. +func CloneRefOfAlterCharset(n *AlterCharset) *AlterCharset { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfAlterColumn creates a deep clone of the input. +func CloneRefOfAlterColumn(n *AlterColumn) *AlterColumn { + if n == nil { + return nil + } + out := *n + out.Column = CloneRefOfColName(n.Column) + out.DefaultVal = CloneExpr(n.DefaultVal) + return &out +} + +// CloneRefOfChangeColumn creates a deep clone of the input. +func CloneRefOfChangeColumn(n *ChangeColumn) *ChangeColumn { + if n == nil { + return nil + } + out := *n + out.OldColumn = CloneRefOfColName(n.OldColumn) + out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out +} + +// CloneRefOfDropColumn creates a deep clone of the input. +func CloneRefOfDropColumn(n *DropColumn) *DropColumn { + if n == nil { + return nil + } + out := *n + out.Name = CloneRefOfColName(n.Name) + return &out +} + +// CloneRefOfDropKey creates a deep clone of the input. +func CloneRefOfDropKey(n *DropKey) *DropKey { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfForce creates a deep clone of the input. +func CloneRefOfForce(n *Force) *Force { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfKeyState creates a deep clone of the input. +func CloneRefOfKeyState(n *KeyState) *KeyState { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfLockOption creates a deep clone of the input. +func CloneRefOfLockOption(n *LockOption) *LockOption { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfModifyColumn creates a deep clone of the input. +func CloneRefOfModifyColumn(n *ModifyColumn) *ModifyColumn { + if n == nil { + return nil + } + out := *n + out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out +} + +// CloneRefOfOrderByOption creates a deep clone of the input. +func CloneRefOfOrderByOption(n *OrderByOption) *OrderByOption { + if n == nil { + return nil + } + out := *n + out.Cols = CloneColumns(n.Cols) + return &out +} + +// CloneRefOfRenameIndex creates a deep clone of the input. +func CloneRefOfRenameIndex(n *RenameIndex) *RenameIndex { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfRenameTableName creates a deep clone of the input. +func CloneRefOfRenameTableName(n *RenameTableName) *RenameTableName { + if n == nil { + return nil + } + out := *n + out.Table = CloneTableName(n.Table) + return &out +} + +// CloneTableOptions creates a deep clone of the input. +func CloneTableOptions(n TableOptions) TableOptions { + res := make(TableOptions, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfTableOption(x)) + } + return res +} + +// CloneRefOfTablespaceOperation creates a deep clone of the input. +func CloneRefOfTablespaceOperation(n *TablespaceOperation) *TablespaceOperation { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfValidation creates a deep clone of the input. +func CloneRefOfValidation(n *Validation) *Validation { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneListArg creates a deep clone of the input. +func CloneListArg(n ListArg) ListArg { + res := make(ListArg, 0, len(n)) + copy(res, n) + return res +} + +// CloneRefOfSubquery creates a deep clone of the input. +func CloneRefOfSubquery(n *Subquery) *Subquery { + if n == nil { + return nil + } + out := *n + out.Select = CloneSelectStatement(n.Select) + return &out +} + +// CloneValTuple creates a deep clone of the input. +func CloneValTuple(n ValTuple) ValTuple { + res := make(ValTuple, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneRefOfCheckConstraintDefinition creates a deep clone of the input. +func CloneRefOfCheckConstraintDefinition(n *CheckConstraintDefinition) *CheckConstraintDefinition { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfForeignKeyDefinition creates a deep clone of the input. +func CloneRefOfForeignKeyDefinition(n *ForeignKeyDefinition) *ForeignKeyDefinition { + if n == nil { + return nil + } + out := *n + out.Source = CloneColumns(n.Source) + out.ReferencedTable = CloneTableName(n.ReferencedTable) + out.ReferencedColumns = CloneColumns(n.ReferencedColumns) + return &out +} + +// CloneRefOfAlterDatabase creates a deep clone of the input. +func CloneRefOfAlterDatabase(n *AlterDatabase) *AlterDatabase { + if n == nil { + return nil + } + out := *n + out.AlterOptions = CloneSliceOfCollateAndCharset(n.AlterOptions) + return &out +} + +// CloneRefOfCreateDatabase creates a deep clone of the input. +func CloneRefOfCreateDatabase(n *CreateDatabase) *CreateDatabase { + if n == nil { + return nil + } + out := *n + out.CreateOptions = CloneSliceOfCollateAndCharset(n.CreateOptions) + return &out +} + +// CloneRefOfDropDatabase creates a deep clone of the input. +func CloneRefOfDropDatabase(n *DropDatabase) *DropDatabase { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfAlterTable creates a deep clone of the input. +func CloneRefOfAlterTable(n *AlterTable) *AlterTable { + if n == nil { + return nil + } + out := *n + out.Table = CloneTableName(n.Table) + out.AlterOptions = CloneSliceOfAlterOption(n.AlterOptions) + out.PartitionSpec = CloneRefOfPartitionSpec(n.PartitionSpec) + return &out +} + +// CloneRefOfAlterView creates a deep clone of the input. +func CloneRefOfAlterView(n *AlterView) *AlterView { + if n == nil { + return nil + } + out := *n + out.ViewName = CloneTableName(n.ViewName) + out.Columns = CloneColumns(n.Columns) + out.Select = CloneSelectStatement(n.Select) + return &out +} + +// CloneRefOfCreateTable creates a deep clone of the input. +func CloneRefOfCreateTable(n *CreateTable) *CreateTable { + if n == nil { + return nil + } + out := *n + out.Table = CloneTableName(n.Table) + out.TableSpec = CloneRefOfTableSpec(n.TableSpec) + out.OptLike = CloneRefOfOptLike(n.OptLike) + return &out +} + +// CloneRefOfCreateView creates a deep clone of the input. +func CloneRefOfCreateView(n *CreateView) *CreateView { + if n == nil { + return nil + } + out := *n + out.ViewName = CloneTableName(n.ViewName) + out.Columns = CloneColumns(n.Columns) + out.Select = CloneSelectStatement(n.Select) + return &out +} + +// CloneRefOfDropTable creates a deep clone of the input. +func CloneRefOfDropTable(n *DropTable) *DropTable { + if n == nil { + return nil + } + out := *n + out.FromTables = CloneTableNames(n.FromTables) + return &out +} + +// CloneRefOfDropView creates a deep clone of the input. +func CloneRefOfDropView(n *DropView) *DropView { + if n == nil { + return nil + } + out := *n + out.FromTables = CloneTableNames(n.FromTables) + return &out +} + +// CloneRefOfRenameTable creates a deep clone of the input. +func CloneRefOfRenameTable(n *RenameTable) *RenameTable { + if n == nil { + return nil + } + out := *n + out.TablePairs = CloneSliceOfRefOfRenameTablePair(n.TablePairs) + return &out +} + +// CloneRefOfTruncateTable creates a deep clone of the input. +func CloneRefOfTruncateTable(n *TruncateTable) *TruncateTable { + if n == nil { + return nil + } + out := *n + out.Table = CloneTableName(n.Table) + return &out +} + +// CloneRefOfExplainStmt creates a deep clone of the input. +func CloneRefOfExplainStmt(n *ExplainStmt) *ExplainStmt { + if n == nil { + return nil + } + out := *n + out.Statement = CloneStatement(n.Statement) + return &out +} + +// CloneRefOfExplainTab creates a deep clone of the input. +func CloneRefOfExplainTab(n *ExplainTab) *ExplainTab { + if n == nil { + return nil + } + out := *n + out.Table = CloneTableName(n.Table) + return &out +} + +// CloneRefOfAndExpr creates a deep clone of the input. +func CloneRefOfAndExpr(n *AndExpr) *AndExpr { + if n == nil { + return nil + } + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out +} + +// CloneArgument creates a deep clone of the input. +func CloneArgument(n Argument) Argument { + res := make(Argument, 0, len(n)) + copy(res, n) + return res +} + +// CloneRefOfBinaryExpr creates a deep clone of the input. +func CloneRefOfBinaryExpr(n *BinaryExpr) *BinaryExpr { + if n == nil { + return nil + } + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out +} + +// CloneRefOfCaseExpr creates a deep clone of the input. +func CloneRefOfCaseExpr(n *CaseExpr) *CaseExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + out.Whens = CloneSliceOfRefOfWhen(n.Whens) + out.Else = CloneExpr(n.Else) + return &out +} + +// CloneRefOfColName creates a deep clone of the input. +func CloneRefOfColName(n *ColName) *ColName { + return n +} + +// CloneRefOfCollateExpr creates a deep clone of the input. +func CloneRefOfCollateExpr(n *CollateExpr) *CollateExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfComparisonExpr creates a deep clone of the input. +func CloneRefOfComparisonExpr(n *ComparisonExpr) *ComparisonExpr { + if n == nil { + return nil + } + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + out.Escape = CloneExpr(n.Escape) + return &out +} + +// CloneRefOfConvertExpr creates a deep clone of the input. +func CloneRefOfConvertExpr(n *ConvertExpr) *ConvertExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + out.Type = CloneRefOfConvertType(n.Type) + return &out +} + +// CloneRefOfConvertUsingExpr creates a deep clone of the input. +func CloneRefOfConvertUsingExpr(n *ConvertUsingExpr) *ConvertUsingExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfCurTimeFuncExpr creates a deep clone of the input. +func CloneRefOfCurTimeFuncExpr(n *CurTimeFuncExpr) *CurTimeFuncExpr { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + out.Fsp = CloneExpr(n.Fsp) + return &out +} + +// CloneRefOfDefault creates a deep clone of the input. +func CloneRefOfDefault(n *Default) *Default { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfExistsExpr creates a deep clone of the input. +func CloneRefOfExistsExpr(n *ExistsExpr) *ExistsExpr { + if n == nil { + return nil + } + out := *n + out.Subquery = CloneRefOfSubquery(n.Subquery) + return &out +} + +// CloneRefOfFuncExpr creates a deep clone of the input. +func CloneRefOfFuncExpr(n *FuncExpr) *FuncExpr { + if n == nil { + return nil + } + out := *n + out.Qualifier = CloneTableIdent(n.Qualifier) + out.Name = CloneColIdent(n.Name) + out.Exprs = CloneSelectExprs(n.Exprs) + return &out +} + +// CloneRefOfGroupConcatExpr creates a deep clone of the input. +func CloneRefOfGroupConcatExpr(n *GroupConcatExpr) *GroupConcatExpr { + if n == nil { + return nil + } + out := *n + out.Exprs = CloneSelectExprs(n.Exprs) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out +} + +// CloneRefOfIntervalExpr creates a deep clone of the input. +func CloneRefOfIntervalExpr(n *IntervalExpr) *IntervalExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfIsExpr creates a deep clone of the input. +func CloneRefOfIsExpr(n *IsExpr) *IsExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfLiteral creates a deep clone of the input. +func CloneRefOfLiteral(n *Literal) *Literal { + if n == nil { + return nil + } + out := *n + out.Val = CloneSliceOfbyte(n.Val) + return &out +} + +// CloneRefOfMatchExpr creates a deep clone of the input. +func CloneRefOfMatchExpr(n *MatchExpr) *MatchExpr { + if n == nil { + return nil + } + out := *n + out.Columns = CloneSelectExprs(n.Columns) + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfNotExpr creates a deep clone of the input. +func CloneRefOfNotExpr(n *NotExpr) *NotExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfNullVal creates a deep clone of the input. +func CloneRefOfNullVal(n *NullVal) *NullVal { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfOrExpr creates a deep clone of the input. +func CloneRefOfOrExpr(n *OrExpr) *OrExpr { + if n == nil { + return nil + } + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out +} + +// CloneRefOfRangeCond creates a deep clone of the input. +func CloneRefOfRangeCond(n *RangeCond) *RangeCond { + if n == nil { + return nil + } + out := *n + out.Left = CloneExpr(n.Left) + out.From = CloneExpr(n.From) + out.To = CloneExpr(n.To) + return &out +} + +// CloneRefOfSubstrExpr creates a deep clone of the input. +func CloneRefOfSubstrExpr(n *SubstrExpr) *SubstrExpr { + if n == nil { + return nil + } + out := *n + out.Name = CloneRefOfColName(n.Name) + out.StrVal = CloneRefOfLiteral(n.StrVal) + out.From = CloneExpr(n.From) + out.To = CloneExpr(n.To) + return &out +} + +// CloneRefOfTimestampFuncExpr creates a deep clone of the input. +func CloneRefOfTimestampFuncExpr(n *TimestampFuncExpr) *TimestampFuncExpr { + if n == nil { + return nil + } + out := *n + out.Expr1 = CloneExpr(n.Expr1) + out.Expr2 = CloneExpr(n.Expr2) + return &out +} + +// CloneRefOfUnaryExpr creates a deep clone of the input. +func CloneRefOfUnaryExpr(n *UnaryExpr) *UnaryExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneRefOfValuesFuncExpr creates a deep clone of the input. +func CloneRefOfValuesFuncExpr(n *ValuesFuncExpr) *ValuesFuncExpr { + if n == nil { + return nil + } + out := *n + out.Name = CloneRefOfColName(n.Name) + return &out +} + +// CloneRefOfXorExpr creates a deep clone of the input. +func CloneRefOfXorExpr(n *XorExpr) *XorExpr { + if n == nil { + return nil + } + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out +} + +// CloneRefOfParenSelect creates a deep clone of the input. +func CloneRefOfParenSelect(n *ParenSelect) *ParenSelect { + if n == nil { + return nil + } + out := *n + out.Select = CloneSelectStatement(n.Select) + return &out +} + +// CloneRefOfSelect creates a deep clone of the input. +func CloneRefOfSelect(n *Select) *Select { + if n == nil { + return nil + } + out := *n + out.Cache = CloneRefOfbool(n.Cache) + out.Comments = CloneComments(n.Comments) + out.SelectExprs = CloneSelectExprs(n.SelectExprs) + out.From = CloneTableExprs(n.From) + out.Where = CloneRefOfWhere(n.Where) + out.GroupBy = CloneGroupBy(n.GroupBy) + out.Having = CloneRefOfWhere(n.Having) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + out.Into = CloneRefOfSelectInto(n.Into) + return &out +} + +// CloneRefOfUnion creates a deep clone of the input. +func CloneRefOfUnion(n *Union) *Union { + if n == nil { + return nil + } + out := *n + out.FirstStatement = CloneSelectStatement(n.FirstStatement) + out.UnionSelects = CloneSliceOfRefOfUnionSelect(n.UnionSelects) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out +} + +// CloneValues creates a deep clone of the input. +func CloneValues(n Values) Values { + res := make(Values, 0, len(n)) + for _, x := range n { + res = append(res, CloneValTuple(x)) + } + return res +} + +// CloneRefOfAliasedExpr creates a deep clone of the input. +func CloneRefOfAliasedExpr(n *AliasedExpr) *AliasedExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + out.As = CloneColIdent(n.As) + return &out +} + +// CloneRefOfAliasedTableExpr creates a deep clone of the input. +func CloneRefOfAliasedTableExpr(n *AliasedTableExpr) *AliasedTableExpr { + if n == nil { + return nil + } + out := *n + out.Expr = CloneSimpleTableExpr(n.Expr) + out.Partitions = ClonePartitions(n.Partitions) + out.As = CloneTableIdent(n.As) + out.Hints = CloneRefOfIndexHints(n.Hints) + return &out +} + +// CloneRefOfAlterVschema creates a deep clone of the input. +func CloneRefOfAlterVschema(n *AlterVschema) *AlterVschema { + if n == nil { + return nil + } + out := *n + out.Table = CloneTableName(n.Table) + out.VindexSpec = CloneRefOfVindexSpec(n.VindexSpec) + out.VindexCols = CloneSliceOfColIdent(n.VindexCols) + out.AutoIncSpec = CloneRefOfAutoIncSpec(n.AutoIncSpec) + return &out +} + +// CloneRefOfAutoIncSpec creates a deep clone of the input. +func CloneRefOfAutoIncSpec(n *AutoIncSpec) *AutoIncSpec { + if n == nil { + return nil + } + out := *n + out.Column = CloneColIdent(n.Column) + out.Sequence = CloneTableName(n.Sequence) + return &out +} + +// CloneRefOfBegin creates a deep clone of the input. +func CloneRefOfBegin(n *Begin) *Begin { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfCallProc creates a deep clone of the input. +func CloneRefOfCallProc(n *CallProc) *CallProc { + if n == nil { + return nil + } + out := *n + out.Name = CloneTableName(n.Name) + out.Params = CloneExprs(n.Params) + return &out +} + +// CloneColIdent creates a deep clone of the input. +func CloneColIdent(n ColIdent) ColIdent { + return *CloneRefOfColIdent(&n) +} + +// CloneRefOfColumnDefinition creates a deep clone of the input. +func CloneRefOfColumnDefinition(n *ColumnDefinition) *ColumnDefinition { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + out.Type = CloneColumnType(n.Type) + return &out +} + +// CloneRefOfColumnType creates a deep clone of the input. +func CloneRefOfColumnType(n *ColumnType) *ColumnType { + if n == nil { + return nil + } + out := *n + out.Options = CloneRefOfColumnTypeOptions(n.Options) + out.Length = CloneRefOfLiteral(n.Length) + out.Scale = CloneRefOfLiteral(n.Scale) + out.EnumValues = CloneSliceOfstring(n.EnumValues) + return &out +} + +// CloneColumns creates a deep clone of the input. +func CloneColumns(n Columns) Columns { + res := make(Columns, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneComments creates a deep clone of the input. +func CloneComments(n Comments) Comments { + res := make(Comments, 0, len(n)) + for _, x := range n { + res = append(res, CloneSliceOfbyte(x)) + } + return res +} + +// CloneRefOfCommit creates a deep clone of the input. +func CloneRefOfCommit(n *Commit) *Commit { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfConstraintDefinition creates a deep clone of the input. +func CloneRefOfConstraintDefinition(n *ConstraintDefinition) *ConstraintDefinition { + if n == nil { + return nil + } + out := *n + out.Details = CloneConstraintInfo(n.Details) + return &out +} + +// CloneRefOfConvertType creates a deep clone of the input. +func CloneRefOfConvertType(n *ConvertType) *ConvertType { + if n == nil { + return nil + } + out := *n + out.Length = CloneRefOfLiteral(n.Length) + out.Scale = CloneRefOfLiteral(n.Scale) + return &out +} + +// CloneRefOfDelete creates a deep clone of the input. +func CloneRefOfDelete(n *Delete) *Delete { + if n == nil { + return nil + } + out := *n + out.Comments = CloneComments(n.Comments) + out.Targets = CloneTableNames(n.Targets) + out.TableExprs = CloneTableExprs(n.TableExprs) + out.Partitions = ClonePartitions(n.Partitions) + out.Where = CloneRefOfWhere(n.Where) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out +} + +// CloneRefOfDerivedTable creates a deep clone of the input. +func CloneRefOfDerivedTable(n *DerivedTable) *DerivedTable { + if n == nil { + return nil + } + out := *n + out.Select = CloneSelectStatement(n.Select) + return &out +} + +// CloneExprs creates a deep clone of the input. +func CloneExprs(n Exprs) Exprs { + res := make(Exprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneRefOfFlush creates a deep clone of the input. +func CloneRefOfFlush(n *Flush) *Flush { + if n == nil { + return nil + } + out := *n + out.FlushOptions = CloneSliceOfstring(n.FlushOptions) + out.TableNames = CloneTableNames(n.TableNames) + return &out +} + +// CloneGroupBy creates a deep clone of the input. +func CloneGroupBy(n GroupBy) GroupBy { + res := make(GroupBy, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneRefOfIndexDefinition creates a deep clone of the input. +func CloneRefOfIndexDefinition(n *IndexDefinition) *IndexDefinition { + if n == nil { + return nil + } + out := *n + out.Info = CloneRefOfIndexInfo(n.Info) + out.Columns = CloneSliceOfRefOfIndexColumn(n.Columns) + out.Options = CloneSliceOfRefOfIndexOption(n.Options) + return &out +} + +// CloneRefOfIndexHints creates a deep clone of the input. +func CloneRefOfIndexHints(n *IndexHints) *IndexHints { + if n == nil { + return nil + } + out := *n + out.Indexes = CloneSliceOfColIdent(n.Indexes) + return &out +} + +// CloneRefOfIndexInfo creates a deep clone of the input. +func CloneRefOfIndexInfo(n *IndexInfo) *IndexInfo { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + out.ConstraintName = CloneColIdent(n.ConstraintName) + return &out +} + +// CloneRefOfInsert creates a deep clone of the input. +func CloneRefOfInsert(n *Insert) *Insert { + if n == nil { + return nil + } + out := *n + out.Comments = CloneComments(n.Comments) + out.Table = CloneTableName(n.Table) + out.Partitions = ClonePartitions(n.Partitions) + out.Columns = CloneColumns(n.Columns) + out.Rows = CloneInsertRows(n.Rows) + out.OnDup = CloneOnDup(n.OnDup) + return &out +} + +// CloneJoinCondition creates a deep clone of the input. +func CloneJoinCondition(n JoinCondition) JoinCondition { + return *CloneRefOfJoinCondition(&n) +} + +// CloneRefOfJoinTableExpr creates a deep clone of the input. +func CloneRefOfJoinTableExpr(n *JoinTableExpr) *JoinTableExpr { + if n == nil { + return nil + } + out := *n + out.LeftExpr = CloneTableExpr(n.LeftExpr) + out.RightExpr = CloneTableExpr(n.RightExpr) + out.Condition = CloneJoinCondition(n.Condition) + return &out +} + +// CloneRefOfLimit creates a deep clone of the input. +func CloneRefOfLimit(n *Limit) *Limit { + if n == nil { + return nil + } + out := *n + out.Offset = CloneExpr(n.Offset) + out.Rowcount = CloneExpr(n.Rowcount) + return &out +} + +// CloneRefOfLoad creates a deep clone of the input. +func CloneRefOfLoad(n *Load) *Load { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfLockTables creates a deep clone of the input. +func CloneRefOfLockTables(n *LockTables) *LockTables { + if n == nil { + return nil + } + out := *n + out.Tables = CloneTableAndLockTypes(n.Tables) + return &out +} + +// CloneRefOfNextval creates a deep clone of the input. +func CloneRefOfNextval(n *Nextval) *Nextval { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneOnDup creates a deep clone of the input. +func CloneOnDup(n OnDup) OnDup { + res := make(OnDup, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUpdateExpr(x)) + } + return res +} + +// CloneRefOfOptLike creates a deep clone of the input. +func CloneRefOfOptLike(n *OptLike) *OptLike { + if n == nil { + return nil + } + out := *n + out.LikeTable = CloneTableName(n.LikeTable) + return &out +} + +// CloneRefOfOrder creates a deep clone of the input. +func CloneRefOfOrder(n *Order) *Order { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneOrderBy creates a deep clone of the input. +func CloneOrderBy(n OrderBy) OrderBy { + res := make(OrderBy, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfOrder(x)) + } + return res +} + +// CloneRefOfOtherAdmin creates a deep clone of the input. +func CloneRefOfOtherAdmin(n *OtherAdmin) *OtherAdmin { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfOtherRead creates a deep clone of the input. +func CloneRefOfOtherRead(n *OtherRead) *OtherRead { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfParenTableExpr creates a deep clone of the input. +func CloneRefOfParenTableExpr(n *ParenTableExpr) *ParenTableExpr { + if n == nil { + return nil + } + out := *n + out.Exprs = CloneTableExprs(n.Exprs) + return &out +} + +// CloneRefOfPartitionDefinition creates a deep clone of the input. +func CloneRefOfPartitionDefinition(n *PartitionDefinition) *PartitionDefinition { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + out.Limit = CloneExpr(n.Limit) + return &out +} + +// CloneRefOfPartitionSpec creates a deep clone of the input. +func CloneRefOfPartitionSpec(n *PartitionSpec) *PartitionSpec { + if n == nil { + return nil + } + out := *n + out.Names = ClonePartitions(n.Names) + out.Number = CloneRefOfLiteral(n.Number) + out.TableName = CloneTableName(n.TableName) + out.Definitions = CloneSliceOfRefOfPartitionDefinition(n.Definitions) + return &out +} + +// ClonePartitions creates a deep clone of the input. +func ClonePartitions(n Partitions) Partitions { + res := make(Partitions, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneRefOfRelease creates a deep clone of the input. +func CloneRefOfRelease(n *Release) *Release { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + return &out +} + +// CloneRefOfRollback creates a deep clone of the input. +func CloneRefOfRollback(n *Rollback) *Rollback { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfSRollback creates a deep clone of the input. +func CloneRefOfSRollback(n *SRollback) *SRollback { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + return &out +} + +// CloneRefOfSavepoint creates a deep clone of the input. +func CloneRefOfSavepoint(n *Savepoint) *Savepoint { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + return &out +} + +// CloneSelectExprs creates a deep clone of the input. +func CloneSelectExprs(n SelectExprs) SelectExprs { + res := make(SelectExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneSelectExpr(x)) + } + return res +} + +// CloneRefOfSelectInto creates a deep clone of the input. +func CloneRefOfSelectInto(n *SelectInto) *SelectInto { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfSet creates a deep clone of the input. +func CloneRefOfSet(n *Set) *Set { + if n == nil { + return nil + } + out := *n + out.Comments = CloneComments(n.Comments) + out.Exprs = CloneSetExprs(n.Exprs) + return &out +} + +// CloneRefOfSetExpr creates a deep clone of the input. +func CloneRefOfSetExpr(n *SetExpr) *SetExpr { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneSetExprs creates a deep clone of the input. +func CloneSetExprs(n SetExprs) SetExprs { + res := make(SetExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfSetExpr(x)) + } + return res +} + +// CloneRefOfSetTransaction creates a deep clone of the input. +func CloneRefOfSetTransaction(n *SetTransaction) *SetTransaction { + if n == nil { + return nil + } + out := *n + out.SQLNode = CloneSQLNode(n.SQLNode) + out.Comments = CloneComments(n.Comments) + out.Characteristics = CloneSliceOfCharacteristic(n.Characteristics) + return &out +} + +// CloneRefOfShow creates a deep clone of the input. +func CloneRefOfShow(n *Show) *Show { + if n == nil { + return nil + } + out := *n + out.Internal = CloneShowInternal(n.Internal) + return &out +} + +// CloneRefOfShowBasic creates a deep clone of the input. +func CloneRefOfShowBasic(n *ShowBasic) *ShowBasic { + if n == nil { + return nil + } + out := *n + out.Tbl = CloneTableName(n.Tbl) + out.Filter = CloneRefOfShowFilter(n.Filter) + return &out +} + +// CloneRefOfShowCreate creates a deep clone of the input. +func CloneRefOfShowCreate(n *ShowCreate) *ShowCreate { + if n == nil { + return nil + } + out := *n + out.Op = CloneTableName(n.Op) + return &out +} + +// CloneRefOfShowFilter creates a deep clone of the input. +func CloneRefOfShowFilter(n *ShowFilter) *ShowFilter { + if n == nil { + return nil + } + out := *n + out.Filter = CloneExpr(n.Filter) + return &out +} + +// CloneRefOfShowLegacy creates a deep clone of the input. +func CloneRefOfShowLegacy(n *ShowLegacy) *ShowLegacy { + if n == nil { + return nil + } + out := *n + out.OnTable = CloneTableName(n.OnTable) + out.Table = CloneTableName(n.Table) + out.ShowTablesOpt = CloneRefOfShowTablesOpt(n.ShowTablesOpt) + out.ShowCollationFilterOpt = CloneExpr(n.ShowCollationFilterOpt) + return &out +} + +// CloneRefOfStarExpr creates a deep clone of the input. +func CloneRefOfStarExpr(n *StarExpr) *StarExpr { + if n == nil { + return nil + } + out := *n + out.TableName = CloneTableName(n.TableName) + return &out +} + +// CloneRefOfStream creates a deep clone of the input. +func CloneRefOfStream(n *Stream) *Stream { + if n == nil { + return nil + } + out := *n + out.Comments = CloneComments(n.Comments) + out.SelectExpr = CloneSelectExpr(n.SelectExpr) + out.Table = CloneTableName(n.Table) + return &out +} + +// CloneTableExprs creates a deep clone of the input. +func CloneTableExprs(n TableExprs) TableExprs { + res := make(TableExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneTableExpr(x)) + } + return res +} + +// CloneTableIdent creates a deep clone of the input. +func CloneTableIdent(n TableIdent) TableIdent { + return *CloneRefOfTableIdent(&n) +} + +// CloneTableName creates a deep clone of the input. +func CloneTableName(n TableName) TableName { + return *CloneRefOfTableName(&n) +} + +// CloneTableNames creates a deep clone of the input. +func CloneTableNames(n TableNames) TableNames { + res := make(TableNames, 0, len(n)) + for _, x := range n { + res = append(res, CloneTableName(x)) + } + return res +} + +// CloneRefOfTableSpec creates a deep clone of the input. +func CloneRefOfTableSpec(n *TableSpec) *TableSpec { + if n == nil { + return nil + } + out := *n + out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) + out.Indexes = CloneSliceOfRefOfIndexDefinition(n.Indexes) + out.Constraints = CloneSliceOfRefOfConstraintDefinition(n.Constraints) + out.Options = CloneTableOptions(n.Options) + return &out +} + +// CloneRefOfUnionSelect creates a deep clone of the input. +func CloneRefOfUnionSelect(n *UnionSelect) *UnionSelect { + if n == nil { + return nil + } + out := *n + out.Statement = CloneSelectStatement(n.Statement) + return &out +} + +// CloneRefOfUnlockTables creates a deep clone of the input. +func CloneRefOfUnlockTables(n *UnlockTables) *UnlockTables { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfUpdate creates a deep clone of the input. +func CloneRefOfUpdate(n *Update) *Update { + if n == nil { + return nil + } + out := *n + out.Comments = CloneComments(n.Comments) + out.TableExprs = CloneTableExprs(n.TableExprs) + out.Exprs = CloneUpdateExprs(n.Exprs) + out.Where = CloneRefOfWhere(n.Where) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out +} + +// CloneRefOfUpdateExpr creates a deep clone of the input. +func CloneRefOfUpdateExpr(n *UpdateExpr) *UpdateExpr { + if n == nil { + return nil + } + out := *n + out.Name = CloneRefOfColName(n.Name) + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneUpdateExprs creates a deep clone of the input. +func CloneUpdateExprs(n UpdateExprs) UpdateExprs { + res := make(UpdateExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUpdateExpr(x)) + } + return res +} + +// CloneRefOfUse creates a deep clone of the input. +func CloneRefOfUse(n *Use) *Use { + if n == nil { + return nil + } + out := *n + out.DBName = CloneTableIdent(n.DBName) + return &out +} + +// CloneRefOfVStream creates a deep clone of the input. +func CloneRefOfVStream(n *VStream) *VStream { + if n == nil { + return nil + } + out := *n + out.Comments = CloneComments(n.Comments) + out.SelectExpr = CloneSelectExpr(n.SelectExpr) + out.Table = CloneTableName(n.Table) + out.Where = CloneRefOfWhere(n.Where) + out.Limit = CloneRefOfLimit(n.Limit) + return &out +} + +// CloneVindexParam creates a deep clone of the input. +func CloneVindexParam(n VindexParam) VindexParam { + return *CloneRefOfVindexParam(&n) +} + +// CloneRefOfVindexSpec creates a deep clone of the input. +func CloneRefOfVindexSpec(n *VindexSpec) *VindexSpec { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + out.Type = CloneColIdent(n.Type) + out.Params = CloneSliceOfVindexParam(n.Params) + return &out +} + +// CloneRefOfWhen creates a deep clone of the input. +func CloneRefOfWhen(n *When) *When { + if n == nil { + return nil + } + out := *n + out.Cond = CloneExpr(n.Cond) + out.Val = CloneExpr(n.Val) + return &out +} + +// CloneRefOfWhere creates a deep clone of the input. +func CloneRefOfWhere(n *Where) *Where { + if n == nil { + return nil + } + out := *n + out.Expr = CloneExpr(n.Expr) + return &out +} + +// CloneSliceOfRefOfColumnDefinition creates a deep clone of the input. +func CloneSliceOfRefOfColumnDefinition(n []*ColumnDefinition) []*ColumnDefinition { + res := make([]*ColumnDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfColumnDefinition(x)) + } + return res +} + +// CloneRefOfTableOption creates a deep clone of the input. +func CloneRefOfTableOption(n *TableOption) *TableOption { + if n == nil { + return nil + } + out := *n + out.Value = CloneRefOfLiteral(n.Value) + out.Tables = CloneTableNames(n.Tables) + return &out +} + +// CloneSliceOfCollateAndCharset creates a deep clone of the input. +func CloneSliceOfCollateAndCharset(n []CollateAndCharset) []CollateAndCharset { + res := make([]CollateAndCharset, 0, len(n)) + for _, x := range n { + res = append(res, CloneCollateAndCharset(x)) + } + return res +} + +// CloneSliceOfAlterOption creates a deep clone of the input. +func CloneSliceOfAlterOption(n []AlterOption) []AlterOption { + res := make([]AlterOption, 0, len(n)) + for _, x := range n { + res = append(res, CloneAlterOption(x)) + } + return res +} + +// CloneSliceOfRefOfRenameTablePair creates a deep clone of the input. +func CloneSliceOfRefOfRenameTablePair(n []*RenameTablePair) []*RenameTablePair { + res := make([]*RenameTablePair, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfRenameTablePair(x)) + } + return res +} + +// CloneSliceOfRefOfWhen creates a deep clone of the input. +func CloneSliceOfRefOfWhen(n []*When) []*When { + res := make([]*When, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfWhen(x)) + } + return res +} + +// CloneSliceOfbyte creates a deep clone of the input. +func CloneSliceOfbyte(n []byte) []byte { + res := make([]byte, 0, len(n)) + copy(res, n) + return res +} + +// CloneRefOfbool creates a deep clone of the input. +func CloneRefOfbool(n *bool) *bool { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneSliceOfRefOfUnionSelect creates a deep clone of the input. +func CloneSliceOfRefOfUnionSelect(n []*UnionSelect) []*UnionSelect { + res := make([]*UnionSelect, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUnionSelect(x)) + } + return res +} + +// CloneSliceOfColIdent creates a deep clone of the input. +func CloneSliceOfColIdent(n []ColIdent) []ColIdent { + res := make([]ColIdent, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneRefOfColIdent creates a deep clone of the input. +func CloneRefOfColIdent(n *ColIdent) *ColIdent { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneColumnType creates a deep clone of the input. +func CloneColumnType(n ColumnType) ColumnType { + return *CloneRefOfColumnType(&n) +} + +// CloneRefOfColumnTypeOptions creates a deep clone of the input. +func CloneRefOfColumnTypeOptions(n *ColumnTypeOptions) *ColumnTypeOptions { + if n == nil { + return nil + } + out := *n + out.Default = CloneExpr(n.Default) + out.OnUpdate = CloneExpr(n.OnUpdate) + out.Comment = CloneRefOfLiteral(n.Comment) + return &out +} + +// CloneSliceOfstring creates a deep clone of the input. +func CloneSliceOfstring(n []string) []string { + res := make([]string, 0, len(n)) + copy(res, n) + return res +} + +// CloneSliceOfRefOfIndexColumn creates a deep clone of the input. +func CloneSliceOfRefOfIndexColumn(n []*IndexColumn) []*IndexColumn { + res := make([]*IndexColumn, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexColumn(x)) + } + return res +} + +// CloneSliceOfRefOfIndexOption creates a deep clone of the input. +func CloneSliceOfRefOfIndexOption(n []*IndexOption) []*IndexOption { + res := make([]*IndexOption, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexOption(x)) + } + return res +} + +// CloneRefOfJoinCondition creates a deep clone of the input. +func CloneRefOfJoinCondition(n *JoinCondition) *JoinCondition { + if n == nil { + return nil + } + out := *n + out.On = CloneExpr(n.On) + out.Using = CloneColumns(n.Using) + return &out +} + +// CloneTableAndLockTypes creates a deep clone of the input. +func CloneTableAndLockTypes(n TableAndLockTypes) TableAndLockTypes { + res := make(TableAndLockTypes, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfTableAndLockType(x)) + } + return res +} + +// CloneSliceOfRefOfPartitionDefinition creates a deep clone of the input. +func CloneSliceOfRefOfPartitionDefinition(n []*PartitionDefinition) []*PartitionDefinition { + res := make([]*PartitionDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfPartitionDefinition(x)) + } + return res +} + +// CloneSliceOfCharacteristic creates a deep clone of the input. +func CloneSliceOfCharacteristic(n []Characteristic) []Characteristic { + res := make([]Characteristic, 0, len(n)) + for _, x := range n { + res = append(res, CloneCharacteristic(x)) + } + return res +} + +// CloneRefOfShowTablesOpt creates a deep clone of the input. +func CloneRefOfShowTablesOpt(n *ShowTablesOpt) *ShowTablesOpt { + if n == nil { + return nil + } + out := *n + out.Filter = CloneRefOfShowFilter(n.Filter) + return &out +} + +// CloneRefOfTableIdent creates a deep clone of the input. +func CloneRefOfTableIdent(n *TableIdent) *TableIdent { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfTableName creates a deep clone of the input. +func CloneRefOfTableName(n *TableName) *TableName { + if n == nil { + return nil + } + out := *n + out.Name = CloneTableIdent(n.Name) + out.Qualifier = CloneTableIdent(n.Qualifier) + return &out +} + +// CloneSliceOfRefOfIndexDefinition creates a deep clone of the input. +func CloneSliceOfRefOfIndexDefinition(n []*IndexDefinition) []*IndexDefinition { + res := make([]*IndexDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfConstraintDefinition creates a deep clone of the input. +func CloneSliceOfRefOfConstraintDefinition(n []*ConstraintDefinition) []*ConstraintDefinition { + res := make([]*ConstraintDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfConstraintDefinition(x)) + } + return res +} + +// CloneRefOfVindexParam creates a deep clone of the input. +func CloneRefOfVindexParam(n *VindexParam) *VindexParam { + if n == nil { + return nil + } + out := *n + out.Key = CloneColIdent(n.Key) + return &out +} + +// CloneSliceOfVindexParam creates a deep clone of the input. +func CloneSliceOfVindexParam(n []VindexParam) []VindexParam { + res := make([]VindexParam, 0, len(n)) + for _, x := range n { + res = append(res, CloneVindexParam(x)) + } + return res +} + +// CloneCollateAndCharset creates a deep clone of the input. +func CloneCollateAndCharset(n CollateAndCharset) CollateAndCharset { + return *CloneRefOfCollateAndCharset(&n) +} + +// CloneRefOfRenameTablePair creates a deep clone of the input. +func CloneRefOfRenameTablePair(n *RenameTablePair) *RenameTablePair { + if n == nil { + return nil + } + out := *n + out.FromTable = CloneTableName(n.FromTable) + out.ToTable = CloneTableName(n.ToTable) + return &out +} + +// CloneRefOfIndexColumn creates a deep clone of the input. +func CloneRefOfIndexColumn(n *IndexColumn) *IndexColumn { + if n == nil { + return nil + } + out := *n + out.Column = CloneColIdent(n.Column) + out.Length = CloneRefOfLiteral(n.Length) + return &out +} + +// CloneRefOfIndexOption creates a deep clone of the input. +func CloneRefOfIndexOption(n *IndexOption) *IndexOption { + if n == nil { + return nil + } + out := *n + out.Value = CloneRefOfLiteral(n.Value) + return &out +} + +// CloneRefOfTableAndLockType creates a deep clone of the input. +func CloneRefOfTableAndLockType(n *TableAndLockType) *TableAndLockType { + if n == nil { + return nil + } + out := *n + out.Table = CloneTableExpr(n.Table) + return &out +} + +// CloneRefOfCollateAndCharset creates a deep clone of the input. +func CloneRefOfCollateAndCharset(n *CollateAndCharset) *CollateAndCharset { + if n == nil { + return nil + } + out := *n + return &out +} diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index ea5f7c3ee08..5716d1c91dd 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -31,9 +31,13 @@ import ( // Within Select constructs, bind vars are deduped. This allows // us to identify vindex equality. Otherwise, every value is // treated as distinct. -func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) { +func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) error { nz := newNormalizer(stmt, bindVars, prefix) - Rewrite(stmt, nz.WalkStatement, nil) + _, err := Rewrite(stmt, nz.WalkStatement, nil) + if err != nil { + return err + } + return nz.err } type normalizer struct { @@ -42,6 +46,7 @@ type normalizer struct { reserved map[string]struct{} counter int vals map[string]string + err error } func newNormalizer(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) *normalizer { @@ -63,7 +68,8 @@ func (nz *normalizer) WalkStatement(cursor *Cursor) bool { case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, *SetTransaction, DDLStatement, *SRollback, *Release, *OtherAdmin, *OtherRead: return false case *Select: - Rewrite(node, nz.WalkSelect, nil) + _, err := Rewrite(node, nz.WalkSelect, nil) + nz.err = err // Don't continue return false case *Literal: @@ -77,7 +83,7 @@ func (nz *normalizer) WalkStatement(cursor *Cursor) bool { case *ConvertType: // we should not rewrite the type description return false } - return true + return nz.err == nil // only continue if we haven't found any errors } // WalkSelect normalizes the AST in Select mode. @@ -98,7 +104,7 @@ func (nz *normalizer) WalkSelect(cursor *Cursor) bool { // we should not rewrite the type description return false } - return true + return nz.err == nil // only continue if we haven't found any errors } func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index c28d3c61ba5..7a40b6cab4a 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -21,6 +21,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -229,7 +231,8 @@ func TestNormalize(t *testing.T) { continue } bv := make(map[string]*querypb.BindVariable) - Normalize(stmt, bv, prefix) + require.NoError(t, + Normalize(stmt, bv, prefix)) outstmt := String(stmt) if outstmt != tc.outstmt { t.Errorf("Query:\n%s:\n%s, want\n%s", tc.in, outstmt, tc.outstmt) @@ -271,6 +274,7 @@ func BenchmarkNormalize(b *testing.B) { b.Fatal(err) } for i := 0; i < b.N; i++ { - Normalize(ast, map[string]*querypb.BindVariable{}, "") + require.NoError(b, + Normalize(ast, map[string]*querypb.BindVariable{}, "")) } } diff --git a/go/vt/sqlparser/redact_query.go b/go/vt/sqlparser/redact_query.go index 55b760178f8..c46e1179494 100644 --- a/go/vt/sqlparser/redact_query.go +++ b/go/vt/sqlparser/redact_query.go @@ -29,7 +29,10 @@ func RedactSQLQuery(sql string) (string, error) { } prefix := "redacted" - Normalize(stmt, bv, prefix) + err = Normalize(stmt, bv, prefix) + if err != nil { + return "", err + } return comments.Leading + String(stmt) + comments.Trailing, nil } diff --git a/go/vt/sqlparser/rewriter.go b/go/vt/sqlparser/rewriter.go index 9e407fee6a3..bed5f3f544f 100644 --- a/go/vt/sqlparser/rewriter.go +++ b/go/vt/sqlparser/rewriter.go @@ -1,1716 +1,874 @@ -// Code generated by visitorgen/main/main.go. DO NOT EDIT. +/* +Copyright 2021 The Vitess Authors. -package sqlparser - -//go:generate go run ./visitorgen/main -input=ast.go -output=rewriter.go - -import ( - "reflect" -) - -type replacerFunc func(newNode, parent SQLNode) - -// application carries all the shared data so we can pass it around cheaply. -type application struct { - pre, post ApplyFunc - cursor Cursor -} - -func replaceAddColumnsAfter(newNode, parent SQLNode) { - parent.(*AddColumns).After = newNode.(*ColName) -} - -type replaceAddColumnsColumns int - -func (r *replaceAddColumnsColumns) replace(newNode, container SQLNode) { - container.(*AddColumns).Columns[int(*r)] = newNode.(*ColumnDefinition) -} - -func (r *replaceAddColumnsColumns) inc() { - *r++ -} - -func replaceAddColumnsFirst(newNode, parent SQLNode) { - parent.(*AddColumns).First = newNode.(*ColName) -} - -func replaceAddConstraintDefinitionConstraintDefinition(newNode, parent SQLNode) { - parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) -} - -func replaceAddIndexDefinitionIndexDefinition(newNode, parent SQLNode) { - parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) -} - -func replaceAliasedExprAs(newNode, parent SQLNode) { - parent.(*AliasedExpr).As = newNode.(ColIdent) -} - -func replaceAliasedExprExpr(newNode, parent SQLNode) { - parent.(*AliasedExpr).Expr = newNode.(Expr) -} - -func replaceAliasedTableExprAs(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).As = newNode.(TableIdent) -} - -func replaceAliasedTableExprExpr(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) -} - -func replaceAliasedTableExprHints(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) -} - -func replaceAliasedTableExprPartitions(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) -} - -func replaceAlterColumnColumn(newNode, parent SQLNode) { - parent.(*AlterColumn).Column = newNode.(*ColName) -} - -func replaceAlterColumnDefaultVal(newNode, parent SQLNode) { - parent.(*AlterColumn).DefaultVal = newNode.(Expr) -} - -type replaceAlterTableAlterOptions int - -func (r *replaceAlterTableAlterOptions) replace(newNode, container SQLNode) { - container.(*AlterTable).AlterOptions[int(*r)] = newNode.(AlterOption) -} - -func (r *replaceAlterTableAlterOptions) inc() { - *r++ -} - -func replaceAlterTablePartitionSpec(newNode, parent SQLNode) { - parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) -} - -func replaceAlterTableTable(newNode, parent SQLNode) { - parent.(*AlterTable).Table = newNode.(TableName) -} - -func replaceAlterViewColumns(newNode, parent SQLNode) { - parent.(*AlterView).Columns = newNode.(Columns) -} - -func replaceAlterViewSelect(newNode, parent SQLNode) { - parent.(*AlterView).Select = newNode.(SelectStatement) -} - -func replaceAlterViewViewName(newNode, parent SQLNode) { - parent.(*AlterView).ViewName = newNode.(TableName) -} - -func replaceAlterVschemaAutoIncSpec(newNode, parent SQLNode) { - parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) -} - -func replaceAlterVschemaTable(newNode, parent SQLNode) { - parent.(*AlterVschema).Table = newNode.(TableName) -} - -type replaceAlterVschemaVindexCols int - -func (r *replaceAlterVschemaVindexCols) replace(newNode, container SQLNode) { - container.(*AlterVschema).VindexCols[int(*r)] = newNode.(ColIdent) -} - -func (r *replaceAlterVschemaVindexCols) inc() { - *r++ -} - -func replaceAlterVschemaVindexSpec(newNode, parent SQLNode) { - parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) -} - -func replaceAndExprLeft(newNode, parent SQLNode) { - parent.(*AndExpr).Left = newNode.(Expr) -} - -func replaceAndExprRight(newNode, parent SQLNode) { - parent.(*AndExpr).Right = newNode.(Expr) -} - -func replaceAutoIncSpecColumn(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Column = newNode.(ColIdent) -} - -func replaceAutoIncSpecSequence(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Sequence = newNode.(TableName) -} - -func replaceBinaryExprLeft(newNode, parent SQLNode) { - parent.(*BinaryExpr).Left = newNode.(Expr) -} - -func replaceBinaryExprRight(newNode, parent SQLNode) { - parent.(*BinaryExpr).Right = newNode.(Expr) -} - -func replaceCallProcName(newNode, parent SQLNode) { - parent.(*CallProc).Name = newNode.(TableName) -} - -func replaceCallProcParams(newNode, parent SQLNode) { - parent.(*CallProc).Params = newNode.(Exprs) -} - -func replaceCaseExprElse(newNode, parent SQLNode) { - parent.(*CaseExpr).Else = newNode.(Expr) -} - -func replaceCaseExprExpr(newNode, parent SQLNode) { - parent.(*CaseExpr).Expr = newNode.(Expr) -} - -type replaceCaseExprWhens int - -func (r *replaceCaseExprWhens) replace(newNode, container SQLNode) { - container.(*CaseExpr).Whens[int(*r)] = newNode.(*When) -} - -func (r *replaceCaseExprWhens) inc() { - *r++ -} - -func replaceChangeColumnAfter(newNode, parent SQLNode) { - parent.(*ChangeColumn).After = newNode.(*ColName) -} - -func replaceChangeColumnFirst(newNode, parent SQLNode) { - parent.(*ChangeColumn).First = newNode.(*ColName) -} - -func replaceChangeColumnNewColDefinition(newNode, parent SQLNode) { - parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) -} - -func replaceChangeColumnOldColumn(newNode, parent SQLNode) { - parent.(*ChangeColumn).OldColumn = newNode.(*ColName) -} - -func replaceCheckConstraintDefinitionExpr(newNode, parent SQLNode) { - parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) -} - -func replaceColNameName(newNode, parent SQLNode) { - parent.(*ColName).Name = newNode.(ColIdent) -} - -func replaceColNameQualifier(newNode, parent SQLNode) { - parent.(*ColName).Qualifier = newNode.(TableName) -} - -func replaceCollateExprExpr(newNode, parent SQLNode) { - parent.(*CollateExpr).Expr = newNode.(Expr) -} - -func replaceColumnDefinitionName(newNode, parent SQLNode) { - parent.(*ColumnDefinition).Name = newNode.(ColIdent) -} - -func replaceColumnTypeLength(newNode, parent SQLNode) { - parent.(*ColumnType).Length = newNode.(*Literal) -} - -func replaceColumnTypeScale(newNode, parent SQLNode) { - parent.(*ColumnType).Scale = newNode.(*Literal) -} - -type replaceColumnsItems int - -func (r *replaceColumnsItems) replace(newNode, container SQLNode) { - container.(Columns)[int(*r)] = newNode.(ColIdent) -} - -func (r *replaceColumnsItems) inc() { - *r++ -} - -func replaceComparisonExprEscape(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Escape = newNode.(Expr) -} - -func replaceComparisonExprLeft(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Left = newNode.(Expr) -} - -func replaceComparisonExprRight(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Right = newNode.(Expr) -} - -func replaceConstraintDefinitionDetails(newNode, parent SQLNode) { - parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) -} - -func replaceConvertExprExpr(newNode, parent SQLNode) { - parent.(*ConvertExpr).Expr = newNode.(Expr) -} - -func replaceConvertExprType(newNode, parent SQLNode) { - parent.(*ConvertExpr).Type = newNode.(*ConvertType) -} - -func replaceConvertTypeLength(newNode, parent SQLNode) { - parent.(*ConvertType).Length = newNode.(*Literal) -} - -func replaceConvertTypeScale(newNode, parent SQLNode) { - parent.(*ConvertType).Scale = newNode.(*Literal) -} - -func replaceConvertUsingExprExpr(newNode, parent SQLNode) { - parent.(*ConvertUsingExpr).Expr = newNode.(Expr) -} - -func replaceCreateTableOptLike(newNode, parent SQLNode) { - parent.(*CreateTable).OptLike = newNode.(*OptLike) -} - -func replaceCreateTableTable(newNode, parent SQLNode) { - parent.(*CreateTable).Table = newNode.(TableName) -} - -func replaceCreateTableTableSpec(newNode, parent SQLNode) { - parent.(*CreateTable).TableSpec = newNode.(*TableSpec) -} - -func replaceCreateViewColumns(newNode, parent SQLNode) { - parent.(*CreateView).Columns = newNode.(Columns) -} - -func replaceCreateViewSelect(newNode, parent SQLNode) { - parent.(*CreateView).Select = newNode.(SelectStatement) -} - -func replaceCreateViewViewName(newNode, parent SQLNode) { - parent.(*CreateView).ViewName = newNode.(TableName) -} - -func replaceCurTimeFuncExprFsp(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) -} - -func replaceCurTimeFuncExprName(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) -} - -func replaceDeleteComments(newNode, parent SQLNode) { - parent.(*Delete).Comments = newNode.(Comments) -} - -func replaceDeleteLimit(newNode, parent SQLNode) { - parent.(*Delete).Limit = newNode.(*Limit) -} - -func replaceDeleteOrderBy(newNode, parent SQLNode) { - parent.(*Delete).OrderBy = newNode.(OrderBy) -} - -func replaceDeletePartitions(newNode, parent SQLNode) { - parent.(*Delete).Partitions = newNode.(Partitions) -} - -func replaceDeleteTableExprs(newNode, parent SQLNode) { - parent.(*Delete).TableExprs = newNode.(TableExprs) -} - -func replaceDeleteTargets(newNode, parent SQLNode) { - parent.(*Delete).Targets = newNode.(TableNames) -} - -func replaceDeleteWhere(newNode, parent SQLNode) { - parent.(*Delete).Where = newNode.(*Where) -} - -func replaceDerivedTableSelect(newNode, parent SQLNode) { - parent.(*DerivedTable).Select = newNode.(SelectStatement) -} - -func replaceDropColumnName(newNode, parent SQLNode) { - parent.(*DropColumn).Name = newNode.(*ColName) -} - -func replaceDropTableFromTables(newNode, parent SQLNode) { - parent.(*DropTable).FromTables = newNode.(TableNames) -} - -func replaceDropViewFromTables(newNode, parent SQLNode) { - parent.(*DropView).FromTables = newNode.(TableNames) -} - -func replaceExistsExprSubquery(newNode, parent SQLNode) { - parent.(*ExistsExpr).Subquery = newNode.(*Subquery) -} - -func replaceExplainStmtStatement(newNode, parent SQLNode) { - parent.(*ExplainStmt).Statement = newNode.(Statement) -} - -func replaceExplainTabTable(newNode, parent SQLNode) { - parent.(*ExplainTab).Table = newNode.(TableName) -} - -type replaceExprsItems int - -func (r *replaceExprsItems) replace(newNode, container SQLNode) { - container.(Exprs)[int(*r)] = newNode.(Expr) -} - -func (r *replaceExprsItems) inc() { - *r++ -} - -func replaceFlushTableNames(newNode, parent SQLNode) { - parent.(*Flush).TableNames = newNode.(TableNames) -} - -func replaceForeignKeyDefinitionOnDelete(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) -} - -func replaceForeignKeyDefinitionOnUpdate(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) -} - -func replaceForeignKeyDefinitionReferencedColumns(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) -} - -func replaceForeignKeyDefinitionReferencedTable(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) -} - -func replaceForeignKeyDefinitionSource(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).Source = newNode.(Columns) -} - -func replaceFuncExprExprs(newNode, parent SQLNode) { - parent.(*FuncExpr).Exprs = newNode.(SelectExprs) -} - -func replaceFuncExprName(newNode, parent SQLNode) { - parent.(*FuncExpr).Name = newNode.(ColIdent) -} - -func replaceFuncExprQualifier(newNode, parent SQLNode) { - parent.(*FuncExpr).Qualifier = newNode.(TableIdent) -} - -type replaceGroupByItems int - -func (r *replaceGroupByItems) replace(newNode, container SQLNode) { - container.(GroupBy)[int(*r)] = newNode.(Expr) -} - -func (r *replaceGroupByItems) inc() { - *r++ -} - -func replaceGroupConcatExprExprs(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) -} - -func replaceGroupConcatExprLimit(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Limit = newNode.(*Limit) -} - -func replaceGroupConcatExprOrderBy(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) -} - -func replaceIndexDefinitionInfo(newNode, parent SQLNode) { - parent.(*IndexDefinition).Info = newNode.(*IndexInfo) -} - -type replaceIndexHintsIndexes int - -func (r *replaceIndexHintsIndexes) replace(newNode, container SQLNode) { - container.(*IndexHints).Indexes[int(*r)] = newNode.(ColIdent) -} - -func (r *replaceIndexHintsIndexes) inc() { - *r++ -} - -func replaceIndexInfoConstraintName(newNode, parent SQLNode) { - parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) -} - -func replaceIndexInfoName(newNode, parent SQLNode) { - parent.(*IndexInfo).Name = newNode.(ColIdent) -} - -func replaceInsertColumns(newNode, parent SQLNode) { - parent.(*Insert).Columns = newNode.(Columns) -} - -func replaceInsertComments(newNode, parent SQLNode) { - parent.(*Insert).Comments = newNode.(Comments) -} - -func replaceInsertOnDup(newNode, parent SQLNode) { - parent.(*Insert).OnDup = newNode.(OnDup) -} - -func replaceInsertPartitions(newNode, parent SQLNode) { - parent.(*Insert).Partitions = newNode.(Partitions) -} - -func replaceInsertRows(newNode, parent SQLNode) { - parent.(*Insert).Rows = newNode.(InsertRows) -} - -func replaceInsertTable(newNode, parent SQLNode) { - parent.(*Insert).Table = newNode.(TableName) -} - -func replaceIntervalExprExpr(newNode, parent SQLNode) { - parent.(*IntervalExpr).Expr = newNode.(Expr) -} - -func replaceIsExprExpr(newNode, parent SQLNode) { - parent.(*IsExpr).Expr = newNode.(Expr) -} - -func replaceJoinConditionOn(newNode, parent SQLNode) { - tmp := parent.(JoinCondition) - tmp.On = newNode.(Expr) -} - -func replaceJoinConditionUsing(newNode, parent SQLNode) { - tmp := parent.(JoinCondition) - tmp.Using = newNode.(Columns) -} - -func replaceJoinTableExprCondition(newNode, parent SQLNode) { - parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) -} - -func replaceJoinTableExprLeftExpr(newNode, parent SQLNode) { - parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) -} - -func replaceJoinTableExprRightExpr(newNode, parent SQLNode) { - parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) -} - -func replaceLimitOffset(newNode, parent SQLNode) { - parent.(*Limit).Offset = newNode.(Expr) -} - -func replaceLimitRowcount(newNode, parent SQLNode) { - parent.(*Limit).Rowcount = newNode.(Expr) -} - -func replaceMatchExprColumns(newNode, parent SQLNode) { - parent.(*MatchExpr).Columns = newNode.(SelectExprs) -} - -func replaceMatchExprExpr(newNode, parent SQLNode) { - parent.(*MatchExpr).Expr = newNode.(Expr) -} - -func replaceModifyColumnAfter(newNode, parent SQLNode) { - parent.(*ModifyColumn).After = newNode.(*ColName) -} - -func replaceModifyColumnFirst(newNode, parent SQLNode) { - parent.(*ModifyColumn).First = newNode.(*ColName) -} - -func replaceModifyColumnNewColDefinition(newNode, parent SQLNode) { - parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) -} - -func replaceNextvalExpr(newNode, parent SQLNode) { - tmp := parent.(Nextval) - tmp.Expr = newNode.(Expr) -} - -func replaceNotExprExpr(newNode, parent SQLNode) { - parent.(*NotExpr).Expr = newNode.(Expr) -} - -type replaceOnDupItems int - -func (r *replaceOnDupItems) replace(newNode, container SQLNode) { - container.(OnDup)[int(*r)] = newNode.(*UpdateExpr) -} - -func (r *replaceOnDupItems) inc() { - *r++ -} - -func replaceOptLikeLikeTable(newNode, parent SQLNode) { - parent.(*OptLike).LikeTable = newNode.(TableName) -} - -func replaceOrExprLeft(newNode, parent SQLNode) { - parent.(*OrExpr).Left = newNode.(Expr) -} - -func replaceOrExprRight(newNode, parent SQLNode) { - parent.(*OrExpr).Right = newNode.(Expr) -} - -func replaceOrderExpr(newNode, parent SQLNode) { - parent.(*Order).Expr = newNode.(Expr) -} - -type replaceOrderByItems int - -func (r *replaceOrderByItems) replace(newNode, container SQLNode) { - container.(OrderBy)[int(*r)] = newNode.(*Order) -} - -func (r *replaceOrderByItems) inc() { - *r++ -} - -func replaceOrderByOptionCols(newNode, parent SQLNode) { - parent.(*OrderByOption).Cols = newNode.(Columns) -} - -func replaceParenSelectSelect(newNode, parent SQLNode) { - parent.(*ParenSelect).Select = newNode.(SelectStatement) -} - -func replaceParenTableExprExprs(newNode, parent SQLNode) { - parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) -} - -func replacePartitionDefinitionLimit(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Limit = newNode.(Expr) -} - -func replacePartitionDefinitionName(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Name = newNode.(ColIdent) -} - -type replacePartitionSpecDefinitions int - -func (r *replacePartitionSpecDefinitions) replace(newNode, container SQLNode) { - container.(*PartitionSpec).Definitions[int(*r)] = newNode.(*PartitionDefinition) -} - -func (r *replacePartitionSpecDefinitions) inc() { - *r++ -} - -func replacePartitionSpecNames(newNode, parent SQLNode) { - parent.(*PartitionSpec).Names = newNode.(Partitions) -} - -func replacePartitionSpecNumber(newNode, parent SQLNode) { - parent.(*PartitionSpec).Number = newNode.(*Literal) -} - -func replacePartitionSpecTableName(newNode, parent SQLNode) { - parent.(*PartitionSpec).TableName = newNode.(TableName) -} - -type replacePartitionsItems int - -func (r *replacePartitionsItems) replace(newNode, container SQLNode) { - container.(Partitions)[int(*r)] = newNode.(ColIdent) -} - -func (r *replacePartitionsItems) inc() { - *r++ -} - -func replaceRangeCondFrom(newNode, parent SQLNode) { - parent.(*RangeCond).From = newNode.(Expr) -} - -func replaceRangeCondLeft(newNode, parent SQLNode) { - parent.(*RangeCond).Left = newNode.(Expr) -} - -func replaceRangeCondTo(newNode, parent SQLNode) { - parent.(*RangeCond).To = newNode.(Expr) -} - -func replaceReleaseName(newNode, parent SQLNode) { - parent.(*Release).Name = newNode.(ColIdent) -} - -func replaceRenameTableNameTable(newNode, parent SQLNode) { - parent.(*RenameTableName).Table = newNode.(TableName) -} - -func replaceSRollbackName(newNode, parent SQLNode) { - parent.(*SRollback).Name = newNode.(ColIdent) -} - -func replaceSavepointName(newNode, parent SQLNode) { - parent.(*Savepoint).Name = newNode.(ColIdent) -} - -func replaceSelectComments(newNode, parent SQLNode) { - parent.(*Select).Comments = newNode.(Comments) -} - -func replaceSelectFrom(newNode, parent SQLNode) { - parent.(*Select).From = newNode.(TableExprs) -} - -func replaceSelectGroupBy(newNode, parent SQLNode) { - parent.(*Select).GroupBy = newNode.(GroupBy) -} - -func replaceSelectHaving(newNode, parent SQLNode) { - parent.(*Select).Having = newNode.(*Where) -} - -func replaceSelectInto(newNode, parent SQLNode) { - parent.(*Select).Into = newNode.(*SelectInto) -} - -func replaceSelectLimit(newNode, parent SQLNode) { - parent.(*Select).Limit = newNode.(*Limit) -} - -func replaceSelectOrderBy(newNode, parent SQLNode) { - parent.(*Select).OrderBy = newNode.(OrderBy) -} - -func replaceSelectSelectExprs(newNode, parent SQLNode) { - parent.(*Select).SelectExprs = newNode.(SelectExprs) -} - -func replaceSelectWhere(newNode, parent SQLNode) { - parent.(*Select).Where = newNode.(*Where) -} - -type replaceSelectExprsItems int - -func (r *replaceSelectExprsItems) replace(newNode, container SQLNode) { - container.(SelectExprs)[int(*r)] = newNode.(SelectExpr) -} - -func (r *replaceSelectExprsItems) inc() { - *r++ -} - -func replaceSetComments(newNode, parent SQLNode) { - parent.(*Set).Comments = newNode.(Comments) -} - -func replaceSetExprs(newNode, parent SQLNode) { - parent.(*Set).Exprs = newNode.(SetExprs) -} - -func replaceSetExprExpr(newNode, parent SQLNode) { - parent.(*SetExpr).Expr = newNode.(Expr) -} - -func replaceSetExprName(newNode, parent SQLNode) { - parent.(*SetExpr).Name = newNode.(ColIdent) -} - -type replaceSetExprsItems int - -func (r *replaceSetExprsItems) replace(newNode, container SQLNode) { - container.(SetExprs)[int(*r)] = newNode.(*SetExpr) -} - -func (r *replaceSetExprsItems) inc() { - *r++ -} - -type replaceSetTransactionCharacteristics int - -func (r *replaceSetTransactionCharacteristics) replace(newNode, container SQLNode) { - container.(*SetTransaction).Characteristics[int(*r)] = newNode.(Characteristic) -} - -func (r *replaceSetTransactionCharacteristics) inc() { - *r++ -} - -func replaceSetTransactionComments(newNode, parent SQLNode) { - parent.(*SetTransaction).Comments = newNode.(Comments) -} - -func replaceShowInternal(newNode, parent SQLNode) { - parent.(*Show).Internal = newNode.(ShowInternal) -} - -func replaceShowBasicFilter(newNode, parent SQLNode) { - parent.(*ShowBasic).Filter = newNode.(*ShowFilter) -} - -func replaceShowBasicTbl(newNode, parent SQLNode) { - parent.(*ShowBasic).Tbl = newNode.(TableName) -} - -func replaceShowCreateOp(newNode, parent SQLNode) { - parent.(*ShowCreate).Op = newNode.(TableName) -} - -func replaceShowFilterFilter(newNode, parent SQLNode) { - parent.(*ShowFilter).Filter = newNode.(Expr) -} - -func replaceShowLegacyOnTable(newNode, parent SQLNode) { - parent.(*ShowLegacy).OnTable = newNode.(TableName) -} - -func replaceShowLegacyShowCollationFilterOpt(newNode, parent SQLNode) { - parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) -} - -func replaceShowLegacyTable(newNode, parent SQLNode) { - parent.(*ShowLegacy).Table = newNode.(TableName) -} - -func replaceStarExprTableName(newNode, parent SQLNode) { - parent.(*StarExpr).TableName = newNode.(TableName) -} - -func replaceStreamComments(newNode, parent SQLNode) { - parent.(*Stream).Comments = newNode.(Comments) -} - -func replaceStreamSelectExpr(newNode, parent SQLNode) { - parent.(*Stream).SelectExpr = newNode.(SelectExpr) -} - -func replaceStreamTable(newNode, parent SQLNode) { - parent.(*Stream).Table = newNode.(TableName) -} - -func replaceSubquerySelect(newNode, parent SQLNode) { - parent.(*Subquery).Select = newNode.(SelectStatement) -} - -func replaceSubstrExprFrom(newNode, parent SQLNode) { - parent.(*SubstrExpr).From = newNode.(Expr) -} - -func replaceSubstrExprName(newNode, parent SQLNode) { - parent.(*SubstrExpr).Name = newNode.(*ColName) -} - -func replaceSubstrExprStrVal(newNode, parent SQLNode) { - parent.(*SubstrExpr).StrVal = newNode.(*Literal) -} - -func replaceSubstrExprTo(newNode, parent SQLNode) { - parent.(*SubstrExpr).To = newNode.(Expr) -} - -type replaceTableExprsItems int - -func (r *replaceTableExprsItems) replace(newNode, container SQLNode) { - container.(TableExprs)[int(*r)] = newNode.(TableExpr) -} - -func (r *replaceTableExprsItems) inc() { - *r++ -} - -func replaceTableNameName(newNode, parent SQLNode) { - tmp := parent.(TableName) - tmp.Name = newNode.(TableIdent) -} - -func replaceTableNameQualifier(newNode, parent SQLNode) { - tmp := parent.(TableName) - tmp.Qualifier = newNode.(TableIdent) -} - -type replaceTableNamesItems int - -func (r *replaceTableNamesItems) replace(newNode, container SQLNode) { - container.(TableNames)[int(*r)] = newNode.(TableName) -} - -func (r *replaceTableNamesItems) inc() { - *r++ -} - -type replaceTableSpecColumns int - -func (r *replaceTableSpecColumns) replace(newNode, container SQLNode) { - container.(*TableSpec).Columns[int(*r)] = newNode.(*ColumnDefinition) -} - -func (r *replaceTableSpecColumns) inc() { - *r++ -} - -type replaceTableSpecConstraints int - -func (r *replaceTableSpecConstraints) replace(newNode, container SQLNode) { - container.(*TableSpec).Constraints[int(*r)] = newNode.(*ConstraintDefinition) -} - -func (r *replaceTableSpecConstraints) inc() { - *r++ -} - -type replaceTableSpecIndexes int - -func (r *replaceTableSpecIndexes) replace(newNode, container SQLNode) { - container.(*TableSpec).Indexes[int(*r)] = newNode.(*IndexDefinition) -} - -func (r *replaceTableSpecIndexes) inc() { - *r++ -} - -func replaceTableSpecOptions(newNode, parent SQLNode) { - parent.(*TableSpec).Options = newNode.(TableOptions) -} - -func replaceTimestampFuncExprExpr1(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) -} - -func replaceTimestampFuncExprExpr2(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) -} - -func replaceTruncateTableTable(newNode, parent SQLNode) { - parent.(*TruncateTable).Table = newNode.(TableName) -} - -func replaceUnaryExprExpr(newNode, parent SQLNode) { - parent.(*UnaryExpr).Expr = newNode.(Expr) -} - -func replaceUnionFirstStatement(newNode, parent SQLNode) { - parent.(*Union).FirstStatement = newNode.(SelectStatement) -} - -func replaceUnionLimit(newNode, parent SQLNode) { - parent.(*Union).Limit = newNode.(*Limit) -} - -func replaceUnionOrderBy(newNode, parent SQLNode) { - parent.(*Union).OrderBy = newNode.(OrderBy) -} - -type replaceUnionUnionSelects int - -func (r *replaceUnionUnionSelects) replace(newNode, container SQLNode) { - container.(*Union).UnionSelects[int(*r)] = newNode.(*UnionSelect) -} - -func (r *replaceUnionUnionSelects) inc() { - *r++ -} - -func replaceUnionSelectStatement(newNode, parent SQLNode) { - parent.(*UnionSelect).Statement = newNode.(SelectStatement) -} - -func replaceUpdateComments(newNode, parent SQLNode) { - parent.(*Update).Comments = newNode.(Comments) -} - -func replaceUpdateExprs(newNode, parent SQLNode) { - parent.(*Update).Exprs = newNode.(UpdateExprs) -} - -func replaceUpdateLimit(newNode, parent SQLNode) { - parent.(*Update).Limit = newNode.(*Limit) -} - -func replaceUpdateOrderBy(newNode, parent SQLNode) { - parent.(*Update).OrderBy = newNode.(OrderBy) -} - -func replaceUpdateTableExprs(newNode, parent SQLNode) { - parent.(*Update).TableExprs = newNode.(TableExprs) -} - -func replaceUpdateWhere(newNode, parent SQLNode) { - parent.(*Update).Where = newNode.(*Where) -} - -func replaceUpdateExprExpr(newNode, parent SQLNode) { - parent.(*UpdateExpr).Expr = newNode.(Expr) -} - -func replaceUpdateExprName(newNode, parent SQLNode) { - parent.(*UpdateExpr).Name = newNode.(*ColName) -} - -type replaceUpdateExprsItems int - -func (r *replaceUpdateExprsItems) replace(newNode, container SQLNode) { - container.(UpdateExprs)[int(*r)] = newNode.(*UpdateExpr) -} - -func (r *replaceUpdateExprsItems) inc() { - *r++ -} - -func replaceUseDBName(newNode, parent SQLNode) { - parent.(*Use).DBName = newNode.(TableIdent) -} +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at -func replaceVStreamComments(newNode, parent SQLNode) { - parent.(*VStream).Comments = newNode.(Comments) -} - -func replaceVStreamLimit(newNode, parent SQLNode) { - parent.(*VStream).Limit = newNode.(*Limit) -} - -func replaceVStreamSelectExpr(newNode, parent SQLNode) { - parent.(*VStream).SelectExpr = newNode.(SelectExpr) -} - -func replaceVStreamTable(newNode, parent SQLNode) { - parent.(*VStream).Table = newNode.(TableName) -} - -func replaceVStreamWhere(newNode, parent SQLNode) { - parent.(*VStream).Where = newNode.(*Where) -} - -type replaceValTupleItems int - -func (r *replaceValTupleItems) replace(newNode, container SQLNode) { - container.(ValTuple)[int(*r)] = newNode.(Expr) -} - -func (r *replaceValTupleItems) inc() { - *r++ -} - -type replaceValuesItems int - -func (r *replaceValuesItems) replace(newNode, container SQLNode) { - container.(Values)[int(*r)] = newNode.(ValTuple) -} - -func (r *replaceValuesItems) inc() { - *r++ -} - -func replaceValuesFuncExprName(newNode, parent SQLNode) { - parent.(*ValuesFuncExpr).Name = newNode.(*ColName) -} - -func replaceVindexParamKey(newNode, parent SQLNode) { - tmp := parent.(VindexParam) - tmp.Key = newNode.(ColIdent) -} - -func replaceVindexSpecName(newNode, parent SQLNode) { - parent.(*VindexSpec).Name = newNode.(ColIdent) -} - -type replaceVindexSpecParams int - -func (r *replaceVindexSpecParams) replace(newNode, container SQLNode) { - container.(*VindexSpec).Params[int(*r)] = newNode.(VindexParam) -} - -func (r *replaceVindexSpecParams) inc() { - *r++ -} - -func replaceVindexSpecType(newNode, parent SQLNode) { - parent.(*VindexSpec).Type = newNode.(ColIdent) -} + http://www.apache.org/licenses/LICENSE-2.0 -func replaceWhenCond(newNode, parent SQLNode) { - parent.(*When).Cond = newNode.(Expr) -} - -func replaceWhenVal(newNode, parent SQLNode) { - parent.(*When).Val = newNode.(Expr) -} - -func replaceWhereExpr(newNode, parent SQLNode) { - parent.(*Where).Expr = newNode.(Expr) -} - -func replaceXorExprLeft(newNode, parent SQLNode) { - parent.(*XorExpr).Left = newNode.(Expr) -} +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Code generated by ASTHelperGen. DO NOT EDIT. -func replaceXorExprRight(newNode, parent SQLNode) { - parent.(*XorExpr).Right = newNode.(Expr) -} +package sqlparser -// apply is where the visiting happens. Here is where we keep the big switch-case that will be used -// to do the actual visiting of SQLNodes func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { if node == nil || isNilValue(node) { return } - - // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead saved := a.cursor a.cursor.replacer = replacer a.cursor.node = node a.cursor.parent = parent - if a.pre != nil && !a.pre(&a.cursor) { a.cursor = saved return } - - // walk children - // (the order of the cases is alphabetical) switch n := node.(type) { - case nil: - case AccessMode: - case *AddColumns: - a.apply(node, n.After, replaceAddColumnsAfter) - replacerColumns := replaceAddColumnsColumns(0) - replacerColumnsB := &replacerColumns - for _, item := range n.Columns { - a.apply(node, item, replacerColumnsB.replace) - replacerColumnsB.inc() + for x, el := range n.Columns { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*AddColumns).Columns[x] = newNode.(*ColumnDefinition) + }) } - a.apply(node, n.First, replaceAddColumnsFirst) - + a.apply(node, n.First, func(newNode, parent SQLNode) { + parent.(*AddColumns).First = newNode.(*ColName) + }) + a.apply(node, n.After, func(newNode, parent SQLNode) { + parent.(*AddColumns).After = newNode.(*ColName) + }) case *AddConstraintDefinition: - a.apply(node, n.ConstraintDefinition, replaceAddConstraintDefinitionConstraintDefinition) - + a.apply(node, n.ConstraintDefinition, func(newNode, parent SQLNode) { + parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) + }) case *AddIndexDefinition: - a.apply(node, n.IndexDefinition, replaceAddIndexDefinitionIndexDefinition) - - case AlgorithmValue: - + a.apply(node, n.IndexDefinition, func(newNode, parent SQLNode) { + parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) + }) case *AliasedExpr: - a.apply(node, n.As, replaceAliasedExprAs) - a.apply(node, n.Expr, replaceAliasedExprExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).Expr = newNode.(Expr) + }) + a.apply(node, n.As, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).As = newNode.(ColIdent) + }) case *AliasedTableExpr: - a.apply(node, n.As, replaceAliasedTableExprAs) - a.apply(node, n.Expr, replaceAliasedTableExprExpr) - a.apply(node, n.Hints, replaceAliasedTableExprHints) - a.apply(node, n.Partitions, replaceAliasedTableExprPartitions) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) + }) + a.apply(node, n.Partitions, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) + }) + a.apply(node, n.As, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).As = newNode.(TableIdent) + }) + a.apply(node, n.Hints, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) + }) case *AlterCharset: - case *AlterColumn: - a.apply(node, n.Column, replaceAlterColumnColumn) - a.apply(node, n.DefaultVal, replaceAlterColumnDefaultVal) - + a.apply(node, n.Column, func(newNode, parent SQLNode) { + parent.(*AlterColumn).Column = newNode.(*ColName) + }) + a.apply(node, n.DefaultVal, func(newNode, parent SQLNode) { + parent.(*AlterColumn).DefaultVal = newNode.(Expr) + }) case *AlterDatabase: - case *AlterTable: - replacerAlterOptions := replaceAlterTableAlterOptions(0) - replacerAlterOptionsB := &replacerAlterOptions - for _, item := range n.AlterOptions { - a.apply(node, item, replacerAlterOptionsB.replace) - replacerAlterOptionsB.inc() + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*AlterTable).Table = newNode.(TableName) + }) + for x, el := range n.AlterOptions { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*AlterTable).AlterOptions[x] = newNode.(AlterOption) + }) } - a.apply(node, n.PartitionSpec, replaceAlterTablePartitionSpec) - a.apply(node, n.Table, replaceAlterTableTable) - + a.apply(node, n.PartitionSpec, func(newNode, parent SQLNode) { + parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) + }) case *AlterView: - a.apply(node, n.Columns, replaceAlterViewColumns) - a.apply(node, n.Select, replaceAlterViewSelect) - a.apply(node, n.ViewName, replaceAlterViewViewName) - + a.apply(node, n.ViewName, func(newNode, parent SQLNode) { + parent.(*AlterView).ViewName = newNode.(TableName) + }) + a.apply(node, n.Columns, func(newNode, parent SQLNode) { + parent.(*AlterView).Columns = newNode.(Columns) + }) + a.apply(node, n.Select, func(newNode, parent SQLNode) { + parent.(*AlterView).Select = newNode.(SelectStatement) + }) case *AlterVschema: - a.apply(node, n.AutoIncSpec, replaceAlterVschemaAutoIncSpec) - a.apply(node, n.Table, replaceAlterVschemaTable) - replacerVindexCols := replaceAlterVschemaVindexCols(0) - replacerVindexColsB := &replacerVindexCols - for _, item := range n.VindexCols { - a.apply(node, item, replacerVindexColsB.replace) - replacerVindexColsB.inc() + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*AlterVschema).Table = newNode.(TableName) + }) + a.apply(node, n.VindexSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) + }) + for x, el := range n.VindexCols { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*AlterVschema).VindexCols[x] = newNode.(ColIdent) + }) } - a.apply(node, n.VindexSpec, replaceAlterVschemaVindexSpec) - + a.apply(node, n.AutoIncSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) + }) case *AndExpr: - a.apply(node, n.Left, replaceAndExprLeft) - a.apply(node, n.Right, replaceAndExprRight) - + a.apply(node, n.Left, func(newNode, parent SQLNode) { + parent.(*AndExpr).Left = newNode.(Expr) + }) + a.apply(node, n.Right, func(newNode, parent SQLNode) { + parent.(*AndExpr).Right = newNode.(Expr) + }) case Argument: - case *AutoIncSpec: - a.apply(node, n.Column, replaceAutoIncSpecColumn) - a.apply(node, n.Sequence, replaceAutoIncSpecSequence) - + a.apply(node, n.Column, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Column = newNode.(ColIdent) + }) + a.apply(node, n.Sequence, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Sequence = newNode.(TableName) + }) case *Begin: - case *BinaryExpr: - a.apply(node, n.Left, replaceBinaryExprLeft) - a.apply(node, n.Right, replaceBinaryExprRight) - - case BoolVal: - + a.apply(node, n.Left, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Left = newNode.(Expr) + }) + a.apply(node, n.Right, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Right = newNode.(Expr) + }) case *CallProc: - a.apply(node, n.Name, replaceCallProcName) - a.apply(node, n.Params, replaceCallProcParams) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*CallProc).Name = newNode.(TableName) + }) + a.apply(node, n.Params, func(newNode, parent SQLNode) { + parent.(*CallProc).Params = newNode.(Exprs) + }) case *CaseExpr: - a.apply(node, n.Else, replaceCaseExprElse) - a.apply(node, n.Expr, replaceCaseExprExpr) - replacerWhens := replaceCaseExprWhens(0) - replacerWhensB := &replacerWhens - for _, item := range n.Whens { - a.apply(node, item, replacerWhensB.replace) - replacerWhensB.inc() + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Expr = newNode.(Expr) + }) + for x, el := range n.Whens { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*CaseExpr).Whens[x] = newNode.(*When) + }) } - + a.apply(node, n.Else, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Else = newNode.(Expr) + }) case *ChangeColumn: - a.apply(node, n.After, replaceChangeColumnAfter) - a.apply(node, n.First, replaceChangeColumnFirst) - a.apply(node, n.NewColDefinition, replaceChangeColumnNewColDefinition) - a.apply(node, n.OldColumn, replaceChangeColumnOldColumn) - + a.apply(node, n.OldColumn, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).OldColumn = newNode.(*ColName) + }) + a.apply(node, n.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) + }) + a.apply(node, n.First, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).First = newNode.(*ColName) + }) + a.apply(node, n.After, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).After = newNode.(*ColName) + }) case *CheckConstraintDefinition: - a.apply(node, n.Expr, replaceCheckConstraintDefinitionExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) + }) case ColIdent: - case *ColName: - a.apply(node, n.Name, replaceColNameName) - a.apply(node, n.Qualifier, replaceColNameQualifier) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*ColName).Name = newNode.(ColIdent) + }) + a.apply(node, n.Qualifier, func(newNode, parent SQLNode) { + parent.(*ColName).Qualifier = newNode.(TableName) + }) case *CollateExpr: - a.apply(node, n.Expr, replaceCollateExprExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*CollateExpr).Expr = newNode.(Expr) + }) case *ColumnDefinition: - a.apply(node, n.Name, replaceColumnDefinitionName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*ColumnDefinition).Name = newNode.(ColIdent) + }) case *ColumnType: - a.apply(node, n.Length, replaceColumnTypeLength) - a.apply(node, n.Scale, replaceColumnTypeScale) - + a.apply(node, n.Length, func(newNode, parent SQLNode) { + parent.(*ColumnType).Length = newNode.(*Literal) + }) + a.apply(node, n.Scale, func(newNode, parent SQLNode) { + parent.(*ColumnType).Scale = newNode.(*Literal) + }) case Columns: - replacer := replaceColumnsItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(Columns)[x] = newNode.(ColIdent) + }) } - case Comments: - case *Commit: - case *ComparisonExpr: - a.apply(node, n.Escape, replaceComparisonExprEscape) - a.apply(node, n.Left, replaceComparisonExprLeft) - a.apply(node, n.Right, replaceComparisonExprRight) - + a.apply(node, n.Left, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Left = newNode.(Expr) + }) + a.apply(node, n.Right, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Right = newNode.(Expr) + }) + a.apply(node, n.Escape, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Escape = newNode.(Expr) + }) case *ConstraintDefinition: - a.apply(node, n.Details, replaceConstraintDefinitionDetails) - + a.apply(node, n.Details, func(newNode, parent SQLNode) { + parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) + }) case *ConvertExpr: - a.apply(node, n.Expr, replaceConvertExprExpr) - a.apply(node, n.Type, replaceConvertExprType) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Expr = newNode.(Expr) + }) + a.apply(node, n.Type, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Type = newNode.(*ConvertType) + }) case *ConvertType: - a.apply(node, n.Length, replaceConvertTypeLength) - a.apply(node, n.Scale, replaceConvertTypeScale) - + a.apply(node, n.Length, func(newNode, parent SQLNode) { + parent.(*ConvertType).Length = newNode.(*Literal) + }) + a.apply(node, n.Scale, func(newNode, parent SQLNode) { + parent.(*ConvertType).Scale = newNode.(*Literal) + }) case *ConvertUsingExpr: - a.apply(node, n.Expr, replaceConvertUsingExprExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertUsingExpr).Expr = newNode.(Expr) + }) case *CreateDatabase: - case *CreateTable: - a.apply(node, n.OptLike, replaceCreateTableOptLike) - a.apply(node, n.Table, replaceCreateTableTable) - a.apply(node, n.TableSpec, replaceCreateTableTableSpec) - + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*CreateTable).Table = newNode.(TableName) + }) + a.apply(node, n.TableSpec, func(newNode, parent SQLNode) { + parent.(*CreateTable).TableSpec = newNode.(*TableSpec) + }) + a.apply(node, n.OptLike, func(newNode, parent SQLNode) { + parent.(*CreateTable).OptLike = newNode.(*OptLike) + }) case *CreateView: - a.apply(node, n.Columns, replaceCreateViewColumns) - a.apply(node, n.Select, replaceCreateViewSelect) - a.apply(node, n.ViewName, replaceCreateViewViewName) - + a.apply(node, n.ViewName, func(newNode, parent SQLNode) { + parent.(*CreateView).ViewName = newNode.(TableName) + }) + a.apply(node, n.Columns, func(newNode, parent SQLNode) { + parent.(*CreateView).Columns = newNode.(Columns) + }) + a.apply(node, n.Select, func(newNode, parent SQLNode) { + parent.(*CreateView).Select = newNode.(SelectStatement) + }) case *CurTimeFuncExpr: - a.apply(node, n.Fsp, replaceCurTimeFuncExprFsp) - a.apply(node, n.Name, replaceCurTimeFuncExprName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) + }) + a.apply(node, n.Fsp, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) + }) case *Default: - case *Delete: - a.apply(node, n.Comments, replaceDeleteComments) - a.apply(node, n.Limit, replaceDeleteLimit) - a.apply(node, n.OrderBy, replaceDeleteOrderBy) - a.apply(node, n.Partitions, replaceDeletePartitions) - a.apply(node, n.TableExprs, replaceDeleteTableExprs) - a.apply(node, n.Targets, replaceDeleteTargets) - a.apply(node, n.Where, replaceDeleteWhere) - + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*Delete).Comments = newNode.(Comments) + }) + a.apply(node, n.Targets, func(newNode, parent SQLNode) { + parent.(*Delete).Targets = newNode.(TableNames) + }) + a.apply(node, n.TableExprs, func(newNode, parent SQLNode) { + parent.(*Delete).TableExprs = newNode.(TableExprs) + }) + a.apply(node, n.Partitions, func(newNode, parent SQLNode) { + parent.(*Delete).Partitions = newNode.(Partitions) + }) + a.apply(node, n.Where, func(newNode, parent SQLNode) { + parent.(*Delete).Where = newNode.(*Where) + }) + a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { + parent.(*Delete).OrderBy = newNode.(OrderBy) + }) + a.apply(node, n.Limit, func(newNode, parent SQLNode) { + parent.(*Delete).Limit = newNode.(*Limit) + }) case *DerivedTable: - a.apply(node, n.Select, replaceDerivedTableSelect) - + a.apply(node, n.Select, func(newNode, parent SQLNode) { + parent.(*DerivedTable).Select = newNode.(SelectStatement) + }) case *DropColumn: - a.apply(node, n.Name, replaceDropColumnName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*DropColumn).Name = newNode.(*ColName) + }) case *DropDatabase: - case *DropKey: - case *DropTable: - a.apply(node, n.FromTables, replaceDropTableFromTables) - + a.apply(node, n.FromTables, func(newNode, parent SQLNode) { + parent.(*DropTable).FromTables = newNode.(TableNames) + }) case *DropView: - a.apply(node, n.FromTables, replaceDropViewFromTables) - + a.apply(node, n.FromTables, func(newNode, parent SQLNode) { + parent.(*DropView).FromTables = newNode.(TableNames) + }) case *ExistsExpr: - a.apply(node, n.Subquery, replaceExistsExprSubquery) - + a.apply(node, n.Subquery, func(newNode, parent SQLNode) { + parent.(*ExistsExpr).Subquery = newNode.(*Subquery) + }) case *ExplainStmt: - a.apply(node, n.Statement, replaceExplainStmtStatement) - + a.apply(node, n.Statement, func(newNode, parent SQLNode) { + parent.(*ExplainStmt).Statement = newNode.(Statement) + }) case *ExplainTab: - a.apply(node, n.Table, replaceExplainTabTable) - + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*ExplainTab).Table = newNode.(TableName) + }) case Exprs: - replacer := replaceExprsItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(Exprs)[x] = newNode.(Expr) + }) } - case *Flush: - a.apply(node, n.TableNames, replaceFlushTableNames) - + a.apply(node, n.TableNames, func(newNode, parent SQLNode) { + parent.(*Flush).TableNames = newNode.(TableNames) + }) case *Force: - case *ForeignKeyDefinition: - a.apply(node, n.OnDelete, replaceForeignKeyDefinitionOnDelete) - a.apply(node, n.OnUpdate, replaceForeignKeyDefinitionOnUpdate) - a.apply(node, n.ReferencedColumns, replaceForeignKeyDefinitionReferencedColumns) - a.apply(node, n.ReferencedTable, replaceForeignKeyDefinitionReferencedTable) - a.apply(node, n.Source, replaceForeignKeyDefinitionSource) - + a.apply(node, n.Source, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).Source = newNode.(Columns) + }) + a.apply(node, n.ReferencedTable, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) + }) + a.apply(node, n.ReferencedColumns, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) + }) + a.apply(node, n.OnDelete, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) + }) + a.apply(node, n.OnUpdate, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) + }) case *FuncExpr: - a.apply(node, n.Exprs, replaceFuncExprExprs) - a.apply(node, n.Name, replaceFuncExprName) - a.apply(node, n.Qualifier, replaceFuncExprQualifier) - + a.apply(node, n.Qualifier, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Qualifier = newNode.(TableIdent) + }) + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Name = newNode.(ColIdent) + }) + a.apply(node, n.Exprs, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Exprs = newNode.(SelectExprs) + }) case GroupBy: - replacer := replaceGroupByItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(GroupBy)[x] = newNode.(Expr) + }) } - case *GroupConcatExpr: - a.apply(node, n.Exprs, replaceGroupConcatExprExprs) - a.apply(node, n.Limit, replaceGroupConcatExprLimit) - a.apply(node, n.OrderBy, replaceGroupConcatExprOrderBy) - + a.apply(node, n.Exprs, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) + }) + a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) + }) + a.apply(node, n.Limit, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Limit = newNode.(*Limit) + }) case *IndexDefinition: - a.apply(node, n.Info, replaceIndexDefinitionInfo) - + a.apply(node, n.Info, func(newNode, parent SQLNode) { + parent.(*IndexDefinition).Info = newNode.(*IndexInfo) + }) case *IndexHints: - replacerIndexes := replaceIndexHintsIndexes(0) - replacerIndexesB := &replacerIndexes - for _, item := range n.Indexes { - a.apply(node, item, replacerIndexesB.replace) - replacerIndexesB.inc() + for x, el := range n.Indexes { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*IndexHints).Indexes[x] = newNode.(ColIdent) + }) } - case *IndexInfo: - a.apply(node, n.ConstraintName, replaceIndexInfoConstraintName) - a.apply(node, n.Name, replaceIndexInfoName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*IndexInfo).Name = newNode.(ColIdent) + }) + a.apply(node, n.ConstraintName, func(newNode, parent SQLNode) { + parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) + }) case *Insert: - a.apply(node, n.Columns, replaceInsertColumns) - a.apply(node, n.Comments, replaceInsertComments) - a.apply(node, n.OnDup, replaceInsertOnDup) - a.apply(node, n.Partitions, replaceInsertPartitions) - a.apply(node, n.Rows, replaceInsertRows) - a.apply(node, n.Table, replaceInsertTable) - + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*Insert).Comments = newNode.(Comments) + }) + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*Insert).Table = newNode.(TableName) + }) + a.apply(node, n.Partitions, func(newNode, parent SQLNode) { + parent.(*Insert).Partitions = newNode.(Partitions) + }) + a.apply(node, n.Columns, func(newNode, parent SQLNode) { + parent.(*Insert).Columns = newNode.(Columns) + }) + a.apply(node, n.Rows, func(newNode, parent SQLNode) { + parent.(*Insert).Rows = newNode.(InsertRows) + }) + a.apply(node, n.OnDup, func(newNode, parent SQLNode) { + parent.(*Insert).OnDup = newNode.(OnDup) + }) case *IntervalExpr: - a.apply(node, n.Expr, replaceIntervalExprExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*IntervalExpr).Expr = newNode.(Expr) + }) case *IsExpr: - a.apply(node, n.Expr, replaceIsExprExpr) - - case IsolationLevel: - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*IsExpr).Expr = newNode.(Expr) + }) case JoinCondition: - a.apply(node, n.On, replaceJoinConditionOn) - a.apply(node, n.Using, replaceJoinConditionUsing) - + a.apply(node, n.On, replacePanic("JoinCondition On")) + a.apply(node, n.Using, replacePanic("JoinCondition Using")) case *JoinTableExpr: - a.apply(node, n.Condition, replaceJoinTableExprCondition) - a.apply(node, n.LeftExpr, replaceJoinTableExprLeftExpr) - a.apply(node, n.RightExpr, replaceJoinTableExprRightExpr) - + a.apply(node, n.LeftExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) + }) + a.apply(node, n.RightExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) + }) + a.apply(node, n.Condition, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) + }) case *KeyState: - case *Limit: - a.apply(node, n.Offset, replaceLimitOffset) - a.apply(node, n.Rowcount, replaceLimitRowcount) - + a.apply(node, n.Offset, func(newNode, parent SQLNode) { + parent.(*Limit).Offset = newNode.(Expr) + }) + a.apply(node, n.Rowcount, func(newNode, parent SQLNode) { + parent.(*Limit).Rowcount = newNode.(Expr) + }) case ListArg: - case *Literal: - case *Load: - case *LockOption: - case *LockTables: - case *MatchExpr: - a.apply(node, n.Columns, replaceMatchExprColumns) - a.apply(node, n.Expr, replaceMatchExprExpr) - + a.apply(node, n.Columns, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Columns = newNode.(SelectExprs) + }) + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Expr = newNode.(Expr) + }) case *ModifyColumn: - a.apply(node, n.After, replaceModifyColumnAfter) - a.apply(node, n.First, replaceModifyColumnFirst) - a.apply(node, n.NewColDefinition, replaceModifyColumnNewColDefinition) - - case Nextval: - a.apply(node, n.Expr, replaceNextvalExpr) - + a.apply(node, n.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) + }) + a.apply(node, n.First, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).First = newNode.(*ColName) + }) + a.apply(node, n.After, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).After = newNode.(*ColName) + }) + case *Nextval: + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*Nextval).Expr = newNode.(Expr) + }) case *NotExpr: - a.apply(node, n.Expr, replaceNotExprExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*NotExpr).Expr = newNode.(Expr) + }) case *NullVal: - case OnDup: - replacer := replaceOnDupItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(OnDup)[x] = newNode.(*UpdateExpr) + }) } - case *OptLike: - a.apply(node, n.LikeTable, replaceOptLikeLikeTable) - + a.apply(node, n.LikeTable, func(newNode, parent SQLNode) { + parent.(*OptLike).LikeTable = newNode.(TableName) + }) case *OrExpr: - a.apply(node, n.Left, replaceOrExprLeft) - a.apply(node, n.Right, replaceOrExprRight) - + a.apply(node, n.Left, func(newNode, parent SQLNode) { + parent.(*OrExpr).Left = newNode.(Expr) + }) + a.apply(node, n.Right, func(newNode, parent SQLNode) { + parent.(*OrExpr).Right = newNode.(Expr) + }) case *Order: - a.apply(node, n.Expr, replaceOrderExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*Order).Expr = newNode.(Expr) + }) case OrderBy: - replacer := replaceOrderByItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(OrderBy)[x] = newNode.(*Order) + }) } - case *OrderByOption: - a.apply(node, n.Cols, replaceOrderByOptionCols) - + a.apply(node, n.Cols, func(newNode, parent SQLNode) { + parent.(*OrderByOption).Cols = newNode.(Columns) + }) case *OtherAdmin: - case *OtherRead: - case *ParenSelect: - a.apply(node, n.Select, replaceParenSelectSelect) - + a.apply(node, n.Select, func(newNode, parent SQLNode) { + parent.(*ParenSelect).Select = newNode.(SelectStatement) + }) case *ParenTableExpr: - a.apply(node, n.Exprs, replaceParenTableExprExprs) - + a.apply(node, n.Exprs, func(newNode, parent SQLNode) { + parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) + }) case *PartitionDefinition: - a.apply(node, n.Limit, replacePartitionDefinitionLimit) - a.apply(node, n.Name, replacePartitionDefinitionName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Name = newNode.(ColIdent) + }) + a.apply(node, n.Limit, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Limit = newNode.(Expr) + }) case *PartitionSpec: - replacerDefinitions := replacePartitionSpecDefinitions(0) - replacerDefinitionsB := &replacerDefinitions - for _, item := range n.Definitions { - a.apply(node, item, replacerDefinitionsB.replace) - replacerDefinitionsB.inc() + a.apply(node, n.Names, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Names = newNode.(Partitions) + }) + a.apply(node, n.Number, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Number = newNode.(*Literal) + }) + a.apply(node, n.TableName, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).TableName = newNode.(TableName) + }) + for x, el := range n.Definitions { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*PartitionSpec).Definitions[x] = newNode.(*PartitionDefinition) + }) } - a.apply(node, n.Names, replacePartitionSpecNames) - a.apply(node, n.Number, replacePartitionSpecNumber) - a.apply(node, n.TableName, replacePartitionSpecTableName) - case Partitions: - replacer := replacePartitionsItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(Partitions)[x] = newNode.(ColIdent) + }) } - case *RangeCond: - a.apply(node, n.From, replaceRangeCondFrom) - a.apply(node, n.Left, replaceRangeCondLeft) - a.apply(node, n.To, replaceRangeCondTo) - - case ReferenceAction: - + a.apply(node, n.Left, func(newNode, parent SQLNode) { + parent.(*RangeCond).Left = newNode.(Expr) + }) + a.apply(node, n.From, func(newNode, parent SQLNode) { + parent.(*RangeCond).From = newNode.(Expr) + }) + a.apply(node, n.To, func(newNode, parent SQLNode) { + parent.(*RangeCond).To = newNode.(Expr) + }) case *Release: - a.apply(node, n.Name, replaceReleaseName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*Release).Name = newNode.(ColIdent) + }) case *RenameIndex: - case *RenameTable: - case *RenameTableName: - a.apply(node, n.Table, replaceRenameTableNameTable) - + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*RenameTableName).Table = newNode.(TableName) + }) case *Rollback: - case *SRollback: - a.apply(node, n.Name, replaceSRollbackName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*SRollback).Name = newNode.(ColIdent) + }) case *Savepoint: - a.apply(node, n.Name, replaceSavepointName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*Savepoint).Name = newNode.(ColIdent) + }) case *Select: - a.apply(node, n.Comments, replaceSelectComments) - a.apply(node, n.From, replaceSelectFrom) - a.apply(node, n.GroupBy, replaceSelectGroupBy) - a.apply(node, n.Having, replaceSelectHaving) - a.apply(node, n.Into, replaceSelectInto) - a.apply(node, n.Limit, replaceSelectLimit) - a.apply(node, n.OrderBy, replaceSelectOrderBy) - a.apply(node, n.SelectExprs, replaceSelectSelectExprs) - a.apply(node, n.Where, replaceSelectWhere) - + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*Select).Comments = newNode.(Comments) + }) + a.apply(node, n.SelectExprs, func(newNode, parent SQLNode) { + parent.(*Select).SelectExprs = newNode.(SelectExprs) + }) + a.apply(node, n.From, func(newNode, parent SQLNode) { + parent.(*Select).From = newNode.(TableExprs) + }) + a.apply(node, n.Where, func(newNode, parent SQLNode) { + parent.(*Select).Where = newNode.(*Where) + }) + a.apply(node, n.GroupBy, func(newNode, parent SQLNode) { + parent.(*Select).GroupBy = newNode.(GroupBy) + }) + a.apply(node, n.Having, func(newNode, parent SQLNode) { + parent.(*Select).Having = newNode.(*Where) + }) + a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { + parent.(*Select).OrderBy = newNode.(OrderBy) + }) + a.apply(node, n.Limit, func(newNode, parent SQLNode) { + parent.(*Select).Limit = newNode.(*Limit) + }) + a.apply(node, n.Into, func(newNode, parent SQLNode) { + parent.(*Select).Into = newNode.(*SelectInto) + }) case SelectExprs: - replacer := replaceSelectExprsItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(SelectExprs)[x] = newNode.(SelectExpr) + }) } - case *SelectInto: - case *Set: - a.apply(node, n.Comments, replaceSetComments) - a.apply(node, n.Exprs, replaceSetExprs) - + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*Set).Comments = newNode.(Comments) + }) + a.apply(node, n.Exprs, func(newNode, parent SQLNode) { + parent.(*Set).Exprs = newNode.(SetExprs) + }) case *SetExpr: - a.apply(node, n.Expr, replaceSetExprExpr) - a.apply(node, n.Name, replaceSetExprName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*SetExpr).Name = newNode.(ColIdent) + }) + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*SetExpr).Expr = newNode.(Expr) + }) case SetExprs: - replacer := replaceSetExprsItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(SetExprs)[x] = newNode.(*SetExpr) + }) } - case *SetTransaction: - replacerCharacteristics := replaceSetTransactionCharacteristics(0) - replacerCharacteristicsB := &replacerCharacteristics - for _, item := range n.Characteristics { - a.apply(node, item, replacerCharacteristicsB.replace) - replacerCharacteristicsB.inc() + a.apply(node, n.SQLNode, func(newNode, parent SQLNode) { + parent.(*SetTransaction).SQLNode = newNode.(SQLNode) + }) + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*SetTransaction).Comments = newNode.(Comments) + }) + for x, el := range n.Characteristics { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*SetTransaction).Characteristics[x] = newNode.(Characteristic) + }) } - a.apply(node, n.Comments, replaceSetTransactionComments) - case *Show: - a.apply(node, n.Internal, replaceShowInternal) - + a.apply(node, n.Internal, func(newNode, parent SQLNode) { + parent.(*Show).Internal = newNode.(ShowInternal) + }) case *ShowBasic: - a.apply(node, n.Filter, replaceShowBasicFilter) - a.apply(node, n.Tbl, replaceShowBasicTbl) - + a.apply(node, n.Tbl, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Tbl = newNode.(TableName) + }) + a.apply(node, n.Filter, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Filter = newNode.(*ShowFilter) + }) case *ShowCreate: - a.apply(node, n.Op, replaceShowCreateOp) - + a.apply(node, n.Op, func(newNode, parent SQLNode) { + parent.(*ShowCreate).Op = newNode.(TableName) + }) case *ShowFilter: - a.apply(node, n.Filter, replaceShowFilterFilter) - + a.apply(node, n.Filter, func(newNode, parent SQLNode) { + parent.(*ShowFilter).Filter = newNode.(Expr) + }) case *ShowLegacy: - a.apply(node, n.OnTable, replaceShowLegacyOnTable) - a.apply(node, n.ShowCollationFilterOpt, replaceShowLegacyShowCollationFilterOpt) - a.apply(node, n.Table, replaceShowLegacyTable) - + a.apply(node, n.OnTable, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).OnTable = newNode.(TableName) + }) + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).Table = newNode.(TableName) + }) + a.apply(node, n.ShowCollationFilterOpt, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) + }) case *StarExpr: - a.apply(node, n.TableName, replaceStarExprTableName) - + a.apply(node, n.TableName, func(newNode, parent SQLNode) { + parent.(*StarExpr).TableName = newNode.(TableName) + }) case *Stream: - a.apply(node, n.Comments, replaceStreamComments) - a.apply(node, n.SelectExpr, replaceStreamSelectExpr) - a.apply(node, n.Table, replaceStreamTable) - + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*Stream).Comments = newNode.(Comments) + }) + a.apply(node, n.SelectExpr, func(newNode, parent SQLNode) { + parent.(*Stream).SelectExpr = newNode.(SelectExpr) + }) + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*Stream).Table = newNode.(TableName) + }) case *Subquery: - a.apply(node, n.Select, replaceSubquerySelect) - + a.apply(node, n.Select, func(newNode, parent SQLNode) { + parent.(*Subquery).Select = newNode.(SelectStatement) + }) case *SubstrExpr: - a.apply(node, n.From, replaceSubstrExprFrom) - a.apply(node, n.Name, replaceSubstrExprName) - a.apply(node, n.StrVal, replaceSubstrExprStrVal) - a.apply(node, n.To, replaceSubstrExprTo) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).Name = newNode.(*ColName) + }) + a.apply(node, n.StrVal, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).StrVal = newNode.(*Literal) + }) + a.apply(node, n.From, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).From = newNode.(Expr) + }) + a.apply(node, n.To, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).To = newNode.(Expr) + }) case TableExprs: - replacer := replaceTableExprsItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(TableExprs)[x] = newNode.(TableExpr) + }) } - case TableIdent: - case TableName: - a.apply(node, n.Name, replaceTableNameName) - a.apply(node, n.Qualifier, replaceTableNameQualifier) - + a.apply(node, n.Name, replacePanic("TableName Name")) + a.apply(node, n.Qualifier, replacePanic("TableName Qualifier")) case TableNames: - replacer := replaceTableNamesItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(TableNames)[x] = newNode.(TableName) + }) } - case TableOptions: - case *TableSpec: - replacerColumns := replaceTableSpecColumns(0) - replacerColumnsB := &replacerColumns - for _, item := range n.Columns { - a.apply(node, item, replacerColumnsB.replace) - replacerColumnsB.inc() + for x, el := range n.Columns { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*TableSpec).Columns[x] = newNode.(*ColumnDefinition) + }) } - replacerConstraints := replaceTableSpecConstraints(0) - replacerConstraintsB := &replacerConstraints - for _, item := range n.Constraints { - a.apply(node, item, replacerConstraintsB.replace) - replacerConstraintsB.inc() + for x, el := range n.Indexes { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*TableSpec).Indexes[x] = newNode.(*IndexDefinition) + }) } - replacerIndexes := replaceTableSpecIndexes(0) - replacerIndexesB := &replacerIndexes - for _, item := range n.Indexes { - a.apply(node, item, replacerIndexesB.replace) - replacerIndexesB.inc() + for x, el := range n.Constraints { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*TableSpec).Constraints[x] = newNode.(*ConstraintDefinition) + }) } - a.apply(node, n.Options, replaceTableSpecOptions) - + a.apply(node, n.Options, func(newNode, parent SQLNode) { + parent.(*TableSpec).Options = newNode.(TableOptions) + }) case *TablespaceOperation: - case *TimestampFuncExpr: - a.apply(node, n.Expr1, replaceTimestampFuncExprExpr1) - a.apply(node, n.Expr2, replaceTimestampFuncExprExpr2) - + a.apply(node, n.Expr1, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) + }) + a.apply(node, n.Expr2, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) + }) case *TruncateTable: - a.apply(node, n.Table, replaceTruncateTableTable) - + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*TruncateTable).Table = newNode.(TableName) + }) case *UnaryExpr: - a.apply(node, n.Expr, replaceUnaryExprExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*UnaryExpr).Expr = newNode.(Expr) + }) case *Union: - a.apply(node, n.FirstStatement, replaceUnionFirstStatement) - a.apply(node, n.Limit, replaceUnionLimit) - a.apply(node, n.OrderBy, replaceUnionOrderBy) - replacerUnionSelects := replaceUnionUnionSelects(0) - replacerUnionSelectsB := &replacerUnionSelects - for _, item := range n.UnionSelects { - a.apply(node, item, replacerUnionSelectsB.replace) - replacerUnionSelectsB.inc() + a.apply(node, n.FirstStatement, func(newNode, parent SQLNode) { + parent.(*Union).FirstStatement = newNode.(SelectStatement) + }) + for x, el := range n.UnionSelects { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*Union).UnionSelects[x] = newNode.(*UnionSelect) + }) } - + a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { + parent.(*Union).OrderBy = newNode.(OrderBy) + }) + a.apply(node, n.Limit, func(newNode, parent SQLNode) { + parent.(*Union).Limit = newNode.(*Limit) + }) case *UnionSelect: - a.apply(node, n.Statement, replaceUnionSelectStatement) - + a.apply(node, n.Statement, func(newNode, parent SQLNode) { + parent.(*UnionSelect).Statement = newNode.(SelectStatement) + }) case *UnlockTables: - case *Update: - a.apply(node, n.Comments, replaceUpdateComments) - a.apply(node, n.Exprs, replaceUpdateExprs) - a.apply(node, n.Limit, replaceUpdateLimit) - a.apply(node, n.OrderBy, replaceUpdateOrderBy) - a.apply(node, n.TableExprs, replaceUpdateTableExprs) - a.apply(node, n.Where, replaceUpdateWhere) - + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*Update).Comments = newNode.(Comments) + }) + a.apply(node, n.TableExprs, func(newNode, parent SQLNode) { + parent.(*Update).TableExprs = newNode.(TableExprs) + }) + a.apply(node, n.Exprs, func(newNode, parent SQLNode) { + parent.(*Update).Exprs = newNode.(UpdateExprs) + }) + a.apply(node, n.Where, func(newNode, parent SQLNode) { + parent.(*Update).Where = newNode.(*Where) + }) + a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { + parent.(*Update).OrderBy = newNode.(OrderBy) + }) + a.apply(node, n.Limit, func(newNode, parent SQLNode) { + parent.(*Update).Limit = newNode.(*Limit) + }) case *UpdateExpr: - a.apply(node, n.Expr, replaceUpdateExprExpr) - a.apply(node, n.Name, replaceUpdateExprName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Name = newNode.(*ColName) + }) + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Expr = newNode.(Expr) + }) case UpdateExprs: - replacer := replaceUpdateExprsItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(UpdateExprs)[x] = newNode.(*UpdateExpr) + }) } - case *Use: - a.apply(node, n.DBName, replaceUseDBName) - + a.apply(node, n.DBName, func(newNode, parent SQLNode) { + parent.(*Use).DBName = newNode.(TableIdent) + }) case *VStream: - a.apply(node, n.Comments, replaceVStreamComments) - a.apply(node, n.Limit, replaceVStreamLimit) - a.apply(node, n.SelectExpr, replaceVStreamSelectExpr) - a.apply(node, n.Table, replaceVStreamTable) - a.apply(node, n.Where, replaceVStreamWhere) - + a.apply(node, n.Comments, func(newNode, parent SQLNode) { + parent.(*VStream).Comments = newNode.(Comments) + }) + a.apply(node, n.SelectExpr, func(newNode, parent SQLNode) { + parent.(*VStream).SelectExpr = newNode.(SelectExpr) + }) + a.apply(node, n.Table, func(newNode, parent SQLNode) { + parent.(*VStream).Table = newNode.(TableName) + }) + a.apply(node, n.Where, func(newNode, parent SQLNode) { + parent.(*VStream).Where = newNode.(*Where) + }) + a.apply(node, n.Limit, func(newNode, parent SQLNode) { + parent.(*VStream).Limit = newNode.(*Limit) + }) case ValTuple: - replacer := replaceValTupleItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(ValTuple)[x] = newNode.(Expr) + }) } - case *Validation: - case Values: - replacer := replaceValuesItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() + for x, el := range n { + a.apply(node, el, func(newNode, container SQLNode) { + container.(Values)[x] = newNode.(ValTuple) + }) } - case *ValuesFuncExpr: - a.apply(node, n.Name, replaceValuesFuncExprName) - + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*ValuesFuncExpr).Name = newNode.(*ColName) + }) case VindexParam: - a.apply(node, n.Key, replaceVindexParamKey) - + a.apply(node, n.Key, replacePanic("VindexParam Key")) case *VindexSpec: - a.apply(node, n.Name, replaceVindexSpecName) - replacerParams := replaceVindexSpecParams(0) - replacerParamsB := &replacerParams - for _, item := range n.Params { - a.apply(node, item, replacerParamsB.replace) - replacerParamsB.inc() + a.apply(node, n.Name, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Name = newNode.(ColIdent) + }) + a.apply(node, n.Type, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Type = newNode.(ColIdent) + }) + for x, el := range n.Params { + a.apply(node, el, func(newNode, container SQLNode) { + container.(*VindexSpec).Params[x] = newNode.(VindexParam) + }) } - a.apply(node, n.Type, replaceVindexSpecType) - case *When: - a.apply(node, n.Cond, replaceWhenCond) - a.apply(node, n.Val, replaceWhenVal) - + a.apply(node, n.Cond, func(newNode, parent SQLNode) { + parent.(*When).Cond = newNode.(Expr) + }) + a.apply(node, n.Val, func(newNode, parent SQLNode) { + parent.(*When).Val = newNode.(Expr) + }) case *Where: - a.apply(node, n.Expr, replaceWhereExpr) - + a.apply(node, n.Expr, func(newNode, parent SQLNode) { + parent.(*Where).Expr = newNode.(Expr) + }) case *XorExpr: - a.apply(node, n.Left, replaceXorExprLeft) - a.apply(node, n.Right, replaceXorExprRight) - - default: - panic("unknown ast type " + reflect.TypeOf(node).String()) + a.apply(node, n.Left, func(newNode, parent SQLNode) { + parent.(*XorExpr).Left = newNode.(Expr) + }) + a.apply(node, n.Right, func(newNode, parent SQLNode) { + parent.(*XorExpr).Right = newNode.(Expr) + }) } - if a.post != nil && !a.post(&a.cursor) { panic(abort) } - a.cursor = saved } - -func isNilValue(i interface{}) bool { - valueOf := reflect.ValueOf(i) - kind := valueOf.Kind() - isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice - return isNullable && valueOf.IsNil() -} diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index 47c85e0473b..ea25e67b1d6 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -16,6 +16,11 @@ limitations under the License. package sqlparser +import ( + "reflect" + "runtime" +) + // The rewriter was heavily inspired by https://github.com/golang/tools/blob/master/go/ast/astutil/rewrite.go // Rewrite traverses a syntax tree recursively, starting with root, @@ -34,11 +39,20 @@ package sqlparser // Only fields that refer to AST nodes are considered children; // i.e., fields of basic types (strings, []byte, etc.) are ignored. // -func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { +func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode, err error) { parent := &struct{ SQLNode }{node} defer func() { - if r := recover(); r != nil && r != abort { - panic(r) + if r := recover(); r != nil { + switch r := r.(type) { + case abortT: // nothing to do + + case *runtime.TypeAssertionError: + err = r + case *valueTypeFieldCantChangeErr: + err = r + default: + panic(r) + } } result = parent.SQLNode }() @@ -56,7 +70,7 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { a.apply(parent, node, replacer) - return parent.SQLNode + return parent.SQLNode, nil } // An ApplyFunc is invoked by Rewrite for each node n, even if n is nil, @@ -67,7 +81,9 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { // See Rewrite for details. type ApplyFunc func(*Cursor) bool -var abort = new(int) // singleton, to signal termination of Apply +type abortT int + +var abort = abortT(0) // singleton, to signal termination of Apply // A Cursor describes a node encountered during Apply. // Information about the node and its parent is available @@ -90,3 +106,34 @@ func (c *Cursor) Replace(newNode SQLNode) { c.replacer(newNode, c.parent) c.node = newNode } + +type replacerFunc func(newNode, parent SQLNode) + +// application carries all the shared data so we can pass it around cheaply. +type application struct { + pre, post ApplyFunc + cursor Cursor +} + +func isNilValue(i interface{}) bool { + valueOf := reflect.ValueOf(i) + kind := valueOf.Kind() + isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice + return isNullable && valueOf.IsNil() +} + +// this type is here so we can catch it in the Rewrite method above +type valueTypeFieldCantChangeErr struct { + msg string +} + +// Error implements the error interface +func (e *valueTypeFieldCantChangeErr) Error() string { + return "Tried replacing a field of a value type. This is not supported. " + e.msg +} + +func replacePanic(msg string) func(newNode, parent SQLNode) { + return func(newNode, parent SQLNode) { + panic(&valueTypeFieldCantChangeErr{msg: msg}) + } +} diff --git a/go/vt/sqlparser/rewriter_test.go b/go/vt/sqlparser/rewriter_test.go new file mode 100644 index 00000000000..6131c6c5588 --- /dev/null +++ b/go/vt/sqlparser/rewriter_test.go @@ -0,0 +1,67 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkVisitLargeExpression(b *testing.B) { + gen := newGenerator(1, 5) + exp := gen.expression() + + depth := 0 + for i := 0; i < b.N; i++ { + _, err := Rewrite(exp, func(cursor *Cursor) bool { + depth++ + return true + }, func(cursor *Cursor) bool { + depth-- + return true + }) + require.NoError(b, err) + } +} + +func TestBadTypeReturnsErrorAndNotPanic(t *testing.T) { + parse, err := Parse("select 42 from dual") + require.NoError(t, err) + _, err = Rewrite(parse, func(cursor *Cursor) bool { + _, ok := cursor.Node().(*Literal) + if ok { + cursor.Replace(&AliasedTableExpr{}) // this is not a valid replacement because of types + } + return true + }, nil) + require.Error(t, err) +} + +func TestChangeValueTypeGivesError(t *testing.T) { + parse, err := Parse("select * from a join b on a.id = b.id") + require.NoError(t, err) + _, err = Rewrite(parse, func(cursor *Cursor) bool { + _, ok := cursor.Node().(*ComparisonExpr) + if ok { + cursor.Replace(&NullVal{}) // this is not a valid replacement because the container is a value type + } + return true + }, nil) + require.Error(t, err) + +} diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 84e70012044..b18682673a8 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -5488,7 +5488,7 @@ yydefault: yyDollar = yyS[yypt-7 : yypt+1] //line sql.y:509 { - yyVAL.selStmt = NewSelect(Comments(yyDollar[2].bytes2), SelectExprs{Nextval{Expr: yyDollar[5].expr}}, []string{yyDollar[3].str} /*options*/, TableExprs{&AliasedTableExpr{Expr: yyDollar[7].tableName}}, nil /*where*/, nil /*groupBy*/, nil /*having*/) + yyVAL.selStmt = NewSelect(Comments(yyDollar[2].bytes2), SelectExprs{&Nextval{Expr: yyDollar[5].expr}}, []string{yyDollar[3].str} /*options*/, TableExprs{&AliasedTableExpr{Expr: yyDollar[7].tableName}}, nil /*where*/, nil /*groupBy*/, nil /*having*/) } case 45: yyDollar = yyS[yypt-4 : yypt+1] diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index c06168310fb..df99cb0dec4 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -507,7 +507,7 @@ select_statement: } | SELECT comment_opt cache_opt NEXT num_val for_from table_name { - $$ = NewSelect(Comments($2), SelectExprs{Nextval{Expr: $5}}, []string{$3}/*options*/, TableExprs{&AliasedTableExpr{Expr: $7}}, nil/*where*/, nil/*groupBy*/, nil/*having*/) + $$ = NewSelect(Comments($2), SelectExprs{&Nextval{Expr: $5}}, []string{$3}/*options*/, TableExprs{&AliasedTableExpr{Expr: $7}}, nil/*where*/, nil/*groupBy*/, nil/*having*/) } // simple_select is an unparenthesized select used for subquery. diff --git a/go/vt/sqlparser/utils.go b/go/vt/sqlparser/utils.go index 1de7833a58e..983faaec22b 100644 --- a/go/vt/sqlparser/utils.go +++ b/go/vt/sqlparser/utils.go @@ -40,7 +40,10 @@ func QueryMatchesTemplates(query string, queryTemplates []string) (match bool, e if err != nil { return "", err } - Normalize(stmt, bv, "") + err = Normalize(stmt, bv, "") + if err != nil { + return "", err + } normalized := String(stmt) return normalized, nil } diff --git a/go/vt/sqlparser/visitorgen/ast_walker.go b/go/vt/sqlparser/visitorgen/ast_walker.go deleted file mode 100644 index 822fb6c4c5e..00000000000 --- a/go/vt/sqlparser/visitorgen/ast_walker.go +++ /dev/null @@ -1,130 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import ( - "go/ast" - "reflect" -) - -var _ ast.Visitor = (*walker)(nil) - -type walker struct { - result SourceFile -} - -// Walk walks the given AST and translates it to the simplified AST used by the next steps -func Walk(node ast.Node) *SourceFile { - var w walker - ast.Walk(&w, node) - return &w.result -} - -// Visit implements the ast.Visitor interface -func (w *walker) Visit(node ast.Node) ast.Visitor { - switch n := node.(type) { - case *ast.TypeSpec: - switch t2 := n.Type.(type) { - case *ast.InterfaceType: - w.append(&InterfaceDeclaration{ - name: n.Name.Name, - block: "", - }) - case *ast.StructType: - var fields []*Field - for _, f := range t2.Fields.List { - for _, name := range f.Names { - fields = append(fields, &Field{ - name: name.Name, - typ: sastType(f.Type), - }) - } - - } - w.append(&StructDeclaration{ - name: n.Name.Name, - fields: fields, - }) - case *ast.ArrayType: - w.append(&TypeAlias{ - name: n.Name.Name, - typ: &Array{inner: sastType(t2.Elt)}, - }) - case *ast.Ident: - w.append(&TypeAlias{ - name: n.Name.Name, - typ: &TypeString{t2.Name}, - }) - - default: - panic(reflect.TypeOf(t2)) - } - case *ast.FuncDecl: - if len(n.Recv.List) > 1 || len(n.Recv.List[0].Names) > 1 { - panic("don't know what to do!") - } - var f *Field - if len(n.Recv.List) == 1 { - r := n.Recv.List[0] - t := sastType(r.Type) - if len(r.Names) > 1 { - panic("don't know what to do!") - } - if len(r.Names) == 1 { - f = &Field{ - name: r.Names[0].Name, - typ: t, - } - } else { - f = &Field{ - name: "", - typ: t, - } - } - } - - w.append(&FuncDeclaration{ - receiver: f, - name: n.Name.Name, - block: "", - arguments: nil, - }) - } - - return w -} - -func (w *walker) append(line Sast) { - w.result.lines = append(w.result.lines, line) -} - -func sastType(e ast.Expr) Type { - switch n := e.(type) { - case *ast.StarExpr: - return &Ref{sastType(n.X)} - case *ast.Ident: - return &TypeString{n.Name} - case *ast.ArrayType: - return &Array{inner: sastType(n.Elt)} - case *ast.InterfaceType: - return &TypeString{"interface{}"} - case *ast.StructType: - return &TypeString{"struct{}"} - } - - panic(reflect.TypeOf(e)) -} diff --git a/go/vt/sqlparser/visitorgen/ast_walker_test.go b/go/vt/sqlparser/visitorgen/ast_walker_test.go deleted file mode 100644 index a4b01f70835..00000000000 --- a/go/vt/sqlparser/visitorgen/ast_walker_test.go +++ /dev/null @@ -1,239 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import ( - "go/parser" - "go/token" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/stretchr/testify/require" -) - -func TestSingleInterface(t *testing.T) { - input := ` -package sqlparser - -type Nodeiface interface { - iNode() -} -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{&InterfaceDeclaration{ - name: "Nodeiface", - block: "", - }}, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestEmptyStruct(t *testing.T) { - input := ` -package sqlparser - -type Empty struct {} -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{&StructDeclaration{ - name: "Empty", - fields: []*Field{}, - }}, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithStringField(t *testing.T) { - input := ` -package sqlparser - -type Struct struct { - field string -} -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{&StructDeclaration{ - name: "Struct", - fields: []*Field{{ - name: "field", - typ: &TypeString{typName: "string"}, - }}, - }}, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithDifferentTypes(t *testing.T) { - input := ` -package sqlparser - -type Struct struct { - field string - reference *string - array []string - arrayOfRef []*string -} -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{&StructDeclaration{ - name: "Struct", - fields: []*Field{{ - name: "field", - typ: &TypeString{typName: "string"}, - }, { - name: "reference", - typ: &Ref{&TypeString{typName: "string"}}, - }, { - name: "array", - typ: &Array{&TypeString{typName: "string"}}, - }, { - name: "arrayOfRef", - typ: &Array{&Ref{&TypeString{typName: "string"}}}, - }}, - }}, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithTwoStringFieldInOneLine(t *testing.T) { - input := ` -package sqlparser - -type Struct struct { - left, right string -} -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{&StructDeclaration{ - name: "Struct", - fields: []*Field{{ - name: "left", - typ: &TypeString{typName: "string"}, - }, { - name: "right", - typ: &TypeString{typName: "string"}, - }}, - }}, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithSingleMethod(t *testing.T) { - input := ` -package sqlparser - -type Empty struct {} - -func (*Empty) method() {} -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{ - &StructDeclaration{ - name: "Empty", - fields: []*Field{}}, - &FuncDeclaration{ - receiver: &Field{ - name: "", - typ: &Ref{&TypeString{"Empty"}}, - }, - name: "method", - block: "", - arguments: []*Field{}, - }, - }, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestSingleArrayType(t *testing.T) { - input := ` -package sqlparser - -type Strings []string -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{&TypeAlias{ - name: "Strings", - typ: &Array{&TypeString{"string"}}, - }}, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestSingleTypeAlias(t *testing.T) { - input := ` -package sqlparser - -type String string -` - - fset := token.NewFileSet() - ast, err := parser.ParseFile(fset, "ast.go", input, 0) - require.NoError(t, err) - - result := Walk(ast) - expected := SourceFile{ - lines: []Sast{&TypeAlias{ - name: "String", - typ: &TypeString{"string"}, - }}, - } - assert.Equal(t, expected.String(), result.String()) -} diff --git a/go/vt/sqlparser/visitorgen/main/main.go b/go/vt/sqlparser/visitorgen/main/main.go deleted file mode 100644 index 0d940ea060f..00000000000 --- a/go/vt/sqlparser/visitorgen/main/main.go +++ /dev/null @@ -1,164 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "bytes" - "flag" - "fmt" - "go/parser" - "go/token" - "io/ioutil" - "os" - - "vitess.io/vitess/go/exit" - "vitess.io/vitess/go/vt/log" - - "vitess.io/vitess/go/vt/sqlparser/visitorgen" -) - -var ( - inputFile = flag.String("input", "", "input file to use") - outputFile = flag.String("output", "", "output file") - compare = flag.Bool("compareOnly", false, "instead of writing to the output file, compare if the generated visitor is still valid for this ast.go") -) - -const usage = `Usage of visitorgen: - -go run /path/to/visitorgen/main -input=/path/to/ast.go -output=/path/to/rewriter.go -` - -func main() { - defer exit.Recover() - flag.Usage = printUsage - flag.Parse() - - if *inputFile == "" || *outputFile == "" { - printUsage() - exit.Return(1) - } - - fs := token.NewFileSet() - file, err := parser.ParseFile(fs, *inputFile, nil, parser.DeclarationErrors) - if err != nil { - log.Error(err) - exit.Return(1) - } - - astWalkResult := visitorgen.Walk(file) - vp := visitorgen.Transform(astWalkResult) - vd := visitorgen.ToVisitorPlan(vp) - - replacementMethods := visitorgen.EmitReplacementMethods(vd) - typeSwitch := visitorgen.EmitTypeSwitches(vd) - - b := &bytes.Buffer{} - fmt.Fprint(b, fileHeader) - fmt.Fprintln(b) - fmt.Fprintln(b, replacementMethods) - fmt.Fprint(b, applyHeader) - fmt.Fprintln(b, typeSwitch) - fmt.Fprintln(b, fileFooter) - - if *compare { - currentFile, err := ioutil.ReadFile(*outputFile) - if err != nil { - log.Error(err) - exit.Return(1) - } - if !bytes.Equal(b.Bytes(), currentFile) { - fmt.Println("rewriter needs to be re-generated: go generate " + *outputFile) - exit.Return(1) - } - } else { - err = ioutil.WriteFile(*outputFile, b.Bytes(), 0644) - if err != nil { - log.Error(err) - exit.Return(1) - } - } - -} - -func printUsage() { - os.Stderr.WriteString(usage) - os.Stderr.WriteString("\nOptions:\n") - flag.PrintDefaults() -} - -const fileHeader = `// Code generated by visitorgen/main/main.go. DO NOT EDIT. - -package sqlparser - -//go:generate go run ./visitorgen/main -input=ast.go -output=rewriter.go - -import ( - "reflect" -) - -type replacerFunc func(newNode, parent SQLNode) - -// application carries all the shared data so we can pass it around cheaply. -type application struct { - pre, post ApplyFunc - cursor Cursor -} -` - -const applyHeader = ` -// apply is where the visiting happens. Here is where we keep the big switch-case that will be used -// to do the actual visiting of SQLNodes -func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { - if node == nil || isNilValue(node) { - return - } - - // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead - saved := a.cursor - a.cursor.replacer = replacer - a.cursor.node = node - a.cursor.parent = parent - - if a.pre != nil && !a.pre(&a.cursor) { - a.cursor = saved - return - } - - // walk children - // (the order of the cases is alphabetical) - switch n := node.(type) { - case nil: - ` - -const fileFooter = ` - default: - panic("unknown ast type " + reflect.TypeOf(node).String()) - } - - if a.post != nil && !a.post(&a.cursor) { - panic(abort) - } - - a.cursor = saved -} - -func isNilValue(i interface{}) bool { - valueOf := reflect.ValueOf(i) - kind := valueOf.Kind() - isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice - return isNullable && valueOf.IsNil() -}` diff --git a/go/vt/sqlparser/visitorgen/sast.go b/go/vt/sqlparser/visitorgen/sast.go deleted file mode 100644 index e46485e8f5d..00000000000 --- a/go/vt/sqlparser/visitorgen/sast.go +++ /dev/null @@ -1,178 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -// simplified ast - when reading the golang ast of the ast.go file, we translate the golang ast objects -// to this much simpler format, that contains only the necessary information and no more -type ( - // SourceFile contains all important lines from an ast.go file - SourceFile struct { - lines []Sast - } - - // Sast or simplified AST, is a representation of the ast.go lines we are interested in - Sast interface { - toSastString() string - } - - // InterfaceDeclaration represents a declaration of an interface. This is used to keep track of which types - // need to be handled by the visitor framework - InterfaceDeclaration struct { - name, block string - } - - // TypeAlias is used whenever we see a `type XXX YYY` - XXX is the new name for YYY. - // Note that YYY could be an array or a reference - TypeAlias struct { - name string - typ Type - } - - // FuncDeclaration represents a function declaration. These are tracked to know which types implement interfaces. - FuncDeclaration struct { - receiver *Field - name, block string - arguments []*Field - } - - // StructDeclaration represents a struct. It contains the fields and their types - StructDeclaration struct { - name string - fields []*Field - } - - // Field is a field in a struct - a name with a type tuple - Field struct { - name string - typ Type - } - - // Type represents a type in the golang type system. Used to keep track of type we need to handle, - // and the types of fields. - Type interface { - toTypString() string - rawTypeName() string - } - - // TypeString is a raw type name, such as `string` - TypeString struct { - typName string - } - - // Ref is a reference to something, such as `*string` - Ref struct { - inner Type - } - - // Array is an array of things, such as `[]string` - Array struct { - inner Type - } -) - -var _ Sast = (*InterfaceDeclaration)(nil) -var _ Sast = (*StructDeclaration)(nil) -var _ Sast = (*FuncDeclaration)(nil) -var _ Sast = (*TypeAlias)(nil) - -var _ Type = (*TypeString)(nil) -var _ Type = (*Ref)(nil) -var _ Type = (*Array)(nil) - -// String returns a textual representation of the SourceFile. This is for testing purposed -func (t *SourceFile) String() string { - var result string - for _, l := range t.lines { - result += l.toSastString() - result += "\n" - } - - return result -} - -func (t *Ref) toTypString() string { - return "*" + t.inner.toTypString() -} - -func (t *Array) toTypString() string { - return "[]" + t.inner.toTypString() -} - -func (t *TypeString) toTypString() string { - return t.typName -} - -func (f *FuncDeclaration) toSastString() string { - var receiver string - if f.receiver != nil { - receiver = "(" + f.receiver.String() + ") " - } - var args string - for i, arg := range f.arguments { - if i > 0 { - args += ", " - } - args += arg.String() - } - - return "func " + receiver + f.name + "(" + args + ") {" + blockInNewLines(f.block) + "}" -} - -func (i *InterfaceDeclaration) toSastString() string { - return "type " + i.name + " interface {" + blockInNewLines(i.block) + "}" -} - -func (a *TypeAlias) toSastString() string { - return "type " + a.name + " " + a.typ.toTypString() -} - -func (s *StructDeclaration) toSastString() string { - var block string - for _, f := range s.fields { - block += "\t" + f.String() + "\n" - } - - return "type " + s.name + " struct {" + blockInNewLines(block) + "}" -} - -func blockInNewLines(block string) string { - if block == "" { - return "" - } - return "\n" + block + "\n" -} - -// String returns a string representation of a field -func (f *Field) String() string { - if f.name != "" { - return f.name + " " + f.typ.toTypString() - } - - return f.typ.toTypString() -} - -func (t *TypeString) rawTypeName() string { - return t.typName -} - -func (t *Ref) rawTypeName() string { - return t.inner.rawTypeName() -} - -func (t *Array) rawTypeName() string { - return t.inner.rawTypeName() -} diff --git a/go/vt/sqlparser/visitorgen/struct_producer.go b/go/vt/sqlparser/visitorgen/struct_producer.go deleted file mode 100644 index 1c293f30803..00000000000 --- a/go/vt/sqlparser/visitorgen/struct_producer.go +++ /dev/null @@ -1,253 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import ( - "fmt" - "sort" -) - -// VisitorData is the data needed to produce the output file -type ( - // VisitorItem represents something that needs to be added to the rewriter infrastructure - VisitorItem interface { - toFieldItemString() string - typeName() string - asSwitchCase() string - asReplMethod() string - getFieldName() string - } - - // SingleFieldItem is a single field in a struct - SingleFieldItem struct { - StructType, FieldType Type - FieldName string - } - - // ArrayFieldItem is an array field in a struct - ArrayFieldItem struct { - StructType, ItemType Type - FieldName string - } - - // ArrayItem is an array that implements SQLNode - ArrayItem struct { - StructType, ItemType Type - } - - // VisitorPlan represents all the output needed for the rewriter - VisitorPlan struct { - Switches []*SwitchCase // The cases for the big switch statement used to implement the visitor - } - - // SwitchCase is what we need to know to produce all the type switch cases in the visitor. - SwitchCase struct { - Type Type - Fields []VisitorItem - } -) - -var _ VisitorItem = (*SingleFieldItem)(nil) -var _ VisitorItem = (*ArrayItem)(nil) -var _ VisitorItem = (*ArrayFieldItem)(nil) -var _ sort.Interface = (*VisitorPlan)(nil) -var _ sort.Interface = (*SwitchCase)(nil) - -// ToVisitorPlan transforms the source information into a plan for the visitor code that needs to be produced -func ToVisitorPlan(input *SourceInformation) *VisitorPlan { - var output VisitorPlan - - for _, typ := range input.interestingTypes { - switchit := &SwitchCase{Type: typ} - stroct, isStruct := input.structs[typ.rawTypeName()] - if isStruct { - for _, f := range stroct.fields { - switchit.Fields = append(switchit.Fields, trySingleItem(input, f, typ)...) - } - } else { - itemType := input.getItemTypeOfArray(typ) - if itemType != nil && input.isSQLNode(itemType) { - switchit.Fields = append(switchit.Fields, &ArrayItem{ - StructType: typ, - ItemType: itemType, - }) - } - } - sort.Sort(switchit) - output.Switches = append(output.Switches, switchit) - } - sort.Sort(&output) - return &output -} - -func trySingleItem(input *SourceInformation, f *Field, typ Type) []VisitorItem { - if input.isSQLNode(f.typ) { - return []VisitorItem{&SingleFieldItem{ - StructType: typ, - FieldType: f.typ, - FieldName: f.name, - }} - } - - arrType, isArray := f.typ.(*Array) - if isArray && input.isSQLNode(arrType.inner) { - return []VisitorItem{&ArrayFieldItem{ - StructType: typ, - ItemType: arrType.inner, - FieldName: f.name, - }} - } - return []VisitorItem{} -} - -// String returns a string, used for testing -func (v *VisitorPlan) String() string { - var sb builder - for _, s := range v.Switches { - sb.appendF("Type: %v", s.Type.toTypString()) - for _, f := range s.Fields { - sb.appendF("\t%v", f.toFieldItemString()) - } - } - return sb.String() -} - -func (s *SingleFieldItem) toFieldItemString() string { - return fmt.Sprintf("single item: %v of type: %v", s.FieldName, s.FieldType.toTypString()) -} - -func (s *SingleFieldItem) asSwitchCase() string { - return fmt.Sprintf(` a.apply(node, n.%s, %s)`, s.FieldName, s.typeName()) -} - -func (s *SingleFieldItem) asReplMethod() string { - _, isRef := s.StructType.(*Ref) - - if isRef { - return fmt.Sprintf(`func %s(newNode, parent SQLNode) { - parent.(%s).%s = newNode.(%s) -}`, s.typeName(), s.StructType.toTypString(), s.FieldName, s.FieldType.toTypString()) - } - - return fmt.Sprintf(`func %s(newNode, parent SQLNode) { - tmp := parent.(%s) - tmp.%s = newNode.(%s) -}`, s.typeName(), s.StructType.toTypString(), s.FieldName, s.FieldType.toTypString()) - -} - -func (ai *ArrayItem) asReplMethod() string { - name := ai.typeName() - return fmt.Sprintf(`type %s int - -func (r *%s) replace(newNode, container SQLNode) { - container.(%s)[int(*r)] = newNode.(%s) -} - -func (r *%s) inc() { - *r++ -}`, name, name, ai.StructType.toTypString(), ai.ItemType.toTypString(), name) -} - -func (afi *ArrayFieldItem) asReplMethod() string { - name := afi.typeName() - return fmt.Sprintf(`type %s int - -func (r *%s) replace(newNode, container SQLNode) { - container.(%s).%s[int(*r)] = newNode.(%s) -} - -func (r *%s) inc() { - *r++ -}`, name, name, afi.StructType.toTypString(), afi.FieldName, afi.ItemType.toTypString(), name) -} - -func (s *SingleFieldItem) getFieldName() string { - return s.FieldName -} - -func (s *SingleFieldItem) typeName() string { - return "replace" + s.StructType.rawTypeName() + s.FieldName -} - -func (afi *ArrayFieldItem) toFieldItemString() string { - return fmt.Sprintf("array field item: %v.%v contains items of type %v", afi.StructType.toTypString(), afi.FieldName, afi.ItemType.toTypString()) -} - -func (ai *ArrayItem) toFieldItemString() string { - return fmt.Sprintf("array item: %v containing items of type %v", ai.StructType.toTypString(), ai.ItemType.toTypString()) -} - -func (ai *ArrayItem) getFieldName() string { - panic("Should not be called!") -} - -func (afi *ArrayFieldItem) getFieldName() string { - return afi.FieldName -} - -func (ai *ArrayItem) asSwitchCase() string { - return fmt.Sprintf(` replacer := %s(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() - }`, ai.typeName()) -} - -func (afi *ArrayFieldItem) asSwitchCase() string { - return fmt.Sprintf(` replacer%s := %s(0) - replacer%sB := &replacer%s - for _, item := range n.%s { - a.apply(node, item, replacer%sB.replace) - replacer%sB.inc() - }`, afi.FieldName, afi.typeName(), afi.FieldName, afi.FieldName, afi.FieldName, afi.FieldName, afi.FieldName) -} - -func (ai *ArrayItem) typeName() string { - return "replace" + ai.StructType.rawTypeName() + "Items" -} - -func (afi *ArrayFieldItem) typeName() string { - return "replace" + afi.StructType.rawTypeName() + afi.FieldName -} -func (v *VisitorPlan) Len() int { - return len(v.Switches) -} - -func (v *VisitorPlan) Less(i, j int) bool { - return v.Switches[i].Type.rawTypeName() < v.Switches[j].Type.rawTypeName() -} - -func (v *VisitorPlan) Swap(i, j int) { - temp := v.Switches[i] - v.Switches[i] = v.Switches[j] - v.Switches[j] = temp -} -func (s *SwitchCase) Len() int { - return len(s.Fields) -} - -func (s *SwitchCase) Less(i, j int) bool { - return s.Fields[i].getFieldName() < s.Fields[j].getFieldName() -} - -func (s *SwitchCase) Swap(i, j int) { - temp := s.Fields[i] - s.Fields[i] = s.Fields[j] - s.Fields[j] = temp -} diff --git a/go/vt/sqlparser/visitorgen/struct_producer_test.go b/go/vt/sqlparser/visitorgen/struct_producer_test.go deleted file mode 100644 index 065b532a9eb..00000000000 --- a/go/vt/sqlparser/visitorgen/struct_producer_test.go +++ /dev/null @@ -1,423 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestEmptyStructVisitor(t *testing.T) { - /* - type Node interface{} - type Struct struct {} - func (*Struct) iNode() {} - */ - - input := &SourceInformation{ - interestingTypes: map[string]Type{ - "*Struct": &Ref{&TypeString{"Struct"}}, - }, - interfaces: map[string]bool{ - "Node": true, - }, - structs: map[string]*StructDeclaration{ - "Struct": {name: "Struct", fields: []*Field{}}, - }, - typeAliases: map[string]*TypeAlias{}, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"Struct"}}, - Fields: []VisitorItem{}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithSqlNodeField(t *testing.T) { - /* - type Node interface{} - type Struct struct { - Field Node - } - func (*Struct) iNode() {} - */ - input := &SourceInformation{ - interestingTypes: map[string]Type{ - "*Struct": &Ref{&TypeString{"Struct"}}, - }, - interfaces: map[string]bool{ - "Node": true, - }, - structs: map[string]*StructDeclaration{ - "Struct": {name: "Struct", fields: []*Field{ - {name: "Field", typ: &TypeString{"Node"}}, - }}, - }, - typeAliases: map[string]*TypeAlias{}, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"Struct"}}, - Fields: []VisitorItem{&SingleFieldItem{ - StructType: &Ref{&TypeString{"Struct"}}, - FieldType: &TypeString{"Node"}, - FieldName: "Field", - }}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithStringField2(t *testing.T) { - /* - type Node interface{} - type Struct struct { - Field Node - } - func (*Struct) iNode() {} - */ - - input := &SourceInformation{ - interestingTypes: map[string]Type{ - "*Struct": &Ref{&TypeString{"Struct"}}, - }, - interfaces: map[string]bool{ - "Node": true, - }, - structs: map[string]*StructDeclaration{ - "Struct": {name: "Struct", fields: []*Field{ - {name: "Field", typ: &TypeString{"string"}}, - }}, - }, - typeAliases: map[string]*TypeAlias{}, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"Struct"}}, - Fields: []VisitorItem{}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestArrayAsSqlNode(t *testing.T) { - /* - type NodeInterface interface { - iNode() - } - - func (*NodeArray) iNode{} - - type NodeArray []NodeInterface - */ - - input := &SourceInformation{ - interfaces: map[string]bool{"NodeInterface": true}, - interestingTypes: map[string]Type{ - "*NodeArray": &Ref{&TypeString{"NodeArray"}}}, - structs: map[string]*StructDeclaration{}, - typeAliases: map[string]*TypeAlias{ - "NodeArray": { - name: "NodeArray", - typ: &Array{&TypeString{"NodeInterface"}}, - }, - }, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"NodeArray"}}, - Fields: []VisitorItem{&ArrayItem{ - StructType: &Ref{&TypeString{"NodeArray"}}, - ItemType: &TypeString{"NodeInterface"}, - }}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithStructField(t *testing.T) { - /* - type Node interface{} - type Struct struct { - Field *Struct - } - func (*Struct) iNode() {} - */ - - input := &SourceInformation{ - interestingTypes: map[string]Type{ - "*Struct": &Ref{&TypeString{"Struct"}}}, - structs: map[string]*StructDeclaration{ - "Struct": {name: "Struct", fields: []*Field{ - {name: "Field", typ: &Ref{&TypeString{"Struct"}}}, - }}, - }, - typeAliases: map[string]*TypeAlias{}, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"Struct"}}, - Fields: []VisitorItem{&SingleFieldItem{ - StructType: &Ref{&TypeString{"Struct"}}, - FieldType: &Ref{&TypeString{"Struct"}}, - FieldName: "Field", - }}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithArrayOfNodes(t *testing.T) { - /* - type NodeInterface interface {} - type Struct struct { - Items []NodeInterface - } - - func (*Struct) iNode{} - */ - - input := &SourceInformation{ - interfaces: map[string]bool{ - "NodeInterface": true, - }, - interestingTypes: map[string]Type{ - "*Struct": &Ref{&TypeString{"Struct"}}}, - structs: map[string]*StructDeclaration{ - "Struct": {name: "Struct", fields: []*Field{ - {name: "Items", typ: &Array{&TypeString{"NodeInterface"}}}, - }}, - }, - typeAliases: map[string]*TypeAlias{}, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"Struct"}}, - Fields: []VisitorItem{&ArrayFieldItem{ - StructType: &Ref{&TypeString{"Struct"}}, - ItemType: &TypeString{"NodeInterface"}, - FieldName: "Items", - }}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestStructWithArrayOfStrings(t *testing.T) { - /* - type NodeInterface interface {} - type Struct struct { - Items []string - } - - func (*Struct) iNode{} - */ - - input := &SourceInformation{ - interfaces: map[string]bool{ - "NodeInterface": true, - }, - interestingTypes: map[string]Type{ - "*Struct": &Ref{&TypeString{"Struct"}}}, - structs: map[string]*StructDeclaration{ - "Struct": {name: "Struct", fields: []*Field{ - {name: "Items", typ: &Array{&TypeString{"string"}}}, - }}, - }, - typeAliases: map[string]*TypeAlias{}, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"Struct"}}, - Fields: []VisitorItem{}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestArrayOfStringsThatImplementSQLNode(t *testing.T) { - /* - type NodeInterface interface {} - type Struct []string - func (Struct) iNode{} - */ - - input := &SourceInformation{ - interfaces: map[string]bool{"NodeInterface": true}, - interestingTypes: map[string]Type{"Struct": &Ref{&TypeString{"Struct"}}}, - structs: map[string]*StructDeclaration{}, - typeAliases: map[string]*TypeAlias{ - "Struct": { - name: "Struct", - typ: &Array{&TypeString{"string"}}, - }, - }, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{{ - Type: &Ref{&TypeString{"Struct"}}, - Fields: []VisitorItem{}, - }}, - } - - assert.Equal(t, expected.String(), result.String()) -} - -func TestSortingOfOutputs(t *testing.T) { - /* - type NodeInterface interface {} - type AStruct struct { - AField NodeInterface - BField NodeInterface - } - type BStruct struct { - CField NodeInterface - } - func (*AStruct) iNode{} - func (*BStruct) iNode{} - */ - - input := &SourceInformation{ - interfaces: map[string]bool{"NodeInterface": true}, - interestingTypes: map[string]Type{ - "AStruct": &Ref{&TypeString{"AStruct"}}, - "BStruct": &Ref{&TypeString{"BStruct"}}, - }, - structs: map[string]*StructDeclaration{ - "AStruct": {name: "AStruct", fields: []*Field{ - {name: "BField", typ: &TypeString{"NodeInterface"}}, - {name: "AField", typ: &TypeString{"NodeInterface"}}, - }}, - "BStruct": {name: "BStruct", fields: []*Field{ - {name: "CField", typ: &TypeString{"NodeInterface"}}, - }}, - }, - typeAliases: map[string]*TypeAlias{}, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{ - {Type: &Ref{&TypeString{"AStruct"}}, - Fields: []VisitorItem{ - &SingleFieldItem{ - StructType: &Ref{&TypeString{"AStruct"}}, - FieldType: &TypeString{"NodeInterface"}, - FieldName: "AField", - }, - &SingleFieldItem{ - StructType: &Ref{&TypeString{"AStruct"}}, - FieldType: &TypeString{"NodeInterface"}, - FieldName: "BField", - }}}, - {Type: &Ref{&TypeString{"BStruct"}}, - Fields: []VisitorItem{ - &SingleFieldItem{ - StructType: &Ref{&TypeString{"BStruct"}}, - FieldType: &TypeString{"NodeInterface"}, - FieldName: "CField", - }}}}, - } - assert.Equal(t, expected.String(), result.String()) -} - -func TestAliasOfAlias(t *testing.T) { - /* - type NodeInterface interface { - iNode() - } - - type NodeArray []NodeInterface - type AliasOfAlias NodeArray - - func (NodeArray) iNode{} - func (AliasOfAlias) iNode{} - */ - - input := &SourceInformation{ - interfaces: map[string]bool{"NodeInterface": true}, - interestingTypes: map[string]Type{ - "NodeArray": &TypeString{"NodeArray"}, - "AliasOfAlias": &TypeString{"AliasOfAlias"}, - }, - structs: map[string]*StructDeclaration{}, - typeAliases: map[string]*TypeAlias{ - "NodeArray": { - name: "NodeArray", - typ: &Array{&TypeString{"NodeInterface"}}, - }, - "AliasOfAlias": { - name: "NodeArray", - typ: &TypeString{"NodeArray"}, - }, - }, - } - - result := ToVisitorPlan(input) - - expected := &VisitorPlan{ - Switches: []*SwitchCase{ - {Type: &TypeString{"AliasOfAlias"}, - Fields: []VisitorItem{&ArrayItem{ - StructType: &TypeString{"AliasOfAlias"}, - ItemType: &TypeString{"NodeInterface"}, - }}, - }, - {Type: &TypeString{"NodeArray"}, - Fields: []VisitorItem{&ArrayItem{ - StructType: &TypeString{"NodeArray"}, - ItemType: &TypeString{"NodeInterface"}, - }}, - }}, - } - assert.Equal(t, expected.String(), result.String()) -} diff --git a/go/vt/sqlparser/visitorgen/transformer.go b/go/vt/sqlparser/visitorgen/transformer.go deleted file mode 100644 index 98129be81b1..00000000000 --- a/go/vt/sqlparser/visitorgen/transformer.go +++ /dev/null @@ -1,95 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import "fmt" - -// Transform takes an input file and collects the information into an easier to consume format -func Transform(input *SourceFile) *SourceInformation { - interestingTypes := make(map[string]Type) - interfaces := make(map[string]bool) - structs := make(map[string]*StructDeclaration) - typeAliases := make(map[string]*TypeAlias) - - for _, l := range input.lines { - switch line := l.(type) { - case *FuncDeclaration: - interestingTypes[line.receiver.typ.toTypString()] = line.receiver.typ - case *StructDeclaration: - structs[line.name] = line - case *TypeAlias: - typeAliases[line.name] = line - case *InterfaceDeclaration: - interfaces[line.name] = true - } - } - - return &SourceInformation{ - interfaces: interfaces, - interestingTypes: interestingTypes, - structs: structs, - typeAliases: typeAliases, - } -} - -// SourceInformation contains the information from the ast.go file, but in a format that is easier to consume -type SourceInformation struct { - interestingTypes map[string]Type - interfaces map[string]bool - structs map[string]*StructDeclaration - typeAliases map[string]*TypeAlias -} - -func (v *SourceInformation) String() string { - var types string - for _, k := range v.interestingTypes { - types += k.toTypString() + "\n" - } - var structs string - for _, k := range v.structs { - structs += k.toSastString() + "\n" - } - var typeAliases string - for _, k := range v.typeAliases { - typeAliases += k.toSastString() + "\n" - } - - return fmt.Sprintf("Types to build visitor for:\n%s\nStructs with fields: \n%s\nTypeAliases with type: \n%s\n", types, structs, typeAliases) -} - -// getItemTypeOfArray will return nil if the given type is not pointing to a array type. -// If it is an array type, the type of it's items will be returned -func (v *SourceInformation) getItemTypeOfArray(typ Type) Type { - alias := v.typeAliases[typ.rawTypeName()] - if alias == nil { - return nil - } - arrTyp, isArray := alias.typ.(*Array) - if !isArray { - return v.getItemTypeOfArray(alias.typ) - } - return arrTyp.inner -} - -func (v *SourceInformation) isSQLNode(typ Type) bool { - _, isInteresting := v.interestingTypes[typ.toTypString()] - if isInteresting { - return true - } - _, isInterface := v.interfaces[typ.toTypString()] - return isInterface -} diff --git a/go/vt/sqlparser/visitorgen/transformer_test.go b/go/vt/sqlparser/visitorgen/transformer_test.go deleted file mode 100644 index 4a0849e9e9c..00000000000 --- a/go/vt/sqlparser/visitorgen/transformer_test.go +++ /dev/null @@ -1,110 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestSimplestAst(t *testing.T) { - /* - type NodeInterface interface { - iNode() - } - - type NodeStruct struct {} - - func (*NodeStruct) iNode{} - */ - input := &SourceFile{ - lines: []Sast{ - &InterfaceDeclaration{ - name: "NodeInterface", - block: "// an interface lives here"}, - &StructDeclaration{ - name: "NodeStruct", - fields: []*Field{}}, - &FuncDeclaration{ - receiver: &Field{ - name: "", - typ: &Ref{&TypeString{"NodeStruct"}}, - }, - name: "iNode", - block: "", - arguments: []*Field{}}, - }, - } - - expected := &SourceInformation{ - interestingTypes: map[string]Type{ - "*NodeStruct": &Ref{&TypeString{"NodeStruct"}}}, - structs: map[string]*StructDeclaration{ - "NodeStruct": { - name: "NodeStruct", - fields: []*Field{}}}, - } - - assert.Equal(t, expected.String(), Transform(input).String()) -} - -func TestAstWithArray(t *testing.T) { - /* - type NodeInterface interface { - iNode() - } - - func (*NodeArray) iNode{} - - type NodeArray []NodeInterface - */ - input := &SourceFile{ - lines: []Sast{ - &InterfaceDeclaration{ - name: "NodeInterface"}, - &TypeAlias{ - name: "NodeArray", - typ: &Array{&TypeString{"NodeInterface"}}, - }, - &FuncDeclaration{ - receiver: &Field{ - name: "", - typ: &Ref{&TypeString{"NodeArray"}}, - }, - name: "iNode", - block: "", - arguments: []*Field{}}, - }, - } - - expected := &SourceInformation{ - interestingTypes: map[string]Type{ - "*NodeArray": &Ref{&TypeString{"NodeArray"}}}, - structs: map[string]*StructDeclaration{}, - typeAliases: map[string]*TypeAlias{ - "NodeArray": { - name: "NodeArray", - typ: &Array{&TypeString{"NodeInterface"}}, - }, - }, - } - - result := Transform(input) - - assert.Equal(t, expected.String(), result.String()) -} diff --git a/go/vt/sqlparser/visitorgen/visitor_emitter.go b/go/vt/sqlparser/visitorgen/visitor_emitter.go deleted file mode 100644 index 889c05fe7f7..00000000000 --- a/go/vt/sqlparser/visitorgen/visitor_emitter.go +++ /dev/null @@ -1,76 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import ( - "fmt" - "strings" -) - -// EmitReplacementMethods is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST, -// and produces a string from it. This method will produce the replacement methods that make it possible to -// replace objects in fields or in slices. -func EmitReplacementMethods(vd *VisitorPlan) string { - var sb builder - for _, s := range vd.Switches { - for _, k := range s.Fields { - sb.appendF(k.asReplMethod()) - sb.newLine() - } - } - - return sb.String() -} - -// EmitTypeSwitches is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST, -// and produces a string from it. This method will produce the switch cases needed to cover the Vitess AST. -func EmitTypeSwitches(vd *VisitorPlan) string { - var sb builder - for _, s := range vd.Switches { - sb.newLine() - sb.appendF(" case %s:", s.Type.toTypString()) - for _, k := range s.Fields { - sb.appendF(k.asSwitchCase()) - } - } - - return sb.String() -} - -func (b *builder) String() string { - return strings.TrimSpace(b.sb.String()) -} - -type builder struct { - sb strings.Builder -} - -func (b *builder) appendF(format string, data ...interface{}) *builder { - _, err := b.sb.WriteString(fmt.Sprintf(format, data...)) - if err != nil { - panic(err) - } - b.newLine() - return b -} - -func (b *builder) newLine() { - _, err := b.sb.WriteString("\n") - if err != nil { - panic(err) - } -} diff --git a/go/vt/sqlparser/visitorgen/visitor_emitter_test.go b/go/vt/sqlparser/visitorgen/visitor_emitter_test.go deleted file mode 100644 index 94666daa743..00000000000 --- a/go/vt/sqlparser/visitorgen/visitor_emitter_test.go +++ /dev/null @@ -1,92 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package visitorgen - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSingleItem(t *testing.T) { - sfi := SingleFieldItem{ - StructType: &Ref{&TypeString{"Struct"}}, - FieldType: &TypeString{"string"}, - FieldName: "Field", - } - - expectedReplacer := `func replaceStructField(newNode, parent SQLNode) { - parent.(*Struct).Field = newNode.(string) -}` - - expectedSwitch := ` a.apply(node, n.Field, replaceStructField)` - require.Equal(t, expectedReplacer, sfi.asReplMethod()) - require.Equal(t, expectedSwitch, sfi.asSwitchCase()) -} - -func TestArrayFieldItem(t *testing.T) { - sfi := ArrayFieldItem{ - StructType: &Ref{&TypeString{"Struct"}}, - ItemType: &TypeString{"string"}, - FieldName: "Field", - } - - expectedReplacer := `type replaceStructField int - -func (r *replaceStructField) replace(newNode, container SQLNode) { - container.(*Struct).Field[int(*r)] = newNode.(string) -} - -func (r *replaceStructField) inc() { - *r++ -}` - - expectedSwitch := ` replacerField := replaceStructField(0) - replacerFieldB := &replacerField - for _, item := range n.Field { - a.apply(node, item, replacerFieldB.replace) - replacerFieldB.inc() - }` - require.Equal(t, expectedReplacer, sfi.asReplMethod()) - require.Equal(t, expectedSwitch, sfi.asSwitchCase()) -} - -func TestArrayItem(t *testing.T) { - sfi := ArrayItem{ - StructType: &Ref{&TypeString{"Struct"}}, - ItemType: &TypeString{"string"}, - } - - expectedReplacer := `type replaceStructItems int - -func (r *replaceStructItems) replace(newNode, container SQLNode) { - container.(*Struct)[int(*r)] = newNode.(string) -} - -func (r *replaceStructItems) inc() { - *r++ -}` - - expectedSwitch := ` replacer := replaceStructItems(0) - replacerRef := &replacer - for _, item := range n { - a.apply(node, item, replacerRef.replace) - replacerRef.inc() - }` - require.Equal(t, expectedReplacer, sfi.asReplMethod()) - require.Equal(t, expectedSwitch, sfi.asSwitchCase()) -} diff --git a/go/vt/sqlparser/visitorgen/visitorgen.go b/go/vt/sqlparser/visitorgen/visitorgen.go deleted file mode 100644 index 284f8c4d9be..00000000000 --- a/go/vt/sqlparser/visitorgen/visitorgen.go +++ /dev/null @@ -1,33 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -//Package visitorgen is responsible for taking the ast.go of Vitess and -//and producing visitor infrastructure for it. -// -//This is accomplished in a few steps. -//Step 1: Walk the AST and collect the interesting information into a format that is -// easy to consume for the next step. The output format is a *SourceFile, that -// contains the needed information in a format that is pretty close to the golang ast, -// but simplified -//Step 2: A SourceFile is packaged into a SourceInformation. SourceInformation is still -// concerned with the input ast - it's just an even more distilled and easy to -// consume format for the last step. This step is performed by the code in transformer.go. -//Step 3: Using the SourceInformation, the struct_producer.go code produces the final data structure -// used, a VisitorPlan. This is focused on the output - it contains a list of all fields or -// arrays that need to be handled by the visitor produced. -//Step 4: The VisitorPlan is lastly turned into a string that is written as the output of -// this whole process. -package visitorgen diff --git a/go/vt/vtgate/planbuilder/ddl.go b/go/vt/vtgate/planbuilder/ddl.go index 9ef47d18990..2c3b3675e37 100644 --- a/go/vt/vtgate/planbuilder/ddl.go +++ b/go/vt/vtgate/planbuilder/ddl.go @@ -164,7 +164,7 @@ func buildAlterView(vschema ContextVSchema, ddl *sqlparser.AlterView) (key.Desti if routePlan.Opcode != engine.SelectUnsharded && routePlan.Opcode != engine.SelectEqualUnique && routePlan.Opcode != engine.SelectScatter { return nil, nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, ViewComplex) } - sqlparser.Rewrite(ddl.Select, func(cursor *sqlparser.Cursor) bool { + _, err = sqlparser.Rewrite(ddl.Select, func(cursor *sqlparser.Cursor) bool { switch tableName := cursor.Node().(type) { case sqlparser.TableName: cursor.Replace(sqlparser.TableName{ @@ -173,6 +173,9 @@ func buildAlterView(vschema ContextVSchema, ddl *sqlparser.AlterView) (key.Desti } return true }, nil) + if err != nil { + return nil, nil, err + } return destination, keyspace, nil } @@ -200,7 +203,7 @@ func buildCreateView(vschema ContextVSchema, ddl *sqlparser.CreateView) (key.Des if routePlan.Opcode != engine.SelectUnsharded && routePlan.Opcode != engine.SelectEqualUnique && routePlan.Opcode != engine.SelectScatter { return nil, nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, ViewComplex) } - sqlparser.Rewrite(ddl.Select, func(cursor *sqlparser.Cursor) bool { + _, err = sqlparser.Rewrite(ddl.Select, func(cursor *sqlparser.Cursor) bool { switch tableName := cursor.Node().(type) { case sqlparser.TableName: cursor.Replace(sqlparser.TableName{ @@ -209,6 +212,9 @@ func buildCreateView(vschema ContextVSchema, ddl *sqlparser.CreateView) (key.Des } return true }, nil) + if err != nil { + return nil, nil, err + } return destination, keyspace, nil } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index f9b7b5a3dae..1d4b09d77d8 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -446,8 +446,8 @@ func pushPredicate2(exprs []sqlparser.Expr, tree joinTree, semTable *semantics.S } func breakPredicateInLHSandRHS(expr sqlparser.Expr, semTable *semantics.SemTable, lhs semantics.TableSet) (columns []*sqlparser.ColName, predicate sqlparser.Expr, err error) { - predicate = expr.Clone() - sqlparser.Rewrite(predicate, nil, func(cursor *sqlparser.Cursor) bool { + predicate = sqlparser.CloneExpr(expr) + _, err = sqlparser.Rewrite(predicate, nil, func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.ColName: deps := semTable.Dependencies(node) @@ -463,6 +463,9 @@ func breakPredicateInLHSandRHS(expr sqlparser.Expr, semTable *semantics.SemTable } return true }) + if err != nil { + return nil, nil, err + } return } diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 3e6b1f8aa14..09b7c49ced4 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -531,7 +531,7 @@ func (pb *primitiveBuilder) pushSelectRoutes(selectExprs sqlparser.SelectExprs) } } resultColumns = append(resultColumns, rb.PushAnonymous(node)) - case sqlparser.Nextval: + case *sqlparser.Nextval: rb, ok := pb.plan.(*route) if !ok { // This code is unreachable because the parser doesn't allow joins for next val statements. diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index ebb23e181b0..72c954f7e79 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -172,8 +172,10 @@ func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.S } func (a *analyzer) analyze(statement sqlparser.Statement) error { - _ = sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp) - + _, err := sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp) + if err != nil { + return err + } return a.err } diff --git a/go/vt/vttablet/tabletserver/planbuilder/builder.go b/go/vt/vttablet/tabletserver/planbuilder/builder.go index 747ea2a2659..990ecdc3be8 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/builder.go +++ b/go/vt/vttablet/tabletserver/planbuilder/builder.go @@ -46,7 +46,7 @@ func analyzeSelect(sel *sqlparser.Select, tables map[string]*schema.Table) (plan } // Check if it's a NEXT VALUE statement. - if nextVal, ok := sel.SelectExprs[0].(sqlparser.Nextval); ok { + if nextVal, ok := sel.SelectExprs[0].(*sqlparser.Nextval); ok { if plan.Table == nil || plan.Table.Type != schema.Sequence { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%s is not a sequence", sqlparser.String(sel.From)) } @@ -134,7 +134,10 @@ func analyzeShow(show *sqlparser.Show, dbName string) (plan *Plan, err error) { // rewrite WHERE clause if it exists // `where Tables_in_Keyspace` => `where Tables_in_DbName` if showInternal.Filter != nil { - showTableRewrite(showInternal, dbName) + err := showTableRewrite(showInternal, dbName) + if err != nil { + return nil, err + } } } return &Plan{ @@ -153,10 +156,10 @@ func analyzeShow(show *sqlparser.Show, dbName string) (plan *Plan, err error) { return &Plan{PlanID: PlanOtherRead}, nil } -func showTableRewrite(show *sqlparser.ShowBasic, dbName string) { +func showTableRewrite(show *sqlparser.ShowBasic, dbName string) error { filter := show.Filter.Filter if filter != nil { - sqlparser.Rewrite(filter, func(cursor *sqlparser.Cursor) bool { + _, err := sqlparser.Rewrite(filter, func(cursor *sqlparser.Cursor) bool { switch n := cursor.Node().(type) { case *sqlparser.ColName: if n.Qualifier.IsEmpty() && strings.HasPrefix(n.Name.Lowered(), "tables_in_") { @@ -165,7 +168,11 @@ func showTableRewrite(show *sqlparser.ShowBasic, dbName string) { } return true }, nil) + if err != nil { + return err + } } + return nil } func analyzeSet(set *sqlparser.Set) (plan *Plan) { diff --git a/go/vt/wrangler/materializer.go b/go/vt/wrangler/materializer.go index 8016e453082..d0e57aa2765 100644 --- a/go/vt/wrangler/materializer.go +++ b/go/vt/wrangler/materializer.go @@ -973,7 +973,10 @@ func stripTableConstraints(ddl string) (string, error) { return true } - noConstraintAST := sqlparser.Rewrite(ast, stripConstraints, nil) + noConstraintAST, err := sqlparser.Rewrite(ast, stripConstraints, nil) + if err != nil { + return "", err + } newDDL := sqlparser.String(noConstraintAST) return newDDL, nil diff --git a/misc/git/hooks/visitorgen b/misc/git/hooks/visitorgen index 65c04d613db..3ac99cb0a07 100755 --- a/misc/git/hooks/visitorgen +++ b/misc/git/hooks/visitorgen @@ -15,4 +15,4 @@ # this script, which should run before committing code, makes sure that the visitor is re-generated when the ast changes -go run ./go/vt/sqlparser/visitorgen/main -compareOnly=true -input=go/vt/sqlparser/ast.go -output=go/vt/sqlparser/rewriter.go \ No newline at end of file +go run ./go/tools/asthelpergen -in ./go/vt/sqlparser -verify=true -iface vitess.io/vitess/go/vt/sqlparser.SQLNode -except "*ColName" \ No newline at end of file