diff --git a/internal/aliasfix/aliasfix.go b/internal/aliasfix/aliasfix.go index a259501362cd..7e511195cb9f 100644 --- a/internal/aliasfix/aliasfix.go +++ b/internal/aliasfix/aliasfix.go @@ -26,6 +26,7 @@ import ( "io" "io/fs" "os" + "path" "path/filepath" "strconv" "strings" @@ -46,12 +47,12 @@ func ProcessPath(path string) error { } if dir.IsDir() { err := filepath.WalkDir(path, func(path string, d fs.DirEntry, err error) error { - if err == nil && !d.IsDir() && strings.HasSuffix(d.Name(), ".go") { - err = processFile(path, nil) - } - if err != nil { + if err != nil || d.IsDir() { return err } + if strings.HasSuffix(d.Name(), ".go") { + return processFile(path, nil) + } return nil }) if err != nil { @@ -73,6 +74,7 @@ func processFile(name string, w io.Writer) (err error) { if err != nil { return err } + var modified bool for _, imp := range f.Imports { var importPath string @@ -81,8 +83,8 @@ func processFile(name string, w io.Writer) (err error) { return err } if pkg, ok := GenprotoPkgMigration[importPath]; ok && pkg.Status == StatusMigrated { - oldNamespace := importPath[strings.LastIndex(importPath, "/")+1:] - newNamespace := pkg.ImportPath[strings.LastIndex(pkg.ImportPath, "/")+1:] + oldNamespace := genprotoNamespace(importPath) + newNamespace := path.Base(pkg.ImportPath) if imp.Name == nil && oldNamespace != newNamespace { // use old namespace for fewer diffs imp.Name = ast.NewIdent(oldNamespace) @@ -99,26 +101,6 @@ func processFile(name string, w io.Writer) (err error) { return nil } - if w == nil { - backup := name + ".bak" - if err = os.Rename(name, backup); err != nil { - return err - } - defer func() { - if err != nil { - os.Rename(backup, name) - } else { - os.Remove(backup) - } - }() - var file *os.File - file, err = os.Create(name) - if err != nil { - return err - } - defer file.Close() - w = file - } var buf bytes.Buffer if err := format.Node(&buf, fset, f); err != nil { return err @@ -127,9 +109,32 @@ func processFile(name string, w io.Writer) (err error) { if err != nil { return err } - if _, err := w.Write(b); err != nil { + + if w != nil { + _, err := w.Write(b) return err } - return nil + backup := name + ".bak" + if err = os.Rename(name, backup); err != nil { + return err + } + defer func() { + if err != nil { + os.Rename(backup, name) + } else { + os.Remove(backup) + } + }() + + return os.WriteFile(name, b, 0644) +} + +func genprotoNamespace(importPath string) string { + suffix := path.Base(importPath) + // if it looks like a version, then use the second from last component. + if len(suffix) >= 2 && suffix[0] == 'v' && '0' <= suffix[1] && suffix[1] <= '1' { + return path.Base(path.Dir(importPath)) + } + return suffix } diff --git a/internal/aliasfix/aliasfix_test.go b/internal/aliasfix/aliasfix_test.go index acbe7bd7f164..0bd01de1c38a 100644 --- a/internal/aliasfix/aliasfix_test.go +++ b/internal/aliasfix/aliasfix_test.go @@ -74,10 +74,14 @@ func TestGolden(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - GenprotoPkgMigration["example.com/old/foo"] = Pkg{ + GenprotoPkgMigration["example.com/old/foo/v1"] = Pkg{ ImportPath: "example.com/new/foopb", Status: tc.status, } + GenprotoPkgMigration["example.com/old/bar/v1/bar"] = Pkg{ + ImportPath: "example.com/new/barpb", + Status: tc.status, + } var w bytes.Buffer if updateGoldens { if err := processFile(filepath.Join("testdata", tc.fileName), nil); err != nil { diff --git a/internal/aliasfix/testdata/golden/input2 b/internal/aliasfix/testdata/golden/input2 index 82c745d251fe..e262dcdcadda 100644 --- a/internal/aliasfix/testdata/golden/input2 +++ b/internal/aliasfix/testdata/golden/input2 @@ -3,7 +3,9 @@ package golden import ( "net" + bar "example.com/new/barpb" foo "example.com/new/foopb" ) func Bar1(foo foo.Baz, addr net.Addr) {} +func Bar2(foo bar.Baz, addr net.Addr) {} diff --git a/internal/aliasfix/testdata/input1 b/internal/aliasfix/testdata/input1 index 90a52b51ef84..85868f7a76e0 100644 --- a/internal/aliasfix/testdata/input1 +++ b/internal/aliasfix/testdata/input1 @@ -1,5 +1,5 @@ package golden -import "example.com/old/foo" +import "example.com/old/foo/v1" func Bar1(foo foo.Baz) {} diff --git a/internal/aliasfix/testdata/input2 b/internal/aliasfix/testdata/input2 index 684233ffce47..bdba2310f05b 100644 --- a/internal/aliasfix/testdata/input2 +++ b/internal/aliasfix/testdata/input2 @@ -3,7 +3,9 @@ package golden import ( "net" - "example.com/old/foo" + "example.com/old/foo/v1" + "example.com/old/bar/v1/bar" ) func Bar1(foo foo.Baz, addr net.Addr) {} +func Bar2(foo bar.Baz, addr net.Addr) {} diff --git a/internal/aliasfix/testdata/input4 b/internal/aliasfix/testdata/input4 index 26734f81ba53..033c5cbcc1bf 100644 --- a/internal/aliasfix/testdata/input4 +++ b/internal/aliasfix/testdata/input4 @@ -1,5 +1,5 @@ package golden -import foopb "example.com/old/foo" +import foopb "example.com/old/foo/v1" func Bar4(baz foopb.Baz) {} diff --git a/internal/aliasfix/testdata/input5 b/internal/aliasfix/testdata/input5 index 95b0d9fad794..f0e915942052 100644 --- a/internal/aliasfix/testdata/input5 +++ b/internal/aliasfix/testdata/input5 @@ -3,7 +3,7 @@ package golden import ( "net" - blah "example.com/old/foo" + blah "example.com/old/foo/v1" ) func Bar2(baz blah.Baz, addr net.Addr) {} diff --git a/internal/aliasfix/testdata/input6 b/internal/aliasfix/testdata/input6 index 95b0d9fad794..f0e915942052 100644 --- a/internal/aliasfix/testdata/input6 +++ b/internal/aliasfix/testdata/input6 @@ -3,7 +3,7 @@ package golden import ( "net" - blah "example.com/old/foo" + blah "example.com/old/foo/v1" ) func Bar2(baz blah.Baz, addr net.Addr) {}