diff --git a/server/events/vcs/github_client.go b/server/events/vcs/github_client.go index 31307b56c4..a1b9f19f9c 100644 --- a/server/events/vcs/github_client.go +++ b/server/events/vcs/github_client.go @@ -29,6 +29,7 @@ import ( "github.com/runatlantis/atlantis/server/events/vcs/common" "github.com/runatlantis/atlantis/server/logging" "github.com/shurcooL/githubv4" + "golang.org/x/oauth2" ) // maxCommentLength is the maximum number of chars allowed in a single comment @@ -40,6 +41,7 @@ type GithubClient struct { user string client *github.Client v4MutateClient *graphql.Client + v4QueryClient *githubv4.Client ctx context.Context logger logging.SimpleLogging } @@ -91,6 +93,16 @@ func NewGithubClient(hostname string, credentials GithubCredentials, logger logg transport, graphql.WithHeader("Accept", "application/vnd.github.queen-beryl-preview+json"), ) + token, err := credentials.GetToken() + if err != nil { + return nil, errors.Wrap(err, "Failed to get GitHub token") + } + src := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: token}, + ) + httpClient := oauth2.NewClient(context.Background(), src) + // Use the client from shurcooL's githubv4 library for queries. + v4QueryClient := githubv4.NewEnterpriseClient(graphqlURL, httpClient) user, err := credentials.GetUser() logger.Debug("GH User: %s", user) @@ -102,6 +114,7 @@ func NewGithubClient(hostname string, credentials GithubCredentials, logger logg user: user, client: client, v4MutateClient: v4MutateClient, + v4QueryClient: v4QueryClient, ctx: context.Background(), logger: logger, }, nil @@ -461,7 +474,7 @@ func (g *GithubClient) GetTeamNamesForUser(repo models.Repo, user models.User) ( var teamNames []string ctx := context.Background() for { - err := g.v4MutateClient.Query(ctx, &q, variables) + err := g.v4QueryClient.Query(ctx, &q, variables) if err != nil { return nil, err } diff --git a/server/events/vcs/github_client_test.go b/server/events/vcs/github_client_test.go index 5c9612ff11..b388cec91d 100644 --- a/server/events/vcs/github_client_test.go +++ b/server/events/vcs/github_client_test.go @@ -981,3 +981,49 @@ func TestGithubClient_Retry404Files(t *testing.T) { Ok(t, err) Equals(t, 3, numCalls) } + +// GetTeamNamesForUser returns a list of team names for a user. +func TestGithubClient_GetTeamNamesForUser(t *testing.T) { + logger := logging.NewNoopLogger(t) + // Mocked GraphQL response for two teams + resp := `{ + "data":{ + "organization": { + "teams":{ + "edges":[ + {"node":{"name":"frontend-developers"}}, + {"node":{"name":"employees"}} + ], + "pageInfo":{ + "endCursor":"Y3Vyc29yOnYyOpHOAFMoLQ==", + "hasNextPage":false + } + } + } + } + }` + testServer := httptest.NewTLSServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/api/graphql": + w.Write([]byte(resp)) // nolint: errcheck + default: + t.Errorf("got unexpected request at %q", r.RequestURI) + http.Error(w, "not found", http.StatusNotFound) + return + } + })) + testServerURL, err := url.Parse(testServer.URL) + Ok(t, err) + client, err := vcs.NewGithubClient(testServerURL.Host, &vcs.GithubUserCredentials{"user", "pass"}, logger) + Ok(t, err) + defer disableSSLVerification()() + + teams, err := client.GetTeamNamesForUser(models.Repo{ + Owner: "testrepo", + }, models.User{ + Username: "testuser", + }) + Ok(t, err) + Equals(t, []string{"frontend-developers", "employees"}, teams) +}