Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert: allow updating expected vars/consts inside functions #244

Merged
merged 2 commits into from
Sep 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ linters-settings:
lll:
line-length: 100
maintidx:
under: 40
under: 35

issues:
exclude-use-default: false
Expand Down
71 changes: 70 additions & 1 deletion assert/assert_ext_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package assert_test

import (
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
Expand Down Expand Up @@ -56,6 +57,48 @@ expected value
expected := "const expectedTwo = `this is the new\nexpected value\n`"
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
})

t.Run("var inside function is updated when -update=true", func(t *testing.T) {
patchUpdate(t)
t.Cleanup(func() {
resetVariable(t, "expectedInsideFunc", "")
})

actual := `this is the new
expected value
for var inside function
`
expectedInsideFunc := ``

assert.Equal(t, actual, expectedInsideFunc)

raw, err := ioutil.ReadFile(fileName(t))
assert.NilError(t, err)

expected := "expectedInsideFunc := `this is the new\nexpected value\nfor var inside function\n`"
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
})

t.Run("const inside function is updated when -update=true", func(t *testing.T) {
patchUpdate(t)
t.Cleanup(func() {
resetVariable(t, "expectedConstInsideFunc", "")
})

actual := `this is the new
expected value
for const inside function
`
const expectedConstInsideFunc = ``

assert.Equal(t, actual, expectedConstInsideFunc)

raw, err := ioutil.ReadFile(fileName(t))
assert.NilError(t, err)

expected := "const expectedConstInsideFunc = `this is the new\nexpected value\nfor const inside function\n`"
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
})
}

// expectedOne is updated by running the tests with -update
Expand Down Expand Up @@ -87,7 +130,33 @@ func resetVariable(t *testing.T, varName string, value string) {
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
assert.NilError(t, err)

err = source.UpdateVariable(filename, fileset, astFile, varName, value)
var ident *ast.Ident
ast.Inspect(astFile, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.AssignStmt:
if len(v.Lhs) == 1 {
if id, ok := v.Lhs[0].(*ast.Ident); ok {
if id.Name == varName {
ident = id
return false
}
}
}

case *ast.ValueSpec:
for _, id := range v.Names {
if id.Name == varName {
ident = id
return false
}
}
}

return true
})
assert.Assert(t, ident != nil, "failed to get ident for %s", varName)

err = source.UpdateVariable(filename, fileset, astFile, ident, value)
assert.NilError(t, err, "failed to reset file")
}

Expand Down
2 changes: 1 addition & 1 deletion assert/cmd/gty-migrate-from-testify/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func convertTestifySingleArgCall(tcall call) ast.Node {
}
}

func convertTestifyAssertion(tcall call, migration migration) ast.Node { //nolint:maintidx
func convertTestifyAssertion(tcall call, migration migration) ast.Node {
imports := migration.importNames

switch tcall.selExpr.Sel.Name {
Expand Down
51 changes: 32 additions & 19 deletions internal/source/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
return ErrNotFound
}

argIndex, varName := getVarNameForExpectedValueArg(expr)
if argIndex < 0 || varName == "" {
argIndex, ident := getIdentForExpectedValueArg(expr)
if argIndex < 0 || ident == nil {
debug("no arguments started with the word 'expected': %v",
debugFormatNode{Node: &ast.CallExpr{Args: expr}})
return ErrNotFound
Expand All @@ -71,7 +71,7 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
debug("value must be type string, got %T", value)
return ErrNotFound
}
return UpdateVariable(filename, fileset, astFile, varName, strValue)
return UpdateVariable(filename, fileset, astFile, ident, strValue)
}

// UpdateVariable writes to filename the contents of astFile with the value of
Expand All @@ -80,10 +80,10 @@ func UpdateVariable(
filename string,
fileset *token.FileSet,
astFile *ast.File,
varName string,
ident *ast.Ident,
value string,
) error {
obj := astFile.Scope.Objects[varName]
obj := ident.Obj
if obj == nil {
return ErrNotFound
}
Expand All @@ -92,20 +92,33 @@ func UpdateVariable(
return ErrNotFound
}

spec, ok := obj.Decl.(*ast.ValueSpec)
if !ok {
switch decl := obj.Decl.(type) {
case *ast.ValueSpec:
if len(decl.Names) != 1 {
debug("more than one name in ast.ValueSpec")
return ErrNotFound
}

decl.Values[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}

case *ast.AssignStmt:
if len(decl.Lhs) != 1 {
debug("more than one name in ast.AssignStmt")
return ErrNotFound
}

decl.Rhs[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}

default:
debug("can only update *ast.ValueSpec, found %T", obj.Decl)
return ErrNotFound
}
if len(spec.Names) != 1 {
debug("more than one name in ast.ValueSpec")
return ErrNotFound
}

spec.Values[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}

var buf bytes.Buffer
if err := format.Node(&buf, fileset, astFile); err != nil {
Expand All @@ -125,14 +138,14 @@ func UpdateVariable(
return nil
}

func getVarNameForExpectedValueArg(expr []ast.Expr) (int, string) {
func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
for i := 1; i < 3; i++ {
switch e := expr[i].(type) {
case *ast.Ident:
if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
return i, e.Name
return i, e
}
}
}
return -1, ""
return -1, nil
}