From 66bf157bf5bcd4cf5e82f17680e12c7fc873a2c1 Mon Sep 17 00:00:00 2001 From: Jay Conrod Date: Wed, 24 Feb 2021 17:55:24 -0500 Subject: [PATCH] modfile: defer fixing versions in retract directives VersionFixers require both a path and a version: if the version is non-canonical (like a branch name), they generally need the path to look up the proper version. This is fine for require, replace, and exclude directives, since the path is specified with each version. For retract directives, the path comes from the module directive, which may appear later in the file. Previously, we just used the empty string, but this breaks reasonable implementations. With this change, we leave retracted versions alone until the file has been completely parsed, then we apply the version fixer to each retract directive. We report an error if retract is used without a module directive. For golang/go#44494 Change-Id: I99b7b8b55941c1fde4ee56161acfe854bcaf948d Reviewed-on: https://go-review.googlesource.com/c/mod/+/296130 Trust: Jay Conrod Run-TryBot: Jay Conrod TryBot-Result: Go Bot Reviewed-by: Bryan C. Mills --- modfile/rule.go | 102 ++++++++++++++++++++++++++++++++++--------- modfile/rule_test.go | 88 +++++++++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 21 deletions(-) diff --git a/modfile/rule.go b/modfile/rule.go index 8fcf96b..f8c9384 100644 --- a/modfile/rule.go +++ b/modfile/rule.go @@ -125,6 +125,12 @@ func (f *File) AddComment(text string) { type VersionFixer func(path, version string) (string, error) +// errDontFix is returned by a VersionFixer to indicate the version should be +// left alone, even if it's not canonical. +var dontFixRetract VersionFixer = func(_, vers string) (string, error) { + return vers, nil +} + // Parse parses the data, reported in errors as being from file, // into a File struct. It applies fix, if non-nil, to canonicalize all module versions found. func Parse(file string, data []byte, fix VersionFixer) (*File, error) { @@ -142,7 +148,7 @@ func ParseLax(file string, data []byte, fix VersionFixer) (*File, error) { return parseToFile(file, data, fix, false) } -func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (*File, error) { +func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (parsed *File, err error) { fs, err := parse(file, data) if err != nil { return nil, err @@ -150,8 +156,18 @@ func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (*File f := &File{ Syntax: fs, } - var errs ErrorList + + // fix versions in retract directives after the file is parsed. + // We need the module path to fix versions, and it might be at the end. + defer func() { + oldLen := len(errs) + f.fixRetract(fix, &errs) + if len(errs) > oldLen { + parsed, err = nil, errs + } + }() + for _, x := range fs.Stmt { switch x := x.(type) { case *Line: @@ -370,7 +386,7 @@ func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, a case "retract": rationale := parseRetractRationale(block, line) - vi, err := parseVersionInterval(verb, &args, fix) + vi, err := parseVersionInterval(verb, "", &args, dontFixRetract) if err != nil { if strict { wrapError(err) @@ -397,6 +413,47 @@ func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, a } } +// fixRetract applies fix to each retract directive in f, appending any errors +// to errs. +// +// Most versions are fixed as we parse the file, but for retract directives, +// the relevant module path is the one specified with the module directive, +// and that might appear at the end of the file (or not at all). +func (f *File) fixRetract(fix VersionFixer, errs *ErrorList) { + if fix == nil { + return + } + path := "" + if f.Module != nil { + path = f.Module.Mod.Path + } + var r *Retract + wrapError := func(err error) { + *errs = append(*errs, Error{ + Filename: f.Syntax.Name, + Pos: r.Syntax.Start, + Err: err, + }) + } + + for _, r = range f.Retract { + if path == "" { + wrapError(errors.New("no module directive found, so retract cannot be used")) + return // only print the first one of these + } + + args := r.Syntax.Token + if args[0] == "retract" { + args = args[1:] + } + vi, err := parseVersionInterval("retract", path, &args, fix) + if err != nil { + wrapError(err) + } + r.VersionInterval = vi + } +} + // isIndirect reports whether line has a "// indirect" comment, // meaning it is in go.mod only for its effect on indirect dependencies, // so that it can be dropped entirely once the effective version of the @@ -491,13 +548,13 @@ func AutoQuote(s string) string { return s } -func parseVersionInterval(verb string, args *[]string, fix VersionFixer) (VersionInterval, error) { +func parseVersionInterval(verb string, path string, args *[]string, fix VersionFixer) (VersionInterval, error) { toks := *args if len(toks) == 0 || toks[0] == "(" { return VersionInterval{}, fmt.Errorf("expected '[' or version") } if toks[0] != "[" { - v, err := parseVersion(verb, "", &toks[0], fix) + v, err := parseVersion(verb, path, &toks[0], fix) if err != nil { return VersionInterval{}, err } @@ -509,7 +566,7 @@ func parseVersionInterval(verb string, args *[]string, fix VersionFixer) (Versio if len(toks) == 0 { return VersionInterval{}, fmt.Errorf("expected version after '['") } - low, err := parseVersion(verb, "", &toks[0], fix) + low, err := parseVersion(verb, path, &toks[0], fix) if err != nil { return VersionInterval{}, err } @@ -523,7 +580,7 @@ func parseVersionInterval(verb string, args *[]string, fix VersionFixer) (Versio if len(toks) == 0 { return VersionInterval{}, fmt.Errorf("expected version after ','") } - high, err := parseVersion(verb, "", &toks[0], fix) + high, err := parseVersion(verb, path, &toks[0], fix) if err != nil { return VersionInterval{}, err } @@ -631,8 +688,7 @@ func parseVersion(verb string, path string, s *string, fix VersionFixer) (string } } if fix != nil { - var err error - t, err = fix(path, t) + fixed, err := fix(path, t) if err != nil { if err, ok := err.(*module.ModuleError); ok { return "", &Error{ @@ -643,19 +699,23 @@ func parseVersion(verb string, path string, s *string, fix VersionFixer) (string } return "", err } + t = fixed + } else { + cv := module.CanonicalVersion(t) + if cv == "" { + return "", &Error{ + Verb: verb, + ModPath: path, + Err: &module.InvalidVersionError{ + Version: t, + Err: errors.New("must be of the form v1.2.3"), + }, + } + } + t = cv } - if v := module.CanonicalVersion(t); v != "" { - *s = v - return *s, nil - } - return "", &Error{ - Verb: verb, - ModPath: path, - Err: &module.InvalidVersionError{ - Version: t, - Err: errors.New("must be of the form v1.2.3"), - }, - } + *s = t + return *s, nil } func modulePathMajor(path string) (string, error) { diff --git a/modfile/rule_test.go b/modfile/rule_test.go index 2381ee6..96ef036 100644 --- a/modfile/rule_test.go +++ b/modfile/rule_test.go @@ -7,6 +7,7 @@ package modfile import ( "bytes" "fmt" + "strings" "testing" "golang.org/x/mod/module" @@ -696,6 +697,59 @@ var addExcludeValidateVersionTests = []struct { }, } +var fixVersionTests = []struct { + desc, in, want, wantErr string + fix VersionFixer +}{ + { + desc: `require`, + in: `require example.com/m 1.0.0`, + want: `require example.com/m v1.0.0`, + fix: fixV, + }, + { + desc: `replace`, + in: `replace example.com/m 1.0.0 => example.com/m 1.1.0`, + want: `replace example.com/m v1.0.0 => example.com/m v1.1.0`, + fix: fixV, + }, + { + desc: `exclude`, + in: `exclude example.com/m 1.0.0`, + want: `exclude example.com/m v1.0.0`, + fix: fixV, + }, + { + desc: `retract_single`, + in: `module example.com/m + retract 1.0.0`, + want: `module example.com/m + retract v1.0.0`, + fix: fixV, + }, + { + desc: `retract_interval`, + in: `module example.com/m + retract [1.0.0, 1.1.0]`, + want: `module example.com/m + retract [v1.0.0, v1.1.0]`, + fix: fixV, + }, + { + desc: `retract_nomod`, + in: `retract 1.0.0`, + wantErr: `in:1: no module directive found, so retract cannot be used`, + fix: fixV, + }, +} + +func fixV(path, version string) (string, error) { + if path != "example.com/m" { + return "", fmt.Errorf("module path must be example.com/m") + } + return "v" + version, nil +} + func TestAddRequire(t *testing.T) { for _, tt := range addRequireTests { t.Run(tt.desc, func(t *testing.T) { @@ -877,3 +931,37 @@ func TestAddExcludeValidateVersion(t *testing.T) { }) } } + +func TestFixVersion(t *testing.T) { + for _, tt := range fixVersionTests { + t.Run(tt.desc, func(t *testing.T) { + inFile, err := Parse("in", []byte(tt.in), tt.fix) + if err != nil { + if tt.wantErr == "" { + t.Fatalf("unexpected error: %v", err) + } + if errMsg := err.Error(); !strings.Contains(errMsg, tt.wantErr) { + t.Fatalf("got error %q; want error containing %q", errMsg, tt.wantErr) + } + return + } + got, err := inFile.Format() + if err != nil { + t.Fatal(err) + } + + outFile, err := Parse("out", []byte(tt.want), nil) + if err != nil { + t.Fatal(err) + } + want, err := outFile.Format() + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatalf("got:\n%s\nwant:\n%s", got, want) + } + }) + } +}