diff --git a/provider/github-app-token/github-app-token.go b/provider/github-app-token/github-app-token.go index d24b8ab0..ebd9ac13 100644 --- a/provider/github-app-token/github-app-token.go +++ b/provider/github-app-token/github-app-token.go @@ -16,6 +16,7 @@ import ( type githubClient interface { CreateStatus(ctx context.Context, token, owner, repo, ref string, status *github.CreateStatusRequest) (*github.CreateStatusResponse, error) + ValidateAPIURL(url string) error } const ( @@ -88,6 +89,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (h *Handler) handle(ctx context.Context, req *requestBody) (*responseBody, error) { + if err := h.github.ValidateAPIURL(req.APIURL); err != nil { + return nil, err + } if err := h.validateGitHubToken(ctx, req); err != nil { return nil, err } diff --git a/provider/github-app-token/github-app-token_test.go b/provider/github-app-token/github-app-token_test.go index a0988885..69812c7c 100644 --- a/provider/github-app-token/github-app-token_test.go +++ b/provider/github-app-token/github-app-token_test.go @@ -10,13 +10,18 @@ import ( ) type githubClientMock struct { - CreateStatusFunc func(ctx context.Context, token, owner, repo, ref string, status *github.CreateStatusRequest) (*github.CreateStatusResponse, error) + CreateStatusFunc func(ctx context.Context, token, owner, repo, ref string, status *github.CreateStatusRequest) (*github.CreateStatusResponse, error) + ValidateAPIURLFunc func(url string) error } func (c *githubClientMock) CreateStatus(ctx context.Context, token, owner, repo, ref string, status *github.CreateStatusRequest) (*github.CreateStatusResponse, error) { return c.CreateStatusFunc(ctx, token, owner, repo, ref, status) } +func (c *githubClientMock) ValidateAPIURL(url string) error { + return c.ValidateAPIURLFunc(url) +} + func TestValidateGitHubToken(t *testing.T) { h := &Handler{ github: &githubClientMock{ @@ -47,6 +52,9 @@ func TestValidateGitHubToken(t *testing.T) { }, }, nil }, + ValidateAPIURLFunc: func(url string) error { + return nil + }, }, } err := h.validateGitHubToken(context.Background(), &requestBody{ @@ -96,6 +104,9 @@ func TestValidateGitHubToken_InvalidCreator(t *testing.T) { }, }, nil }, + ValidateAPIURLFunc: func(url string) error { + return nil + }, }, } err := h.validateGitHubToken(context.Background(), &requestBody{ diff --git a/provider/github-app-token/github/github.go b/provider/github-app-token/github/github.go index 9d7cddf8..94e3be93 100644 --- a/provider/github-app-token/github/github.go +++ b/provider/github-app-token/github/github.go @@ -1,6 +1,7 @@ package github import ( + "errors" "fmt" "net" "net/http" @@ -10,7 +11,7 @@ import ( ) const ( - githubUserAgent = "actions-aws-assume-role/1.0" + githubUserAgent = "actions-github-token/1.0" defaultAPIBaseURL = "https://api.github.com" ) @@ -45,6 +46,24 @@ func NewClient(httpClient *http.Client) *Client { } } +func (c *Client) ValidateAPIURL(url string) error { + u, err := canonicalURL(url) + if err != nil { + return err + } + if u != c.baseURL { + if c.baseURL == defaultAPIBaseURL { + return errors.New( + "it looks that you use GitHub Enterprise Server, " + + "but the credential provider doesn't support it. " + + "I recommend you to build your own credential provider", + ) + } + return errors.New("your api server is not verified by the credential provider") + } + return nil +} + type UnexpectedStatusCodeError struct { StatusCode int }