From 305fd150f4e1532a50205e8a5b298bb2dc699de5 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Wed, 31 Jul 2024 17:55:07 +0900 Subject: [PATCH] feat: Slice Out-of-Bound Check (#40) * basic bound check * handling safe context (i.e., table test looping) --- formatter/fmt.go | 8 ++ formatter/general.go | 1 + formatter/slice_bound.go | 43 ++++++ internal/engine.go | 1 + internal/lints/lint_test.go | 82 ++++++++++++ internal/lints/slice_bound.go | 237 ++++++++++++++++++++++++++++++++++ internal/rule_set.go | 10 ++ internal/types/types.go | 5 +- testdata/bound/bound0.gno | 8 ++ testdata/bound/bound1.gno | 6 + testdata/bound/bound2.gno | 24 ++++ 11 files changed, 423 insertions(+), 2 deletions(-) create mode 100644 formatter/slice_bound.go create mode 100644 internal/lints/slice_bound.go create mode 100644 testdata/bound/bound0.gno create mode 100644 testdata/bound/bound1.gno create mode 100644 testdata/bound/bound2.gno diff --git a/formatter/fmt.go b/formatter/fmt.go index 79015e0..39e563a 100644 --- a/formatter/fmt.go +++ b/formatter/fmt.go @@ -16,6 +16,7 @@ const ( SimplifySliceExpr = "simplify-slice-range" CycloComplexity = "high-cyclomatic-complexity" EmitFormat = "emit-format" + SliceBound = "slice-bounds-check" ) // IssueFormatter is the interface that wraps the Format method. @@ -51,6 +52,8 @@ func getFormatter(rule string) IssueFormatter { return &CyclomaticComplexityFormatter{} case EmitFormat: return &EmitFormatFormatter{} + case SliceBound: + return &SliceBoundsCheckFormatter{} default: return &GeneralIssueFormatter{} } @@ -70,6 +73,11 @@ func buildSuggestion(result *strings.Builder, issue tt.Issue, lineStyle, suggest result.WriteString(suggestionStyle.Sprintf("Suggestion:\n")) for i, line := range strings.Split(issue.Suggestion, "\n") { lineNum := fmt.Sprintf("%d", startLine+i) + + if maxLineNumWidth < len(lineNum) { + maxLineNumWidth = len(lineNum) + } + result.WriteString(lineStyle.Sprintf("%s%s | ", padding[:maxLineNumWidth-len(lineNum)], lineNum)) result.WriteString(fmt.Sprintf("%s\n", line)) } diff --git a/formatter/general.go b/formatter/general.go index f8cf121..b84c7aa 100644 --- a/formatter/general.go +++ b/formatter/general.go @@ -13,6 +13,7 @@ const tabWidth = 8 var ( errorStyle = color.New(color.FgRed, color.Bold) + warningStyle = color.New(color.FgHiYellow, color.Bold) ruleStyle = color.New(color.FgYellow, color.Bold) fileStyle = color.New(color.FgCyan, color.Bold) lineStyle = color.New(color.FgBlue, color.Bold) diff --git a/formatter/slice_bound.go b/formatter/slice_bound.go new file mode 100644 index 0000000..2d7aac3 --- /dev/null +++ b/formatter/slice_bound.go @@ -0,0 +1,43 @@ +package formatter + +import ( + "fmt" + "strings" + + "github.com/gnoswap-labs/lint/internal" + tt "github.com/gnoswap-labs/lint/internal/types" +) + +type SliceBoundsCheckFormatter struct{} + +func (f *SliceBoundsCheckFormatter) Format( + issue tt.Issue, + snippet *internal.SourceCode, +) string { + var result strings.Builder + + maxLineNumWidth := calculateMaxLineNumWidth(issue.End.Line) + padding := strings.Repeat(" ", maxLineNumWidth+1) + + startLine := issue.Start.Line + endLine := issue.End.Line + for i := startLine; i <= endLine; i++ { + line := expandTabs(snippet.Lines[i-1]) + result.WriteString(lineStyle.Sprintf("%s%d | ", padding[:maxLineNumWidth-len(fmt.Sprintf("%d", i))], i)) + result.WriteString(line + "\n") + } + + result.WriteString(lineStyle.Sprintf("%s| ", padding)) + result.WriteString(messageStyle.Sprintf("%s\n", strings.Repeat("~", calculateMaxLineLength(snippet.Lines, startLine, endLine)))) + result.WriteString(lineStyle.Sprintf("%s| ", padding)) + result.WriteString(messageStyle.Sprintf("%s\n\n", issue.Message)) + + result.WriteString(warningStyle.Sprint("warning: ")) + if issue.Category == "index-access" { + result.WriteString("Index access without bounds checking can lead to runtime panics.\n") + } else if issue.Category == "slice-expression" { + result.WriteString("Slice expressions without proper length checks may cause unexpected behavior.\n\n") + } + + return result.String() +} diff --git a/internal/engine.go b/internal/engine.go index 6c9272a..192a3c5 100644 --- a/internal/engine.go +++ b/internal/engine.go @@ -38,6 +38,7 @@ func (e *Engine) registerDefaultRules() { &SimplifySliceExprRule{}, &UnnecessaryConversionRule{}, &LoopAllocationRule{}, + &SliceBoundCheckRule{}, &EmitFormatRule{}, &DetectCycleRule{}, &GnoSpecificRule{}, diff --git a/internal/lints/lint_test.go b/internal/lints/lint_test.go index 1d3a1b6..b984422 100644 --- a/internal/lints/lint_test.go +++ b/internal/lints/lint_test.go @@ -4,6 +4,7 @@ import ( "fmt" "go/ast" "go/parser" + "go/token" "os" "path/filepath" "runtime" @@ -508,3 +509,84 @@ func TestFormatEmitCall(t *testing.T) { }) } } + +func TestDetectSliceBoundCheck(t *testing.T) { + tests := []struct { + name string + code string + expected int + }{ + { + name: "simple bound check", + code: ` +package main +func main() { + arr := []int{1, 2, 3} + if i < len(arr) { + _ = arr[i] + } +} + `, + expected: 0, + }, + { + name: "missing bound check", + code: ` +package main +func main() { + arr := []int{1, 2, 3} + _ = arr[i] +} + `, + expected: 1, + }, + { + name: "complex condition 2", + code: ` +package main + +type Item struct { + Name string + Value int +} + +func main() { + sourceItems := []*Item{ + {"item1", 10}, + {"item2", 20}, + {"item3", 30}, + } + + destinationItems := make([]*Item, 0, len(sourceItems)) + + i := 0 + for _, item := range sourceItems { + destinationItems[i] = item + i++ + } +} +`, + expected: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, "", tt.code, 0) + if err != nil { + t.Fatalf("Failed to parse code: %v", err) + } + + issues, err := DetectSliceBoundCheck("test.go", node, fset) + for i, issue := range issues { + t.Logf("Issue %d: %v", i, issue) + } + assert.NoError(t, err) + assert.Equal( + t, tt.expected, len(issues), + "Number of detected slice bound check issues doesn't match expected", + ) + }) + } +} diff --git a/internal/lints/slice_bound.go b/internal/lints/slice_bound.go new file mode 100644 index 0000000..93f6034 --- /dev/null +++ b/internal/lints/slice_bound.go @@ -0,0 +1,237 @@ +package lints + +import ( + "fmt" + "go/ast" + "go/token" + + tt "github.com/gnoswap-labs/lint/internal/types" +) + +func DetectSliceBoundCheck(filename string, node *ast.File, fset *token.FileSet) ([]tt.Issue, error) { + var issues []tt.Issue + ast.Inspect(node, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.IndexExpr, *ast.SliceExpr: + ident, ok := getIdentForSliceOrArr(x) + if !ok { + return true + } + + if assignStmt, ok := findAssignmentForIdent(node, ident); ok { + if callExpr, ok := assignStmt.Rhs[0].(*ast.CallExpr); ok { + if fun, ok := callExpr.Fun.(*ast.Ident); ok && fun.Name == "make" { + // make로 생성된 슬라이스의 경우, 초기 길이가 0이면 위험할 수 있음 + if len(callExpr.Args) >= 2 { + if lit, ok := callExpr.Args[1].(*ast.BasicLit); ok && lit.Value == "0" { + issue := createIssue(x, ident, filename, fset) + issues = append(issues, issue) + return true + } + } + } + } + } + + if !isWithinSafeContext(node, x) && !isWithinBoundsCheck(node, x, ident) { + issue := createIssue(x, ident, filename, fset) + issues = append(issues, issue) + } + } + return true + }) + + return issues, nil +} + +func createIssue(node ast.Node, ident *ast.Ident, filename string, fset *token.FileSet) tt.Issue { + var category, message, suggestion, note string + + switch x := node.(type) { + case *ast.IndexExpr: + if isConstantIndex(x.Index) { + return tt.Issue{} + } + category = "index-access" + message = "Potential out of bounds array/slice index access" + suggestion = fmt.Sprintf("if i < len(%s) { value := %s[i] }", ident.Name, ident.Name) + note = "Always check the length of the array/slice before accessing an index to prevent runtime panics." + case *ast.SliceExpr: + category = "slice-expression" + message = "Potential out of bounds slice expression" + suggestion = fmt.Sprintf("%s = append(%s, newElement)", ident.Name, ident.Name) + note = "Consider using append() for slices to automatically handle capacity and prevent out of bounds errors." + } + + return tt.Issue{ + Rule: "slice-bounds-check", + Category: category, + Filename: filename, + Start: fset.Position(node.Pos()), + End: fset.Position(node.End()), + Message: message, + Suggestion: suggestion, + Note: note, + } +} + +// getIdentForSliceOrArr checks if the node is within an if statement +// that performs a bounds check. +func getIdentForSliceOrArr(node ast.Node) (*ast.Ident, bool) { + switch n := node.(type) { + case *ast.IndexExpr: + if ident, ok := n.X.(*ast.Ident); ok { + return ident, true + } + case *ast.SliceExpr: + if ident, ok := n.X.(*ast.Ident); ok { + return ident, true + } + } + return nil, false +} + +// isWithinBoundsCheck checks if the node is within an if statement that performs a bounds check. +func isWithinBoundsCheck(file *ast.File, node ast.Node, ident *ast.Ident) bool { + var ifStmt *ast.IfStmt + var found bool + + ast.Inspect(file, func(n ast.Node) bool { + if found { + return false + } + if x, ok := n.(*ast.IfStmt); ok && containsNode(x, node) { + ifStmt = x + found = true + } + return true + }) + + if ifStmt == nil { + return false + } + return isIfCapCheck(ifStmt, ident) +} + +// isIfCapCheck checks if the if statement condition is a capacity or length check. +func isIfCapCheck(ifStmt *ast.IfStmt, ident *ast.Ident) bool { + if binaryExpr, ok := ifStmt.Cond.(*ast.BinaryExpr); ok { + return binExprHasLenCapCall(binaryExpr, ident) + } + return false +} + +// binExprHasLenCapCall checks if a binary expression contains a length or capacity call. +func binExprHasLenCapCall(bin *ast.BinaryExpr, ident *ast.Ident) bool { + if call, ok := bin.X.(*ast.CallExpr); ok { + return isCapOrLenCallWithIdent(call, ident) + } + if call, ok := bin.Y.(*ast.CallExpr); ok { + return isCapOrLenCallWithIdent(call, ident) + } + return false +} + +// isCapOrLenCallWithIdent checks if a call expression is a len or cap call with the identifier. +func isCapOrLenCallWithIdent(call *ast.CallExpr, ident *ast.Ident) bool { + if fun, ok := call.Fun.(*ast.Ident); ok { + if (fun.Name == "len" || fun.Name == "cap") && len(call.Args) == 1 { + arg, ok := call.Args[0].(*ast.Ident) + return ok && arg.Name == ident.Name + } + } + return false +} + +// isWithinSafeContext checks if the given node is within a "safe" context. +// This function considers the following cases as "safe": +// 1. Inside a for loop with appropriate bounds checking +// 2. Inside a range loop where the loop variable is used as an index +// +// Behavior: +// 1. Traverses the entire AST to find the parent statements containing the given node. +// 2. If a range statement is found, it verifies if the node correctly uses the range variable. +// 3. If a for statement is found, it checks for appropriate length checks. +// 4. If a safe context is found, it immediately stops traversal and returns true. +// +// Notes: +// - This function may not cover all possible safe usage scenarios. +// - Complex nested structures or indirect access through function calls may be difficult to analyze accurately. +func isWithinSafeContext(file *ast.File, node ast.Node) bool { + var safeContext bool + ast.Inspect(file, func(n ast.Node) bool { + if n == node { + return false + } + switch x := n.(type) { + case *ast.RangeStmt: + if containsNode(x.Body, node) { + // inside a range statement, but check if the index expression is the range variable + if indexExpr, ok := node.(*ast.IndexExpr); ok { + if ident, ok := indexExpr.X.(*ast.Ident); ok { + // accessing a different slice/array than the range variable is not safe + safeContext = (ident.Name == x.Key.(*ast.Ident).Name) + } + } + return false + } + case *ast.ForStmt: + if isForWithLenCheck(x) && containsNode(x.Body, node) { + safeContext = true + return false + } + } + return true + }) + return safeContext +} + +func isForWithLenCheck(forStmt *ast.ForStmt) bool { + if cond, ok := forStmt.Cond.(*ast.BinaryExpr); ok { + if isBinaryExprLenCheck(cond) { + return true + } + } + return false +} + +func isConstantIndex(expr ast.Expr) bool { + switch x := expr.(type) { + case *ast.BasicLit: + return x.Kind == token.INT + case *ast.Ident: + return x.Obj != nil && x.Obj.Kind == ast.Con + } + return false +} + +func isBinaryExprLenCheck(expr *ast.BinaryExpr) bool { + if expr.Op == token.LSS || expr.Op == token.LEQ { + if call, ok := expr.Y.(*ast.CallExpr); ok { + if ident, ok := call.Fun.(*ast.Ident); ok { + return ident.Name == "len" + } + } + } + return false +} + +func findAssignmentForIdent(file *ast.File, ident *ast.Ident) (*ast.AssignStmt, bool) { + var assignStmt *ast.AssignStmt + var found bool + + ast.Inspect(file, func(n ast.Node) bool { + if assign, ok := n.(*ast.AssignStmt); ok { + for _, lhs := range assign.Lhs { + if id, ok := lhs.(*ast.Ident); ok && id.Name == ident.Name { + assignStmt = assign + found = true + return false + } + } + } + return true + }) + + return assignStmt, found +} \ No newline at end of file diff --git a/internal/rule_set.go b/internal/rule_set.go index c197c6d..3625ccd 100644 --- a/internal/rule_set.go +++ b/internal/rule_set.go @@ -91,6 +91,16 @@ func (r *EmitFormatRule) Name() string { return "emit-format" } +type SliceBoundCheckRule struct{} + +func (r *SliceBoundCheckRule) Check(filename string, node *ast.File, fset *token.FileSet) ([]tt.Issue, error) { + return lints.DetectSliceBoundCheck(filename, node, fset) +} + +func (r *SliceBoundCheckRule) Name() string { + return "slice-bounds-check" +} + // ----------------------------------------------------------------------------- type CyclomaticComplexityRule struct { diff --git a/internal/types/types.go b/internal/types/types.go index befbd55..909b503 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -5,10 +5,11 @@ import "go/token" // Issue represents a lint issue found in the code base. type Issue struct { Rule string + Category string Filename string - Start token.Position - End token.Position Message string Suggestion string Note string + Start token.Position + End token.Position } diff --git a/testdata/bound/bound0.gno b/testdata/bound/bound0.gno new file mode 100644 index 0000000..bd5f13e --- /dev/null +++ b/testdata/bound/bound0.gno @@ -0,0 +1,8 @@ +package main + +func main() { + arr := []int{1, 2, 3} + if i < len(arr) { + _ = arr[i] + } +} diff --git a/testdata/bound/bound1.gno b/testdata/bound/bound1.gno new file mode 100644 index 0000000..e705f70 --- /dev/null +++ b/testdata/bound/bound1.gno @@ -0,0 +1,6 @@ +package main + +func main() { + arr := []int{1, 2, 3} + _ = arr[i] +} diff --git a/testdata/bound/bound2.gno b/testdata/bound/bound2.gno new file mode 100644 index 0000000..d715e1b --- /dev/null +++ b/testdata/bound/bound2.gno @@ -0,0 +1,24 @@ +// ref: https://github.com/golangci/golangci-lint/discussions/4150 + +package main + +type Item struct { + Name string + Value int +} + +func main() { + sourceItems := []*Item{ + {"item1", 10}, + {"item2", 20}, + {"item3", 30}, + } + + destinationItems := make([]*Item, 0, len(sourceItems)) + + i := 0 + for _, item := range sourceItems { + destinationItems[i] = item + i++ + } +}