Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
Suggest another solution
Browse files Browse the repository at this point in the history
  • Loading branch information
svanharmelen committed Nov 15, 2021
1 parent b009e13 commit 48f6ed3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 69 deletions.
25 changes: 0 additions & 25 deletions context.go

This file was deleted.

76 changes: 36 additions & 40 deletions gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ const (
headerRateReset = "RateLimit-Reset"
)

// authType represents an authentication type within GitLab.
// AuthType represents an authentication type within GitLab.
//
// GitLab API docs: https://docs.gitlab.com/ce/api/
type authType int
type AuthType int

// List of available authentication types.
//
// GitLab API docs: https://docs.gitlab.com/ce/api/
const (
basicAuth authType = iota
jobToken
oAuthToken
privateToken
BasicAuth AuthType = iota
JobToken
OAuthToken
PrivateToken
)

// A Client manages communication with the GitLab API.
Expand All @@ -84,7 +84,7 @@ type Client struct {
limiter RateLimiter

// Token type used to make authenticated API calls.
authType authType
authType AuthType

// Username and password used for basix authentication.
username, password string
Expand Down Expand Up @@ -210,7 +210,7 @@ func NewClient(token string, options ...ClientOptionFunc) (*Client, error) {
if err != nil {
return nil, err
}
client.authType = privateToken
client.authType = PrivateToken
client.token = token
return client, nil
}
Expand All @@ -223,7 +223,7 @@ func NewBasicAuthClient(username, password string, options ...ClientOptionFunc)
return nil, err
}

client.authType = basicAuth
client.authType = BasicAuth
client.username = username
client.password = password

