Skip to content

Commit

Permalink
Set http client when doing token exchange. Fixes #541
Browse files Browse the repository at this point in the history
  • Loading branch information
plorenz committed Apr 11, 2024
1 parent 48fc225 commit c0ad497
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
26 changes: 16 additions & 10 deletions edge-apis/authwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func (self *ZitiEdgeManagement) SetAllowOidcDynamicallyEnabled(allow bool) {
self.oidcDynamicallyEnabled = allow
}

func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession) (ApiSession, error) {
func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession, httpClient *http.Client) (ApiSession, error) {
switch s := apiSession.(type) {
case *ApiSessionLegacy:
params := manCurApiSession.NewGetCurrentAPISessionParams()
Expand All @@ -313,7 +313,7 @@ func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession) (ApiSes

return s, nil
case *ApiSessionOidc:
tokens, err := self.ExchangeTokens(s.OidcTokens)
tokens, err := self.ExchangeTokens(s.OidcTokens, httpClient)

if err != nil {
return nil, err
Expand All @@ -327,8 +327,8 @@ func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession) (ApiSes
return nil, errors.New("api session does not have any tokens")
}

func (self *ZitiEdgeManagement) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims]) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
return exchangeTokens(self.apiUrl.String(), curTokens)
func (self *ZitiEdgeManagement) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
return exchangeTokens(getBaseUrl(self.apiUrl), curTokens, httpClient)
}

// ZitiEdgeClient is an alias of the go-swagger generated client that allows this package to add additional
Expand Down Expand Up @@ -414,7 +414,7 @@ func (self *ZitiEdgeClient) SetAllowOidcDynamicallyEnabled(allow bool) {
self.oidcDynamicallyEnabled = allow
}

func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession) (ApiSession, error) {
func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient *http.Client) (ApiSession, error) {
switch s := apiSession.(type) {
case *ApiSessionLegacy:
params := clientApiSession.NewGetCurrentAPISessionParams()
Expand All @@ -430,7 +430,7 @@ func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession) (ApiSession

return newApiSession, nil
case *ApiSessionOidc:
tokens, err := self.ExchangeTokens(s.OidcTokens)
tokens, err := self.ExchangeTokens(s.OidcTokens, httpClient)

if err != nil {
return nil, err
Expand All @@ -444,12 +444,18 @@ func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession) (ApiSession
return nil, errors.New("api session does not have any tokens")
}

func (self *ZitiEdgeClient) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims]) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
return exchangeTokens(self.apiUrl.String(), curTokens)
func (self *ZitiEdgeClient) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
return exchangeTokens(getBaseUrl(self.apiUrl), curTokens, httpClient)
}

func exchangeTokens(issuer string, curTokens *oidc.Tokens[*oidc.IDTokenClaims]) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
te, err := tokenexchange.NewTokenExchanger(issuer)
func getBaseUrl(apiUrl *url.URL) string {
urlCopy := *apiUrl
urlCopy.Path = ""
return urlCopy.String()
}

func exchangeTokens(issuer string, curTokens *oidc.Tokens[*oidc.IDTokenClaims], client *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
te, err := tokenexchange.NewTokenExchanger(issuer, tokenexchange.WithHTTPClient(client))

if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion ziti/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (self *CtrlClient) GetCurrentApiSession() apis.ApiSession {
// Refresh will contact the controller extending the current ApiSession for legacy API Sessions
func (self *CtrlClient) Refresh() (*time.Time, error) {
if apiSession := self.GetCurrentApiSession(); apiSession != nil {
newApiSession, err := self.API.RefreshApiSession(apiSession)
newApiSession, err := self.API.RefreshApiSession(apiSession, self.HttpClient)

if err != nil {
return nil, err
Expand Down

0 comments on commit c0ad497

Please sign in to comment.