diff --git a/linter/ginkgo_linter.go b/linter/ginkgo_linter.go index 97a2e6d..574fdfa 100644 --- a/linter/ginkgo_linter.go +++ b/linter/ginkgo_linter.go @@ -238,7 +238,7 @@ func checkAssignments(pass *analysis.Pass, list []ast.Stmt) bool { case *ast.AssignStmt: for i, val := range st.Rhs { - if _, isFunc := val.(*ast.FuncLit); !isFunc { + if !is[*ast.FuncLit](val) { if id, isIdent := st.Lhs[i].(*ast.Ident); isIdent && id.Name != "_" { reportNoFix(pass, id.Pos(), useBeforeEachTemplate, id.Name) foundSomething = true @@ -268,7 +268,7 @@ func checkAssignments(pass *analysis.Pass, list []ast.Stmt) bool { func checkAssignmentsValues(pass *analysis.Pass, names []*ast.Ident, values []ast.Expr) bool { foundSomething := false for i, val := range values { - if _, isFunc := val.(*ast.FuncLit); !isFunc { + if !is[*ast.FuncLit](val) { reportNoFix(pass, names[i].Pos(), useBeforeEachTemplate, names[i].Name) foundSomething = true } @@ -894,7 +894,7 @@ func handleEqualComparison(pass *analysis.Pass, matcher *ast.CallExpr, first ast t := pass.TypesInfo.TypeOf(first) if gotypes.IsInterface(t) { handler.ReplaceFunction(matcher, ast.NewIdent(beIdenticalTo)) - } else if _, ok := t.(*gotypes.Pointer); ok { + } else if is[*gotypes.Pointer](t) { handler.ReplaceFunction(matcher, ast.NewIdent(beIdenticalTo)) } else { handler.ReplaceFunction(matcher, ast.NewIdent(equal)) @@ -1120,7 +1120,7 @@ func checkNilError(pass *analysis.Pass, assertionExp *ast.CallExpr, handler gome } var newFuncName string - if _, ok := actualArg.(*ast.CallExpr); ok { + if is[*ast.CallExpr](actualArg) { newFuncName = succeed } else { reverseAssertionFuncLogic(assertionExp) @@ -1462,7 +1462,7 @@ func handleNilComparisonErr(pass *analysis.Pass, exp *ast.CallExpr, nilable ast. newFuncName := beNil isItError := isExprError(pass, nilable) if isItError { - if _, ok := nilable.(*ast.CallExpr); ok { + if is[*ast.CallExpr](nilable) { newFuncName = succeed } else { reverseAssertionFuncLogic(exp) @@ -1577,7 +1577,7 @@ func isComparison(pass *analysis.Pass, actualArg ast.Expr) (ast.Expr, ast.Expr, case *ast.Ident: // check if const info, ok := pass.TypesInfo.Types[realFirst] if ok { - if _, ok := info.Type.(*gotypes.Basic); ok && info.Value != nil { + if is[*gotypes.Basic](info.Type) && info.Value != nil { replace = true } } @@ -1631,8 +1631,7 @@ func isExprError(pass *analysis.Pass, expr ast.Expr) bool { func isPointer(pass *analysis.Pass, expr ast.Expr) bool { t := pass.TypesInfo.TypeOf(expr) - _, ok := t.(*gotypes.Pointer) - return ok + return is[*gotypes.Pointer](t) } func isInterface(pass *analysis.Pass, expr ast.Expr) bool { @@ -1667,3 +1666,8 @@ func checkNoAssertion(pass *analysis.Pass, expr *ast.CallExpr, handler gomegahan reportNoFix(pass, expr.Pos(), missingAssertionMessage, funcName, allowedFunction) } } + +func is[T any](x any) bool { + _, matchType := x.(T) + return matchType +}