Expand All @@ -237,7 +237,7 @@ func NewJobClient(token string, options ...ClientOptionFunc) (*Client, error) {
if err != nil {
return nil, err
}
client.authType = jobToken
client.authType = JobToken
client.token = token
return client, nil
}
Expand All @@ -249,7 +249,7 @@ func NewOAuthClient(token string, options ...ClientOptionFunc) (*Client, error)
if err != nil {
return nil, err
}
client.authType = oAuthToken
client.authType = OAuthToken
client.token = token
return client, nil
}
Expand Down Expand Up @@ -606,22 +606,22 @@ const (
// populatePageValues parses the HTTP Link response headers and populates the
// various pagination link values in the Response.
func (r *Response) populatePageValues() {
if totalItems := r.Response.Header.Get(xTotal); totalItems != "" {
if totalItems := r.Header.Get(xTotal); totalItems != "" {
r.TotalItems, _ = strconv.Atoi(totalItems)
}
if totalPages := r.Response.Header.Get(xTotalPages); totalPages != "" {
if totalPages := r.Header.Get(xTotalPages); totalPages != "" {
r.TotalPages, _ = strconv.Atoi(totalPages)
}
if itemsPerPage := r.Response.Header.Get(xPerPage); itemsPerPage != "" {
if itemsPerPage := r.Header.Get(xPerPage); itemsPerPage != "" {
r.ItemsPerPage, _ = strconv.Atoi(itemsPerPage)
}
if currentPage := r.Response.Header.Get(xPage); currentPage != "" {
if currentPage := r.Header.Get(xPage); currentPage != "" {
r.CurrentPage, _ = strconv.Atoi(currentPage)
}
if nextPage := r.Response.Header.Get(xNextPage); nextPage != "" {
if nextPage := r.Header.Get(xNextPage); nextPage != "" {
r.NextPage, _ = strconv.Atoi(nextPage)
}
if previousPage := r.Response.Header.Get(xPrevPage); previousPage != "" {
if previousPage := r.Header.Get(xPrevPage); previousPage != "" {
r.PreviousPage, _ = strconv.Atoi(previousPage)
}
}
Expand All @@ -636,30 +636,20 @@ func (c *Client) Do(req *retryablehttp.Request, v interface{}) (*Response, error
// silently as the limiter will be disabled in case of an error.
c.configureLimiterOnce.Do(func() { c.configureLimiter(req.Context()) })

ctx := req.Context()

// Wait will block until the limiter can obtain a new token.
err := c.limiter.Wait(ctx)
err := c.limiter.Wait(req.Context())
if err != nil {
return nil, err
}

var token string
c.tokenLock.RLock()
token = c.token
c.tokenLock.RUnlock()

authToken := TokenFromContext(ctx)
if authToken != nil {
token = *authToken
}

// Set the correct authentication header. If using basic auth, then check
// if we already have a token and if not first authenticate and get one.
var basicAuthToken string
switch c.authType {
case basicAuth:
basicAuthToken = token
case BasicAuth:
c.tokenLock.RLock()
basicAuthToken = c.token
c.tokenLock.RUnlock()
if basicAuthToken == "" {
// If we don't have a token yet, we first need to request one.
basicAuthToken, err = c.requestOAuthToken(req.Context(), basicAuthToken)
Expand All @@ -668,20 +658,26 @@ func (c *Client) Do(req *retryablehttp.Request, v interface{}) (*Response, error
}
}
req.Header.Set("Authorization", "Bearer "+basicAuthToken)
case jobToken:
req.Header.Set("JOB-TOKEN", token)
case oAuthToken:
req.Header.Set("Authorization", "Bearer "+token)
case privateToken:
req.Header.Set("PRIVATE-TOKEN", token)
case JobToken:
if values := req.Header.Values("JOB-TOKEN"); len(values) == 0 {
req.Header.Set("JOB-TOKEN", c.token)
}
case OAuthToken:
if values := req.Header.Values("Authorization"); len(values) == 0 {
req.Header.Set("Authorization", "Bearer "+c.token)
}
case PrivateToken:
if values := req.Header.Values("PRIVATE-TOKEN"); len(values) == 0 {
req.Header.Set("PRIVATE-TOKEN", c.token)
}
}

resp, err := c.client.Do(req)
if err != nil {
return nil, err
}

if resp.StatusCode == http.StatusUnauthorized && c.authType == basicAuth {
if resp.StatusCode == http.StatusUnauthorized && c.authType == BasicAuth {
resp.Body.Close()
// The token most likely expired, so we need to request a new one and try again.
if _, err := c.requestOAuthToken(req.Context(), basicAuthToken); err != nil {
Expand Down Expand Up @@ -752,7 +748,7 @@ func parseID(id interface{}) (string, error) {

// Helper function to escape a project identifier.
func pathEscape(s string) string {
return strings.Replace(url.PathEscape(s), ".", "%2E", -1)
return strings.ReplaceAll(url.PathEscape(s), ".", "%2E")
}

// An ErrorResponse reports one or more errors caused by an API request.
Expand Down
23 changes: 19 additions & 4 deletions request_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ import (
// RequestOptionFunc can be passed to all API requests to customize the API request.
type RequestOptionFunc func(*retryablehttp.Request) error

// WithSudo takes either a username or user ID and sets the SUDO request header
// WithContext runs the request with the provided context
func WithContext(ctx context.Context) RequestOptionFunc {
return func(req *retryablehttp.Request) error {
*req = *req.WithContext(ctx)
return nil
}
}

// WithSudo takes either a username or user ID and sets the SUDO request header.
func WithSudo(uid interface{}) RequestOptionFunc {
return func(req *retryablehttp.Request) error {
user, err := parseID(uid)
Expand All @@ -37,10 +45,17 @@ func WithSudo(uid interface{}) RequestOptionFunc {
}
}

// WithContext runs the request with the provided context
func WithContext(ctx context.Context) RequestOptionFunc {
// WithToken takes a token which is then used when making this one request.
func WithToken(authType AuthType, token string) RequestOptionFunc {
return func(req *retryablehttp.Request) error {
*req = *req.WithContext(ctx)
switch authType {
case JobToken:
req.Header.Set("JOB-TOKEN", token)
case OAuthToken:
req.Header.Set("Authorization", "Bearer "+token)
case PrivateToken:
req.Header.Set("PRIVATE-TOKEN", token)
}
return nil
}
}

0 comments on commit 48f6ed3

Please sign in to comment.