From fe093da7e039a0fc0cfcf5d2ae9d642323561dd4 Mon Sep 17 00:00:00 2001 From: Phil Adams Date: Tue, 22 Mar 2022 11:43:07 -0500 Subject: [PATCH] fix: retain http.Client config when retries are enabled (#157) This commit includes changes to the BaseService functionality so that enabling automatic retries will not cause the loss of any existing http.Client instance that might be set on the BaseService. A user can now use BaseService.SetHTTPClient() to set a custom client instance with or without retries being enabled. Also, the BaseService.GetHTTPClient() method was introduced which will return the correct http.Client instance regardless of whether retries are enabled or not. --- .secrets.baseline | 32 +++--- v5/core/base_service.go | 186 ++++++++++++++++++++--------------- v5/core/base_service_test.go | 143 ++++++++++++++++++++++++--- 3 files changed, 252 insertions(+), 109 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 89a9078..aa64516 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,7 +3,7 @@ "files": "go.sum|package-lock.json|^.secrets.baseline$", "lines": null }, - "generated_at": "2022-01-31T22:47:05Z", + "generated_at": "2022-03-17T20:39:01Z", "plugins_used": [ { "name": "AWSKeyDetector" @@ -116,7 +116,7 @@ "hashed_secret": "bc2f74c22f98f7b6ffbc2f67453dbfa99bce9a32", "is_secret": false, "is_verified": false, - "line_number": 603, + "line_number": 617, "type": "Secret Keyword", "verified_result": null } @@ -126,7 +126,7 @@ "hashed_secret": "1f5e25be9b575e9f5d39c82dfd1d9f4d73f1975c", "is_secret": false, "is_verified": false, - "line_number": 1224, + "line_number": 1251, "type": "Secret Keyword", "verified_result": null }, @@ -134,7 +134,7 @@ "hashed_secret": "84ba4ce8a59ed2d6e90726d57cdc4a927d3672b2", "is_secret": false, "is_verified": false, - "line_number": 1461, + "line_number": 1488, "type": "Secret Keyword", "verified_result": null }, @@ -142,7 +142,7 @@ "hashed_secret": "62cdb7020ff920e5aa642c3d4066950dd1f01f4d", "is_secret": false, "is_verified": false, - "line_number": 1504, + "line_number": 1531, "type": "Secret Keyword", "verified_result": null }, @@ -150,7 +150,7 @@ "hashed_secret": "ec7ec9d8ff520250fd5ca955c6474c6d70022407", "is_secret": false, "is_verified": false, - "line_number": 1512, + "line_number": 1539, "type": "JSON Web Token", "verified_result": null }, @@ -158,7 +158,7 @@ "hashed_secret": "40ce4379f5763c05b71c88f9a371809fdbce6a21", "is_secret": false, "is_verified": false, - "line_number": 1606, + "line_number": 1633, "type": "Secret Keyword", "verified_result": null }, @@ -166,7 +166,7 @@ "hashed_secret": "9addbf544119efa4a64223b649750a510f0d463f", "is_secret": false, "is_verified": false, - "line_number": 1632, + "line_number": 1659, "type": "Secret Keyword", "verified_result": null } @@ -210,7 +210,7 @@ "hashed_secret": "e4f50034475acff058e17b35679f8ef1e54f86c5", "is_secret": false, "is_verified": false, - "line_number": 50, + "line_number": 51, "type": "Secret Keyword", "verified_result": null }, @@ -218,7 +218,7 @@ "hashed_secret": "edbd5e119f94badb9f99a67ac6ff4c7a5204ad61", "is_secret": false, "is_verified": false, - "line_number": 58, + "line_number": 59, "type": "Secret Keyword", "verified_result": null }, @@ -226,7 +226,7 @@ "hashed_secret": "0e08371049c621b8a686d4b53a18ada4f7d15111", "is_secret": false, "is_verified": false, - "line_number": 65, + "line_number": 66, "type": "Secret Keyword", "verified_result": null }, @@ -234,7 +234,7 @@ "hashed_secret": "1e95707b2d2cc9086c651c60bb323bb85522b334", "is_secret": false, "is_verified": false, - "line_number": 68, + "line_number": 69, "type": "Secret Keyword", "verified_result": null }, @@ -242,7 +242,7 @@ "hashed_secret": "a7189814f2b74aea88acd2a8b24bed64c9ab43dd", "is_secret": false, "is_verified": false, - "line_number": 68, + "line_number": 69, "type": "Secret Keyword", "verified_result": null }, @@ -250,7 +250,7 @@ "hashed_secret": "f2e7745f43b0ef0e2c2faf61d6c6a28be2965750", "is_secret": false, "is_verified": false, - "line_number": 70, + "line_number": 71, "type": "Secret Keyword", "verified_result": null }, @@ -258,7 +258,7 @@ "hashed_secret": "2a68d46242baf9214502d1dc240a9075a7c6ed55", "is_secret": false, "is_verified": false, - "line_number": 78, + "line_number": 79, "type": "Secret Keyword", "verified_result": null }, @@ -266,7 +266,7 @@ "hashed_secret": "333f0f8814d63e7268f80e1e65e7549137d2350c", "is_secret": false, "is_verified": false, - "line_number": 87, + "line_number": 88, "type": "Secret Keyword", "verified_result": null } diff --git a/v5/core/base_service.go b/v5/core/base_service.go index 6894ec4..687a236 100644 --- a/v5/core/base_service.go +++ b/v5/core/base_service.go @@ -227,28 +227,53 @@ func (service *BaseService) SetDefaultHeaders(headers http.Header) { service.DefaultHeaders = headers } -// SetHTTPClient updates the client handling the requests. +// SetHTTPClient will set "client" as the http.Client instance to be used +// to invoke individual HTTP requests. +// If automatic retries are currently enabled on "service", then +// "client" will be set as the embedded client instance within +// the retryable client; otherwise "client" will be stored +// directly on "service". func (service *BaseService) SetHTTPClient(client *http.Client) { setMinimumTLSVersion(client) - service.Client = client + + if isRetryableClient(service.Client) { + // If "service" is currently holding a retryable client, + // then set "client" as the embedded client used for individual requests. + tr := service.Client.Transport.(*retryablehttp.RoundTripper) + tr.Client.HTTPClient = client + } else { + // Otherwise, just hang "client" directly off the base service. + service.Client = client + } +} + +// GetHTTPClient will return the http.Client instance used +// to invoke individual HTTP requests. +// If automatic retries are enabled, the returned value will +// be the http.Client instance embedded within the retryable client. +// If automatic retries are not enabled, then the returned value +// will simply be the "Client" field of the base service. +func (service *BaseService) GetHTTPClient() *http.Client { + if isRetryableClient(service.Client) { + tr := service.Client.Transport.(*retryablehttp.RoundTripper) + return tr.Client.HTTPClient + } + return service.Client } -// DisableSSLVerification skips SSL verification. -// This function sets a new http.Client instance on the service -// and configures it to bypass verification of server certificates -// and host names, making the client susceptible to "man-in-the-middle" -// attacks. This should be used only for testing. +// DisableSSLVerification will configure the service to +// skip the verification of server certificates and hostnames. +// This will make the client susceptible to "man-in-the-middle" +// attacks. This should be used only for testing or in secure +// environments. func (service *BaseService) DisableSSLVerification() { // Make sure we have a non-nil client hanging off the BaseService. if service.Client == nil { - service.SetHTTPClient(DefaultHTTPClient()) + service.Client = DefaultHTTPClient() } - // Grab the Transport instance used to invoke requests and set it - // to skip server ssl certificate verification. - tr := getClientTransportForSSL(service.Client) - if tr != nil { - + client := service.GetHTTPClient() + if tr, ok := client.Transport.(*http.Transport); tr != nil && ok { // If no TLS config, then create a new one. if tr.TLSClientConfig == nil { tr.TLSClientConfig = &tls.Config{} // #nosec G402 @@ -262,9 +287,9 @@ func (service *BaseService) DisableSSLVerification() { // IsSSLDisabled returns true if and only if the service's http.Client instance // is configured to skip verification of server SSL certificates. func (service *BaseService) IsSSLDisabled() bool { - if service.Client != nil { - tr := getClientTransportForSSL(service.Client) - if tr != nil { + client := service.GetHTTPClient() + if client != nil { + if tr, ok := client.Transport.(*http.Transport); tr != nil && ok { if tr.TLSClientConfig != nil { return tr.TLSClientConfig.InsecureSkipVerify } @@ -273,41 +298,9 @@ func (service *BaseService) IsSSLDisabled() bool { return false } -// getClientTransportForSSL() will return the http.Transport instance -// that needs to be modified to disable SSL. This is a bit tricky -// because we have to account for the retries-enabled scenario. -func getClientTransportForSSL(client *http.Client) *http.Transport { - // "client" will be passed in as the client instance that is hanging off - // the BaseService. This could be either a "normal" http.Client instance - // to be used when retries are not enabled, or it could be a "shim" http.Client - // instance that's used when retries are enabled. - // We determine which it is by checking the client's Transport field. - // If it is a "shim" client instance then the Transport field will be an - // instance of the retryablehttp.RoundTripper struct, otherwise the Transport - // field will be an instance of http.Transport. - - // if "client" is a shim http.Client instance to support retries, then just - // change "client" to point to the real http.Client instance that is embedded inside - // the retryablehttp.Client instance, since this is the one used to invoke - // individual requests when retries are enabled. - if tr, ok := client.Transport.(*retryablehttp.RoundTripper); tr != nil && ok { - client = tr.Client.HTTPClient - } - - // Next, - if client != nil { - if tr, ok := client.Transport.(*http.Transport); tr != nil && ok { - return tr - } - } - - return nil -} - // setMinimumTLSVersion sets the minimum TLS version required by the client to TLS v1.2 func setMinimumTLSVersion(client *http.Client) { - tr := getClientTransportForSSL(client) - if tr != nil { + if tr, ok := client.Transport.(*http.Transport); tr != nil && ok { if tr.TLSClientConfig == nil { tr.TLSClientConfig = &tls.Config{} // #nosec G402 } @@ -617,9 +610,20 @@ func getErrorMessage(responseMap map[string]interface{}, statusCode int) string return http.StatusText(statusCode) } -// EnableRetries will construct a "retryable" HTTP Client with the specified -// configuration, and then set it on the service instance. -// If maxRetries and/or maxRetryInterval are specified as 0, then default values +// isRetryableClient() will return true if and only if "client" is +// an http.Client instance that is configured for automatic retries. +// A retryable client is a client whose transport is a +// retryablehttp.RoundTripper instance. +func isRetryableClient(client *http.Client) bool { + var isRetryable bool = false + if client != nil && client.Transport != nil { + _, isRetryable = client.Transport.(*retryablehttp.RoundTripper) + } + return isRetryable +} + +// EnableRetries will configure the service to perform automatic retries of failed requests. +// If "maxRetries" and/or "maxRetryInterval" are specified as 0, then default values // are used instead. // // In a scenario where retries ARE NOT enabled: @@ -646,37 +650,41 @@ func getErrorMessage(responseMap map[string]interface{}, statusCode int) string // individual requests within the retry logic // - Result: Each request is invoked such that the automatic retry logic is employed func (service *BaseService) EnableRetries(maxRetries int, maxRetryInterval time.Duration) { - // Remember whether or not SSL verification has been disabled. - isSSLDisabled := service.IsSSLDisabled() - - // Create and configure the retryable client, then set it on the service. - client := NewRetryableHTTPClient() - if maxRetries > 0 { - client.RetryMax = maxRetries - } - if maxRetryInterval > 0 { - client.RetryWaitMax = maxRetryInterval - } - service.SetHTTPClient(client.StandardClient()) + if isRetryableClient(service.Client) { + // If retries are already enabled, then we just need to adjust + // the retryable client's config using "maxRetries" and "maxRetryInterval". + tr := service.Client.Transport.(*retryablehttp.RoundTripper) + if maxRetries > 0 { + tr.Client.RetryMax = maxRetries + } + if maxRetryInterval > 0 { + tr.Client.RetryWaitMax = maxRetryInterval + } + } else { + // Otherwise, we need to create a new retryable client instance + // and hang it off the base service. + client := NewRetryableClientWithHTTPClient(service.Client) + if maxRetries > 0 { + client.RetryMax = maxRetries + } + if maxRetryInterval > 0 { + client.RetryWaitMax = maxRetryInterval + } - // If SSL verification was previously disabled, then disable it now on the new client. - if isSSLDisabled { - service.DisableSSLVerification() + // Hang the retryable client off the base service via the "shim" client. + service.Client = client.StandardClient() } } -// DisableRetries will disable automatic retries by constructing a new -// default (non-retryable) HTTP Client instance and setting it on the service. +// DisableRetries will disable automatic retries in the service. func (service *BaseService) DisableRetries() { - // Remember whether or not SSL verification has been disabled. - isSSLDisabled := service.IsSSLDisabled() - - // Set a new standard client on the service. - service.SetHTTPClient(DefaultHTTPClient()) - - // If SSL verification was previously disabled, then disable it now on the new client. - if isSSLDisabled { - service.DisableSSLVerification() + if isRetryableClient(service.Client) { + // If the current client hanging off the base service is retryable, + // then we need to get ahold of the embedded http.Client instance + // and set that on the base service and effectively remove + // the retryable client instance. + tr := service.Client.Transport.(*retryablehttp.RoundTripper) + service.Client = tr.Client.HTTPClient } } @@ -698,15 +706,33 @@ func (l *httpLogger) Printf(format string, inserts ...interface{}) { } } -// NewRetryableHTTPClient returns a new instance of go-retryablehttp.Client +// NewRetryableHTTPClient returns a new instance of a retryable client // with a default configuration that supports Go SDK usage. func NewRetryableHTTPClient() *retryablehttp.Client { + return NewRetryableClientWithHTTPClient(nil) +} + +// NewRetryableClientWithHTTPClient will return a new instance of a +// retryable client, using "httpClient" as the embedded client used to +// invoke individual requests within the retry logic. +// If "httpClient" is passed in as nil, then a default HTTP client will be +// used as the embedded client instead. +func NewRetryableClientWithHTTPClient(httpClient *http.Client) *retryablehttp.Client { client := retryablehttp.NewClient() client.Logger = &httpLogger{} client.CheckRetry = IBMCloudSDKRetryPolicy client.Backoff = IBMCloudSDKBackoffPolicy client.ErrorHandler = retryablehttp.PassthroughErrorHandler - setMinimumTLSVersion(client.HTTPClient) + + if httpClient != nil { + // If a non-nil http client was passed in, then let's use that + // as our embedded client used to invoke individual requests. + client.HTTPClient = httpClient + } else { + // Otherwise, we'll use construct a default HTTP client and use that + client.HTTPClient = DefaultHTTPClient() + } + return client } diff --git a/v5/core/base_service_test.go b/v5/core/base_service_test.go index bd3340f..a3d937e 100644 --- a/v5/core/base_service_test.go +++ b/v5/core/base_service_test.go @@ -1107,6 +1107,14 @@ func TestDisableSSLVerification(t *testing.T) { assert.False(t, service.IsSSLDisabled()) service.DisableSSLVerification() assert.True(t, service.IsSSLDisabled()) + + // Try another test while setting the service client to nil + service, _ = NewBaseService(options) + service.Client = nil + assert.False(t, service.IsSSLDisabled()) + service.DisableSSLVerification() + assert.NotNil(t, service.Client) + assert.True(t, service.IsSSLDisabled()) } func TestDisableSSLVerificationWithRetries(t *testing.T) { @@ -1134,6 +1142,23 @@ func TestDisableSSLVerificationWithRetries(t *testing.T) { assert.True(t, service.IsSSLDisabled()) service.DisableRetries() assert.True(t, service.IsSSLDisabled()) + + // Verify that we can first enable retries, then disable SSL, etc. + service, _ = NewBaseService(options) + assert.False(t, service.IsSSLDisabled()) + assert.False(t, isRetryableClient(service.Client)) + + service.EnableRetries(0, 0) + assert.False(t, service.IsSSLDisabled()) + assert.True(t, isRetryableClient(service.Client)) + + service.DisableSSLVerification() + assert.True(t, service.IsSSLDisabled()) + assert.True(t, service.IsSSLDisabled()) + + service.DisableRetries() + assert.True(t, service.IsSSLDisabled()) + assert.False(t, isRetryableClient(service.Client)) } func TestRequestBasicAuth1(t *testing.T) { @@ -1679,23 +1704,108 @@ func TestSetUserAgent(t *testing.T) { assert.Equal(t, "my-user-agent", service.UserAgent) } +func getRetriesConfig(service *BaseService) (int, time.Duration) { + if isRetryableClient(service.Client) { + tr := service.Client.Transport.(*retryablehttp.RoundTripper) + return tr.Client.RetryMax, tr.Client.RetryWaitMax + } + + return -1, -1 * time.Second +} + func TestEnableRetries(t *testing.T) { - service, err := NewBaseService( - &ServiceOptions{ - Authenticator: &NoAuthAuthenticator{}, - }) + options := &ServiceOptions{ + Authenticator: &NoAuthAuthenticator{}, + } + service, err := NewBaseService(options) assert.Nil(t, err) assert.NotNil(t, service) - assert.NotNil(t, service.Client) + // Save the existing client for comparisons below. + client := service.Client + + // Enable retries, then verify the config + // and make sure our client survived. service.EnableRetries(5, 30*time.Second) + assert.True(t, isRetryableClient(service.Client)) + maxRetries, maxInterval := getRetriesConfig(service) + assert.Equal(t, 5, maxRetries) + assert.Equal(t, 30*time.Second, maxInterval) + assert.Equal(t, client, service.GetHTTPClient()) + + // Enable retries with a different config and verify. + service.EnableRetries(6, 60*time.Second) + assert.True(t, isRetryableClient(service.Client)) + maxRetries, maxInterval = getRetriesConfig(service) + assert.Equal(t, 6, maxRetries) + assert.Equal(t, 60*time.Second, maxInterval) + assert.Equal(t, client, service.GetHTTPClient()) + + // Disable retries and make sure the original client is still there. + service.DisableRetries() + assert.False(t, isRetryableClient(service.Client)) + assert.Equal(t, client, service.GetHTTPClient()) +} + +func TestClientWithRetries(t *testing.T) { + options := &ServiceOptions{ + Authenticator: &NoAuthAuthenticator{}, + } + service, err := NewBaseService(options) + assert.Nil(t, err) + assert.NotNil(t, service) assert.NotNil(t, service.Client) - assert.NotNil(t, service.Client.Transport) + // Create a customized client and set it on the service + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS13, + InsecureSkipVerify: true, + }, + }, + } + + service.SetHTTPClient(client) + actualClient := service.GetHTTPClient() + assert.Equal(t, client, actualClient) + assert.Equal(t, *client, *actualClient) + assert.Equal(t, client, service.Client) + + // Next, enable retries and make sure the client survived. + service.EnableRetries(4, 90*time.Second) + assert.True(t, isRetryableClient(service.Client)) + actualClient = service.GetHTTPClient() + assert.Equal(t, client, actualClient) + assert.Equal(t, *client, *actualClient) + + // Finally, disable retries and make sure + // we're left with the same client instance. service.DisableRetries() + assert.False(t, isRetryableClient(service.Client)) + actualClient = service.GetHTTPClient() + assert.Equal(t, client, actualClient) + assert.Equal(t, *client, *actualClient) + assert.Equal(t, client, service.Client) + + // Create a new service and perform the steps in a different order. + service, err = NewBaseService(options) + assert.Nil(t, err) + assert.NotNil(t, service) assert.NotNil(t, service.Client) - assert.NotNil(t, service.Client.Transport) + assert.NotEqual(t, client, service.Client) + + // Enable retries. + service.EnableRetries(0, 0) + assert.True(t, isRetryableClient(service.Client)) + + // Next, set our customized client on the service + service.SetHTTPClient(client) + assert.True(t, isRetryableClient(service.Client)) + actualClient = service.GetHTTPClient() + assert.Equal(t, client, actualClient) + assert.Equal(t, *client, *actualClient) } func TestSetEnableGzipCompression(t *testing.T) { @@ -1969,6 +2079,16 @@ func TestErrorMessage(t *testing.T) { "Internal Server Error") } +func getTLSVersion(service *BaseService) int { + var tlsVersion int = -1 + client := service.GetHTTPClient() + if client != nil { + tr := client.Transport.(*http.Transport) + tlsVersion = int(tr.TLSClientConfig.MinVersion) + } + return tlsVersion +} + func TestMinSSLVersion(t *testing.T) { service, err := NewBaseService( &ServiceOptions{ @@ -1979,8 +2099,7 @@ func TestMinSSLVersion(t *testing.T) { assert.NotNil(t, service.Client) // Check the default config. - minTLS := int(getClientTransportForSSL(service.Client).TLSClientConfig.MinVersion) - assert.Equal(t, minTLS, tls.VersionTLS12) + assert.Equal(t, getTLSVersion(service), tls.VersionTLS12) // Set a insecureClient with different value. insecureClient := &http.Client{} @@ -1990,11 +2109,9 @@ func TestMinSSLVersion(t *testing.T) { }, } service.SetHTTPClient(insecureClient) - minTLS = int(getClientTransportForSSL(service.Client).TLSClientConfig.MinVersion) - assert.Equal(t, minTLS, tls.VersionTLS12) + assert.Equal(t, getTLSVersion(service), tls.VersionTLS12) // Check retryable client config. service.EnableRetries(3, 30*time.Second) - minTLS = int(getClientTransportForSSL(service.Client).TLSClientConfig.MinVersion) - assert.Equal(t, minTLS, tls.VersionTLS12) + assert.Equal(t, getTLSVersion(service), tls.VersionTLS12) }