From 3b58a15a84a4dfe5d6108d9a694b0daeb37a410e Mon Sep 17 00:00:00 2001 From: Richard Gomez <32133502+rgmz@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:29:46 -0400 Subject: [PATCH] Fix GitHub enumeration & rate-limiting logic (#2625) This is a follow-up to #2379. It fixes the following issues: GitHub API calls missing rate-limit handling The fix for Refactor GitHub source #2379 (comment) inadvertently resulting in duplicate API calls --- pkg/sources/github/github.go | 71 ++++++++++++++++++++----------- pkg/sources/github/github_test.go | 53 +++++++++++++++++++---- pkg/sources/github/repo.go | 31 +++++++++++--- 3 files changed, 114 insertions(+), 41 deletions(-) diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 8cf0665cb700..757668b32f6b 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -404,31 +404,59 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli } s.repos = make([]string, 0, s.filteredRepoCache.Count()) + +RepoLoop: for _, repo := range s.filteredRepoCache.Values() { + repoCtx := context.WithValue(ctx, "repo", repo) + r, ok := repo.(string) if !ok { - ctx.Logger().Error(fmt.Errorf("type assertion failed"), "unexpected value in cache", "repo", repo) + repoCtx.Logger().Error(fmt.Errorf("type assertion failed"), "Unexpected value in cache") continue } - _, urlParts, err := getRepoURLParts(r) - if err != nil { - ctx.Logger().Error(err, "failed to parse repository URL") - continue - } + // Ensure that |s.repoInfoCache| contains an entry for |repo|. + // This compensates for differences in enumeration logic between `--org` and `--repo`. + // See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788 + if _, ok := s.repoInfoCache.get(r); !ok { + repoCtx.Logger().V(2).Info("Caching repository info") - // Ignore any gists in |s.filteredRepoCache|. - // Repos have three parts (github.com, owner, name), gists have two. - if len(urlParts) == 3 { - // Ensure that individual repos specified in --repo are cached. - // Gists should be cached elsewhere. - // https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788 - ghRepo, _, err := s.apiClient.Repositories.Get(ctx, urlParts[1], urlParts[2]) + _, urlParts, err := getRepoURLParts(r) if err != nil { - ctx.Logger().Error(err, "failed to fetch repository") + repoCtx.Logger().Error(err, "Failed to parse repository URL") continue } - s.cacheRepoInfo(ghRepo) + + if strings.EqualFold(urlParts[0], "gist.github.com") { + // Cache gist info. + for { + gistID := extractGistID(urlParts) + gist, _, err := s.apiClient.Gists.Get(repoCtx, gistID) + if s.handleRateLimit(err) { + continue + } + if err != nil { + repoCtx.Logger().Error(err, "Failed to fetch gist") + continue RepoLoop + } + s.cacheGistInfo(gist) + break + } + } else { + // Cache repository info. + for { + ghRepo, _, err := s.apiClient.Repositories.Get(repoCtx, urlParts[1], urlParts[2]) + if s.handleRateLimit(err) { + continue + } + if err != nil { + repoCtx.Logger().Error(err, "Failed to fetch repository") + continue RepoLoop + } + s.cacheRepoInfo(ghRepo) + break + } + } } s.repos = append(s.repos, r) } @@ -902,16 +930,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error { for _, gist := range gists { s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL()) - - info := repoInfo{ - owner: gist.GetOwner().GetLogin(), - } - if gist.GetPublic() { - info.visibility = source_metadatapb.Visibility_public - } else { - info.visibility = source_metadatapb.Visibility_private - } - s.repoInfoCache.put(gist.GetGitPullURL(), info) + s.cacheGistInfo(gist) } if res == nil || res.NextPage == 0 { @@ -998,7 +1017,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) { logger := s.log.WithValues("user", user) for { orgs, resp, err := s.apiClient.Organizations.List(ctx, "", orgOpts) - if handled := s.handleRateLimit(err); handled { + if s.handleRateLimit(err) { continue } if err != nil { diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index 9a8c362a5ac2..77fb5332a2c9 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -451,20 +451,17 @@ func BenchmarkEnumerateWithToken(b *testing.B) { func TestEnumerate(t *testing.T) { defer gock.Off() + // Arrange gock.New("https://api.github.com"). Get("/user"). Reply(200). JSON(map[string]string{"login": "super-secret-user"}) + // gock.New("https://api.github.com"). Get("/users/super-secret-user/repos"). Reply(200). - JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-user/super-secret-repo.git", "full_name": "super-secret-user/super-secret-repo"}}) - - gock.New("https://api.github.com"). - Get("/repos/super-secret-user/super-secret-repo"). - Reply(200). - JSON(`{"owner": {"login": "super-secret-user"}, "name": "super-secret-repo", "full_name": "super-secret-user/super-secret-repo", "has_wiki": false, "size": 1}`) + JSON(`[{"name": "super-secret-repo", "full_name": "super-secret-user/super-secret-repo", "owner": {"login": "super-secret-user"}, "clone_url": "https://github.com/super-secret-user/super-secret-repo.git", "has_wiki": false, "size": 1}]`) gock.New("https://api.github.com"). Get("/user/orgs"). @@ -483,12 +480,50 @@ func TestEnumerate(t *testing.T) { }, }) + // Manually cache a repository to ensure that enumerate + // doesn't make duplicate API calls. + // See https://github.com/trufflesecurity/trufflehog/pull/2625 + repo := func() *github.Repository { + var ( + name = "cached-repo" + fullName = "cached-user/cached-repo" + login = "cached-user" + cloneUrl = "https://github.com/cached-user/cached-repo.git" + owner = &github.User{ + Login: &login, + } + hasWiki = false + size = 1234 + ) + return &github.Repository{ + Name: &name, + FullName: &fullName, + Owner: owner, + HasWiki: &hasWiki, + Size: &size, + CloneURL: &cloneUrl, + } + }() + s.cacheRepoInfo(repo) + s.filteredRepoCache.Set(repo.GetFullName(), repo.GetCloneURL()) + + // Act _, err := s.enumerate(context.Background(), "https://api.github.com") + + // Assert assert.Nil(t, err) - assert.Equal(t, 2, s.filteredRepoCache.Count()) - ok := s.filteredRepoCache.Exists("super-secret-user/super-secret-repo") + // Enumeration found all repos. + assert.Equal(t, 3, s.filteredRepoCache.Count()) + assert.True(t, s.filteredRepoCache.Exists("super-secret-user/super-secret-repo")) + assert.True(t, s.filteredRepoCache.Exists("cached-user/cached-repo")) + assert.True(t, s.filteredRepoCache.Exists("2801a2b0523099d0614a951579d99ba9")) + // Enumeration cached all repos. + assert.Equal(t, 3, len(s.repoInfoCache.cache)) + _, ok := s.repoInfoCache.get("https://github.com/super-secret-user/super-secret-repo.git") + assert.True(t, ok) + _, ok = s.repoInfoCache.get("https://github.com/cached-user/cached-repo.git") assert.True(t, ok) - ok = s.filteredRepoCache.Exists("2801a2b0523099d0614a951579d99ba9") + _, ok = s.repoInfoCache.get("https://gist.github.com/2801a2b0523099d0614a951579d99ba9.git") assert.True(t, ok) assert.True(t, gock.IsDone()) } diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index cc217a6e8f5a..821e3345a60a 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -297,6 +297,18 @@ func (s *Source) cacheRepoInfo(r *github.Repository) { s.repoInfoCache.put(r.GetCloneURL(), info) } +func (s *Source) cacheGistInfo(g *github.Gist) { + info := repoInfo{ + owner: g.GetOwner().GetLogin(), + } + if g.GetPublic() { + info.visibility = source_metadatapb.Visibility_public + } else { + info.visibility = source_metadatapb.Visibility_private + } + s.repoInfoCache.put(g.GetGitPullURL(), info) +} + // wikiIsReachable returns true if https://github.com/$org/$repo/wiki is not redirected. // Unfortunately, this isn't 100% accurate. Some repositories have `has_wiki: true` and don't redirect their wiki page, // but still don't have a cloneable wiki. @@ -329,12 +341,19 @@ type commitQuery struct { // getDiffForFileInCommit retrieves the diff for a specified file in a commit. // If the file or its diff is not found, it returns an error. func (s *Source) getDiffForFileInCommit(ctx context.Context, query commitQuery) (string, error) { - commit, _, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) - if s.handleRateLimit(err) { - return "", fmt.Errorf("error fetching commit %s due to rate limit: %w", query.sha, err) - } - if err != nil { - return "", fmt.Errorf("error fetching commit %s: %w", query.sha, err) + var ( + commit *github.RepositoryCommit + err error + ) + for { + commit, _, err = s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) + if s.handleRateLimit(err) { + continue + } + if err != nil { + return "", fmt.Errorf("error fetching commit %s: %w", query.sha, err) + } + break } if len(commit.Files) == 0 {