From 48f6ed3781856bae6f04ee97b7162a96fceff383 Mon Sep 17 00:00:00 2001 From: Sander van Harmelen Date: Tue, 2 Nov 2021 12:25:28 +0100 Subject: [PATCH] Suggest another solution --- context.go | 25 --------------- gitlab.go | 76 ++++++++++++++++++++++------------------------ request_options.go | 23 +++++++++++--- 3 files changed, 55 insertions(+), 69 deletions(-) delete mode 100644 context.go diff --git a/context.go b/context.go deleted file mode 100644 index 05573e821..000000000 --- a/context.go +++ /dev/null @@ -1,25 +0,0 @@ -package gitlab - -import "context" - -type key uint8 - -const optKey key = iota - -func WithToken(ctx context.Context, token string) context.Context { - return context.WithValue(ctx, optKey, token) -} - -func TokenFromContext(ctx context.Context) *string { - val := ctx.Value(optKey) - if val == nil { - return nil - } - - opt, ok := val.(string) - if !ok { - return nil - } - - return &opt -} diff --git a/gitlab.go b/gitlab.go index a0dc92054..c749f21d5 100644 --- a/gitlab.go +++ b/gitlab.go @@ -48,19 +48,19 @@ const ( headerRateReset = "RateLimit-Reset" ) -// authType represents an authentication type within GitLab. +// AuthType represents an authentication type within GitLab. // // GitLab API docs: https://docs.gitlab.com/ce/api/ -type authType int +type AuthType int // List of available authentication types. // // GitLab API docs: https://docs.gitlab.com/ce/api/ const ( - basicAuth authType = iota - jobToken - oAuthToken - privateToken + BasicAuth AuthType = iota + JobToken + OAuthToken + PrivateToken ) // A Client manages communication with the GitLab API. @@ -84,7 +84,7 @@ type Client struct { limiter RateLimiter // Token type used to make authenticated API calls. - authType authType + authType AuthType // Username and password used for basix authentication. username, password string @@ -210,7 +210,7 @@ func NewClient(token string, options ...ClientOptionFunc) (*Client, error) { if err != nil { return nil, err } - client.authType = privateToken + client.authType = PrivateToken client.token = token return client, nil } @@ -223,7 +223,7 @@ func NewBasicAuthClient(username, password string, options ...ClientOptionFunc) return nil, err } - client.authType = basicAuth + client.authType = BasicAuth client.username = username client.password = password @@ -237,7 +237,7 @@ func NewJobClient(token string, options ...ClientOptionFunc) (*Client, error) { if err != nil { return nil, err } - client.authType = jobToken + client.authType = JobToken client.token = token return client, nil } @@ -249,7 +249,7 @@ func NewOAuthClient(token string, options ...ClientOptionFunc) (*Client, error) if err != nil { return nil, err } - client.authType = oAuthToken + client.authType = OAuthToken client.token = token return client, nil } @@ -606,22 +606,22 @@ const ( // populatePageValues parses the HTTP Link response headers and populates the // various pagination link values in the Response. func (r *Response) populatePageValues() { - if totalItems := r.Response.Header.Get(xTotal); totalItems != "" { + if totalItems := r.Header.Get(xTotal); totalItems != "" { r.TotalItems, _ = strconv.Atoi(totalItems) } - if totalPages := r.Response.Header.Get(xTotalPages); totalPages != "" { + if totalPages := r.Header.Get(xTotalPages); totalPages != "" { r.TotalPages, _ = strconv.Atoi(totalPages) } - if itemsPerPage := r.Response.Header.Get(xPerPage); itemsPerPage != "" { + if itemsPerPage := r.Header.Get(xPerPage); itemsPerPage != "" { r.ItemsPerPage, _ = strconv.Atoi(itemsPerPage) } - if currentPage := r.Response.Header.Get(xPage); currentPage != "" { + if currentPage := r.Header.Get(xPage); currentPage != "" { r.CurrentPage, _ = strconv.Atoi(currentPage) } - if nextPage := r.Response.Header.Get(xNextPage); nextPage != "" { + if nextPage := r.Header.Get(xNextPage); nextPage != "" { r.NextPage, _ = strconv.Atoi(nextPage) } - if previousPage := r.Response.Header.Get(xPrevPage); previousPage != "" { + if previousPage := r.Header.Get(xPrevPage); previousPage != "" { r.PreviousPage, _ = strconv.Atoi(previousPage) } } @@ -636,30 +636,20 @@ func (c *Client) Do(req *retryablehttp.Request, v interface{}) (*Response, error // silently as the limiter will be disabled in case of an error. c.configureLimiterOnce.Do(func() { c.configureLimiter(req.Context()) }) - ctx := req.Context() - // Wait will block until the limiter can obtain a new token. - err := c.limiter.Wait(ctx) + err := c.limiter.Wait(req.Context()) if err != nil { return nil, err } - var token string - c.tokenLock.RLock() - token = c.token - c.tokenLock.RUnlock() - - authToken := TokenFromContext(ctx) - if authToken != nil { - token = *authToken - } - // Set the correct authentication header. If using basic auth, then check // if we already have a token and if not first authenticate and get one. var basicAuthToken string switch c.authType { - case basicAuth: - basicAuthToken = token + case BasicAuth: + c.tokenLock.RLock() + basicAuthToken = c.token + c.tokenLock.RUnlock() if basicAuthToken == "" { // If we don't have a token yet, we first need to request one. basicAuthToken, err = c.requestOAuthToken(req.Context(), basicAuthToken) @@ -668,12 +658,18 @@ func (c *Client) Do(req *retryablehttp.Request, v interface{}) (*Response, error } } req.Header.Set("Authorization", "Bearer "+basicAuthToken) - case jobToken: - req.Header.Set("JOB-TOKEN", token) - case oAuthToken: - req.Header.Set("Authorization", "Bearer "+token) - case privateToken: - req.Header.Set("PRIVATE-TOKEN", token) + case JobToken: + if values := req.Header.Values("JOB-TOKEN"); len(values) == 0 { + req.Header.Set("JOB-TOKEN", c.token) + } + case OAuthToken: + if values := req.Header.Values("Authorization"); len(values) == 0 { + req.Header.Set("Authorization", "Bearer "+c.token) + } + case PrivateToken: + if values := req.Header.Values("PRIVATE-TOKEN"); len(values) == 0 { + req.Header.Set("PRIVATE-TOKEN", c.token) + } } resp, err := c.client.Do(req) @@ -681,7 +677,7 @@ func (c *Client) Do(req *retryablehttp.Request, v interface{}) (*Response, error return nil, err } - if resp.StatusCode == http.StatusUnauthorized && c.authType == basicAuth { + if resp.StatusCode == http.StatusUnauthorized && c.authType == BasicAuth { resp.Body.Close() // The token most likely expired, so we need to request a new one and try again. if _, err := c.requestOAuthToken(req.Context(), basicAuthToken); err != nil { @@ -752,7 +748,7 @@ func parseID(id interface{}) (string, error) { // Helper function to escape a project identifier. func pathEscape(s string) string { - return strings.Replace(url.PathEscape(s), ".", "%2E", -1) + return strings.ReplaceAll(url.PathEscape(s), ".", "%2E") } // An ErrorResponse reports one or more errors caused by an API request. diff --git a/request_options.go b/request_options.go index b43dd39fe..75ae044a5 100644 --- a/request_options.go +++ b/request_options.go @@ -25,7 +25,15 @@ import ( // RequestOptionFunc can be passed to all API requests to customize the API request. type RequestOptionFunc func(*retryablehttp.Request) error -// WithSudo takes either a username or user ID and sets the SUDO request header +// WithContext runs the request with the provided context +func WithContext(ctx context.Context) RequestOptionFunc { + return func(req *retryablehttp.Request) error { + *req = *req.WithContext(ctx) + return nil + } +} + +// WithSudo takes either a username or user ID and sets the SUDO request header. func WithSudo(uid interface{}) RequestOptionFunc { return func(req *retryablehttp.Request) error { user, err := parseID(uid) @@ -37,10 +45,17 @@ func WithSudo(uid interface{}) RequestOptionFunc { } } -// WithContext runs the request with the provided context -func WithContext(ctx context.Context) RequestOptionFunc { +// WithToken takes a token which is then used when making this one request. +func WithToken(authType AuthType, token string) RequestOptionFunc { return func(req *retryablehttp.Request) error { - *req = *req.WithContext(ctx) + switch authType { + case JobToken: + req.Header.Set("JOB-TOKEN", token) + case OAuthToken: + req.Header.Set("Authorization", "Bearer "+token) + case PrivateToken: + req.Header.Set("PRIVATE-TOKEN", token) + } return nil } }