From f616278a8ba07c5a5bb6b43a541f5f081039d3e6 Mon Sep 17 00:00:00 2001 From: Azeem Shaikh Date: Tue, 22 Feb 2022 10:50:01 -0800 Subject: [PATCH] Generalize CheckIfFileExists fn (#1668) Co-authored-by: Azeem Shaikh --- checks/errors.go | 10 +- checks/fileparser/listing.go | 20 ++-- checks/fileparser/listing_test.go | 143 ++++++++++++++++----------- checks/license.go | 30 +++--- checks/raw/dependency_update_tool.go | 13 ++- checks/raw/errors.go | 3 +- checks/raw/security_policy.go | 93 +++++++---------- checks/raw/security_policy_test.go | 39 +++----- 8 files changed, 176 insertions(+), 175 deletions(-) diff --git a/checks/errors.go b/checks/errors.go index 61abf511618..e8cd0d624d0 100644 --- a/checks/errors.go +++ b/checks/errors.go @@ -18,16 +18,12 @@ import ( "errors" ) -//nolint var ( errInternalInvalidDockerFile = errors.New("invalid Dockerfile") - errInternalInvalidYamlFile = errors.New("invalid yaml file") - errInternalFilenameMatch = errors.New("filename match error") - errInternalEmptyFile = errors.New("empty file") errInvalidGitHubWorkflow = errors.New("invalid GitHub workflow") - errInternalNoReviews = errors.New("no reviews found") - errInternalNoCommits = errors.New("no commits found") - errInternalInvalidPermissions = errors.New("invalid permissions") errInternalNameCannotBeEmpty = errors.New("name cannot be empty") errInternalCheckFuncCannotBeNil = errors.New("checkFunc cannot be nil") + // TODO(#1245): these should be moved under `raw` package after migration. + errInvalidArgType = errors.New("invalid arg type") + errInvalidArgLength = errors.New("invalid arg length") ) diff --git a/checks/fileparser/listing.go b/checks/fileparser/listing.go index 0cf7bbf4cb6..e141ea4db23 100644 --- a/checks/fileparser/listing.go +++ b/checks/fileparser/listing.go @@ -168,22 +168,19 @@ func CheckFilesContentV6(shellPathFnPattern string, return nil } -// FileCbV6 is the callback. -// The bool returned indicates whether the FileCbData -// should continue iterating over files or not. -type FileCbV6 func(path string, data FileCbData) (bool, error) +// DoWhileTrueOnFilename takes a filename and optional variadic args and returns +// true if the next filename should continue to be processed. +type DoWhileTrueOnFilename func(path string, args ...interface{}) (bool, error) -// CheckIfFileExistsV6 downloads the tar of the repository and calls the onFile() to check -// for the occurrence. -func CheckIfFileExistsV6(repoClient clients.RepoClient, - onFile FileCbV6, data FileCbData) error { +// OnAllFilesDo iterates through all files returned by `repoClient` and +// calls `onFile` fn on them until `onFile` returns error or a false value. +func OnAllFilesDo(repoClient clients.RepoClient, onFile DoWhileTrueOnFilename, args ...interface{}) error { matchedFiles, err := repoClient.ListFiles(func(string) (bool, error) { return true, nil }) if err != nil { - // nolint: wrapcheck - return err + return fmt.Errorf("error during ListFiles: %w", err) } for _, filename := range matchedFiles { - continueIter, err := onFile(filename, data) + continueIter, err := onFile(filename, args...) if err != nil { return err } @@ -192,7 +189,6 @@ func CheckIfFileExistsV6(repoClient clients.RepoClient, break } } - return nil } diff --git a/checks/fileparser/listing_test.go b/checks/fileparser/listing_test.go index 5496763da93..28867103d55 100644 --- a/checks/fileparser/listing_test.go +++ b/checks/fileparser/listing_test.go @@ -24,6 +24,12 @@ import ( mockrepo "github.com/ossf/scorecard/v4/clients/mockclients" ) +var ( + errInvalidArgType = errors.New("invalid arg type") + errInvalidArgLength = errors.New("invalid arg length") + errTest = errors.New("test") +) + func TestIsTemplateFile(t *testing.T) { t.Parallel() @@ -572,57 +578,90 @@ func TestCheckFilesContent(t *testing.T) { } } -// TestCheckFilesContentV6 tests the CheckFilesContentV6 function. -func TestCheckIfFileExistsV6(t *testing.T) { +// TestOnAllFilesDo tests the OnAllFilesDo function. +// nolint:gocognit +func TestOnAllFilesDo(t *testing.T) { t.Parallel() - //nolint - type args struct { - cbReturn bool - cbwantErr bool - listFilesReturnError error + + type testArgsFn func(args ...interface{}) bool + validateCountIs := func(count int) testArgsFn { + return func(args ...interface{}) bool { + if len(args) == 0 { + return false + } + val, ok := args[0].(*int) + if !ok { + return false + } + return val != nil && *val == count + } } - //nolint + + incrementCount := func(path string, args ...interface{}) (bool, error) { + if len(args) < 1 { + return false, errInvalidArgLength + } + val, ok := args[0].(*int) + if !ok || val == nil { + return false, errInvalidArgType + } + (*val)++ + if len(args) > 1 { + maxVal, ok := args[1].(int) + if !ok { + return false, errInvalidArgType + } + if *val >= maxVal { + return false, nil + } + } + return true, nil + } + alwaysFail := func(path string, args ...interface{}) (bool, error) { + return false, errTest + } + // nolint tests := []struct { - name string - args args - wantErr bool + name string + onFile DoWhileTrueOnFilename + onFileArgs []interface{} + listFiles []string + errListFiles error + err error + testArgs testArgsFn }{ { - name: "cb true and no error", - args: args{ - cbReturn: true, - cbwantErr: false, - listFilesReturnError: nil, - }, - wantErr: false, + name: "error during ListFiles", + errListFiles: errTest, + err: errTest, + onFile: alwaysFail, }, { - name: "cb false and no error", - args: args{ - cbReturn: false, - cbwantErr: false, - listFilesReturnError: nil, - }, - wantErr: false, + name: "empty ListFiles", + listFiles: []string{}, + onFile: incrementCount, + onFileArgs: []interface{}{new(int)}, + testArgs: validateCountIs(0), }, { - name: "cb wantErr and error", - args: args{ - cbReturn: true, - cbwantErr: true, - listFilesReturnError: nil, - }, - wantErr: true, + name: "onFile true and no error", + listFiles: []string{"foo", "bar"}, + onFile: incrementCount, + onFileArgs: []interface{}{new(int)}, + testArgs: validateCountIs(2), }, { - name: "listFilesReturnError and error", - args: args{ - cbReturn: true, - cbwantErr: true, - //nolint - listFilesReturnError: errors.New("test error"), - }, - wantErr: true, + name: "onFile false and no error", + listFiles: []string{"foo", "bar"}, + onFile: incrementCount, + onFileArgs: []interface{}{new(int), 1 /*maxVal*/}, + testArgs: validateCountIs(1), + }, + { + name: "onFile has error", + listFiles: []string{"foo", "bar"}, + onFile: alwaysFail, + err: errTest, }, } for _, tt := range tests { @@ -630,23 +669,17 @@ func TestCheckIfFileExistsV6(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - x := func(path string, data FileCbData) (bool, error) { - if tt.args.cbwantErr { - //nolint - return false, errors.New("test error") - } - return tt.args.cbReturn, nil - } - ctrl := gomock.NewController(t) - mockRepo := mockrepo.NewMockRepoClient(ctrl) - mockRepo.EXPECT().ListFiles(gomock.Any()).Return([]string{"foo"}, nil).AnyTimes() + mockRepoClient := mockrepo.NewMockRepoClient(ctrl) + mockRepoClient.EXPECT().ListFiles(gomock.Any()). + Return(tt.listFiles, tt.errListFiles).AnyTimes() - err := CheckIfFileExistsV6(mockRepo, x, x) - - if (err != nil) != tt.wantErr { - t.Errorf("CheckIfFileExistsV6() error = %v, wantErr %v for %v", err, tt.wantErr, tt.name) - return + err := OnAllFilesDo(mockRepoClient, tt.onFile, tt.onFileArgs...) + if !errors.Is(err, tt.err) { + t.Errorf("OnAllFilesDo() expected error = %v, got %v", tt.err, err) + } + if tt.testArgs != nil && !tt.testArgs(tt.onFileArgs...) { + t.Error("testArgs validation failed") } }) } diff --git a/checks/license.go b/checks/license.go index ff5ee715f84..0bd19ebde3d 100644 --- a/checks/license.go +++ b/checks/license.go @@ -15,6 +15,7 @@ package checks import ( + "fmt" "regexp" "strings" @@ -105,17 +106,7 @@ func testLicenseCheck(name string) bool { func LicenseCheck(c *checker.CheckRequest) checker.CheckResult { var s string - onFile := func(name string, data fileparser.FileCbData) (bool, error) { - if checkLicense(name) { - if strData, ok := data.(*string); ok && strData != nil { - *strData = name - } - return false, nil - } - return true, nil - } - - err := fileparser.CheckIfFileExistsV6(c.RepoClient, onFile, &s) + err := fileparser.OnAllFilesDo(c.RepoClient, isLicenseFile, &s) if err != nil { return checker.CreateRuntimeErrorResult(CheckLicense, err) } @@ -130,6 +121,23 @@ func LicenseCheck(c *checker.CheckRequest) checker.CheckResult { return checker.CreateMinScoreResult(CheckLicense, "license file not detected") } +var isLicenseFile fileparser.DoWhileTrueOnFilename = func(name string, args ...interface{}) (bool, error) { + if len(args) != 1 { + return false, fmt.Errorf("isLicenseFile requires exactly one argument: %w", errInvalidArgLength) + } + s, ok := args[0].(*string) + if !ok { + return false, fmt.Errorf("isLicenseFile requires argument of type: *string: %w", errInvalidArgType) + } + if checkLicense(name) { + if s != nil { + *s = name + } + return false, nil + } + return true, nil +} + // CheckLicense to check whether the name parameter fulfill license file criteria. func checkLicense(name string) bool { for _, check := range regexChecks { diff --git a/checks/raw/dependency_update_tool.go b/checks/raw/dependency_update_tool.go index ba653d32717..a4bb2168f1d 100644 --- a/checks/raw/dependency_update_tool.go +++ b/checks/raw/dependency_update_tool.go @@ -26,7 +26,7 @@ import ( // DependencyUpdateTool is the exported name for Depdendency-Update-Tool. func DependencyUpdateTool(c clients.RepoClient) (checker.DependencyUpdateToolData, error) { var tools []checker.Tool - err := fileparser.CheckIfFileExistsV6(c, checkDependencyFileExists, &tools) + err := fileparser.OnAllFilesDo(c, checkDependencyFileExists, &tools) if err != nil { return checker.DependencyUpdateToolData{}, fmt.Errorf("%w", err) } @@ -35,11 +35,14 @@ func DependencyUpdateTool(c clients.RepoClient) (checker.DependencyUpdateToolDat return checker.DependencyUpdateToolData{Tools: tools}, nil } -func checkDependencyFileExists(name string, data fileparser.FileCbData) (bool, error) { - ptools, ok := data.(*[]checker.Tool) +var checkDependencyFileExists fileparser.DoWhileTrueOnFilename = func(name string, args ...interface{}) (bool, error) { + if len(args) != 1 { + return false, fmt.Errorf("checkDependencyFileExists requires exactly one argument: %w", errInvalidArgLength) + } + ptools, ok := args[0].(*[]checker.Tool) if !ok { - // This never happens. - panic("invalid type") + return false, fmt.Errorf( + "checkDependencyFileExists requires an argument of type: *[]checker.Tool: %w", errInvalidArgType) } switch strings.ToLower(name) { diff --git a/checks/raw/errors.go b/checks/raw/errors.go index 1da1236c0de..6fec85936d0 100644 --- a/checks/raw/errors.go +++ b/checks/raw/errors.go @@ -18,8 +18,9 @@ import ( "errors" ) -//nolint var ( errInternalCommitishNil = errors.New("commitish is nil") errInternalBranchNotFound = errors.New("branch not found") + errInvalidArgType = errors.New("invalid arg type") + errInvalidArgLength = errors.New("invalid arg length") ) diff --git a/checks/raw/security_policy.go b/checks/raw/security_policy.go index c956aef853e..4c0b73a1d98 100644 --- a/checks/raw/security_policy.go +++ b/checks/raw/security_policy.go @@ -16,6 +16,7 @@ package raw import ( "errors" + "fmt" "strings" "github.com/ossf/scorecard/v4/checker" @@ -30,39 +31,8 @@ import ( func SecurityPolicy(c *checker.CheckRequest) (checker.SecurityPolicyData, error) { // TODO: not supported for local clients. - // Check repository for repository-specific policy. - // https://docs.github.com/en/github/building-a-strong-community/creating-a-default-community-health-file. - onFile := func(name string, data fileparser.FileCbData) (bool, error) { - pfiles, ok := data.(*[]checker.File) - if !ok { - // This never happens. - panic("invalid type") - } - if strings.EqualFold(name, "security.md") || - strings.EqualFold(name, ".github/security.md") || - strings.EqualFold(name, "docs/security.md") || - strings.EqualFold(name, "security.adoc") || - strings.EqualFold(name, ".github/security.adoc") || - strings.EqualFold(name, "docs/security.adoc") { - *pfiles = append(*pfiles, checker.File{ - Path: name, - Type: checker.FileTypeSource, - Offset: checker.OffsetDefault, - }) - return false, nil - } else if isSecurityRstFound(name) { - *pfiles = append(*pfiles, checker.File{ - Path: name, - Type: checker.FileTypeSource, - Offset: checker.OffsetDefault, - }) - return false, nil - } - return true, nil - } - files := make([]checker.File, 0) - err := fileparser.CheckIfFileExistsV6(c.RepoClient, onFile, &files) + err := fileparser.OnAllFilesDo(c.RepoClient, isSecurityPolicyFile, &files) if err != nil { return checker.SecurityPolicyData{}, err } @@ -79,28 +49,8 @@ func SecurityPolicy(c *checker.CheckRequest) (checker.SecurityPolicyData, error) switch { case err == nil: defer dotGitHubClient.Close() - onFile = func(name string, data fileparser.FileCbData) (bool, error) { - pfiles, ok := data.(*[]checker.File) - if !ok { - // This never happens. - panic("invalid type") - } - if strings.EqualFold(name, "security.md") || - strings.EqualFold(name, ".github/security.md") || - strings.EqualFold(name, "docs/security.md") || - strings.EqualFold(name, "security.adoc") || - strings.EqualFold(name, ".github/security.adoc") || - strings.EqualFold(name, "docs/security.adoc") { - *pfiles = append(*pfiles, checker.File{ - Path: name, - Type: checker.FileTypeURL, - Offset: checker.OffsetDefault, - }) - return false, nil - } - return true, nil - } - err = fileparser.CheckIfFileExistsV6(dotGitHubClient, onFile, &files) + + err = fileparser.OnAllFilesDo(dotGitHubClient, isSecurityPolicyFile, &files) if err != nil { return checker.SecurityPolicyData{}, err } @@ -115,11 +65,34 @@ func SecurityPolicy(c *checker.CheckRequest) (checker.SecurityPolicyData, error) return checker.SecurityPolicyData{Files: files}, nil } -func isSecurityRstFound(name string) bool { - if strings.EqualFold(name, "doc/security.rst") { - return true - } else if strings.EqualFold(name, "docs/security.rst") { - return true +// Check repository for repository-specific policy. +// https://docs.github.com/en/github/building-a-strong-community/creating-a-default-community-health-file. +var isSecurityPolicyFile fileparser.DoWhileTrueOnFilename = func(name string, args ...interface{}) (bool, error) { + if len(args) != 1 { + return false, fmt.Errorf("isSecurityPolicyFile requires exactly one argument: %w", errInvalidArgLength) + } + pfiles, ok := args[0].(*[]checker.File) + if !ok { + return false, fmt.Errorf("isSecurityPolicyFile expects arg of type: *[]checker.File: %w", errInvalidArgType) + } + if isSecurityPolicyFilename(name) { + *pfiles = append(*pfiles, checker.File{ + Path: name, + Type: checker.FileTypeSource, + Offset: checker.OffsetDefault, + }) + return false, nil } - return false + return true, nil +} + +func isSecurityPolicyFilename(name string) bool { + return strings.EqualFold(name, "security.md") || + strings.EqualFold(name, ".github/security.md") || + strings.EqualFold(name, "docs/security.md") || + strings.EqualFold(name, "security.adoc") || + strings.EqualFold(name, ".github/security.adoc") || + strings.EqualFold(name, "docs/security.adoc") || + strings.EqualFold(name, "doc/security.rst") || + strings.EqualFold(name, "docs/security.rst") } diff --git a/checks/raw/security_policy_test.go b/checks/raw/security_policy_test.go index 95aa84d3399..75c52834763 100644 --- a/checks/raw/security_policy_test.go +++ b/checks/raw/security_policy_test.go @@ -24,44 +24,35 @@ import ( scut "github.com/ossf/scorecard/v4/utests" ) -func Test_isSecurityRstFound(t *testing.T) { +func Test_isSecurityPolicyFilename(t *testing.T) { t.Parallel() - type args struct { - name string - } tests := []struct { - name string - args args - want bool + name string + filename string + expected bool }{ { - name: "test1", - args: args{ - name: "test1", - }, - want: false, + name: "test1", + filename: "test1", + expected: false, }, { - name: "docs/security.rst", - args: args{ - name: "docs/security.rst", - }, - want: true, + name: "docs/security.rst", + filename: "docs/security.rst", + expected: true, }, { - name: "doc/security.rst", - args: args{ - name: "doc/security.rst", - }, - want: true, + name: "doc/security.rst", + filename: "doc/security.rst", + expected: true, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - if got := isSecurityRstFound(tt.args.name); got != tt.want { - t.Errorf("isSecurityRstFound() = %v, want %v for %v", got, tt.want, tt.name) + if got := isSecurityPolicyFilename(tt.filename); got != tt.expected { + t.Errorf("isSecurityPolicyFilename() = %v, want %v for %v", got, tt.expected, tt.name) } }) }