Skip to content

Commit

Permalink
fix: force access token renewal on 401
Browse files Browse the repository at this point in the history
  • Loading branch information
devgianlu committed May 3, 2024
1 parent e122f2e commit ef364db
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dealer/dealer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func NewDealer(dealerAddr librespot.GetAddressFunc, accessToken librespot.GetLog
}

func (d *Dealer) connect() error {
accessToken, err := d.accessToken()
accessToken, err := d.accessToken(false)
if err != nil {
return fmt.Errorf("failed obtaining dealer access token: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion login5.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package go_librespot

// GetLogin5TokenFunc is a function that everytime it is called returns a valid login5 access token.
type GetLogin5TokenFunc func() (string, error)
type GetLogin5TokenFunc func(force bool) (string, error)
6 changes: 3 additions & 3 deletions login5/login5.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ func (c *Login5) Login(credentials proto.Message) error {
}

func (c *Login5) AccessToken() librespot.GetLogin5TokenFunc {
return func() (string, error) {
return func(force bool) (string, error) {
c.loginOkLock.RLock()
if c.loginOk == nil {
panic("login5 not authenticated")
}

// if not expired, just return it
if c.loginOkExp.After(time.Now()) {
// if not asked to force a new token and not expired, just return it
if !force && c.loginOkExp.After(time.Now()) {
defer c.loginOkLock.RUnlock()
return c.loginOk.AccessToken, nil
}
Expand Down
26 changes: 19 additions & 7 deletions spclient/spclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ func NewSpclient(addr librespot.GetAddressFunc, accessToken librespot.GetLogin5T
}

func (c *Spclient) request(method string, path string, query url.Values, header http.Header, body []byte) (*http.Response, error) {
accessToken, err := c.accessToken()
if err != nil {
return nil, fmt.Errorf("failed obtaining spclient access token: %w", err)
}

reqUrl := c.baseUrl.JoinPath(path)
if query != nil {
reqUrl.RawQuery = query.Encode()
Expand All @@ -68,7 +63,6 @@ func (c *Spclient) request(method string, path string, query url.Values, header
}

req.Header.Set("Client-Token", c.clientToken)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))

if body != nil {
req.GetBody = func() (io.ReadCloser, error) {
Expand All @@ -77,7 +71,25 @@ func (c *Spclient) request(method string, path string, query url.Values, header
req.Body, _ = req.GetBody()
}

resp, err := backoff.RetryWithData(func() (*http.Response, error) { return c.client.Do(req) }, backoff.NewExponentialBackOff())
var forceNewToken bool
resp, err := backoff.RetryWithData(func() (*http.Response, error) {
accessToken, err := c.accessToken(forceNewToken)
if err != nil {
return nil, fmt.Errorf("failed obtaining spclient access token: %w", err)
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))

resp, err := c.client.Do(req)
if err != nil {
return nil, err
} else if resp.StatusCode == 401 {
forceNewToken = true
return nil, fmt.Errorf("unauthorized")
}

return resp, nil
}, backoff.NewExponentialBackOff())
if err != nil {
return nil, fmt.Errorf("spclient request failed: %w", err)
}
Expand Down

0 comments on commit ef364db

Please sign in to comment.