From 6015cdcadccddef15438dabdc027235b65f3d73f Mon Sep 17 00:00:00 2001 From: Anton Averchenkov Date: Thu, 14 Sep 2023 15:49:38 -0400 Subject: [PATCH] Allow for custom non-http transports --- client.go | 33 ++++++++++++++------ client_configuration.go | 41 +++++++++++++++++++++++++ client_configuration_test.go | 46 +++++++++++++++++++++++++++- generate/templates/client.handlebars | 33 ++++++++++++++------ 4 files changed, 132 insertions(+), 21 deletions(-) diff --git a/client.go b/client.go index 3e307b69..d0e01b2a 100644 --- a/client.go +++ b/client.go @@ -93,20 +93,33 @@ func newClient(configuration ClientConfiguration) (*Client, error) { } c.parsedBaseAddress = address - transport, ok := c.client.Transport.(*http.Transport) - if !ok { - return nil, fmt.Errorf("the configured base client's transport (%T) is not of type *http.Transport", c.client.Transport) - } - - // Adjust the dial context for unix domain socket addresses. if strings.HasPrefix(configuration.Address, "unix://") { - transport.DialContext = func(context.Context, string, string) (net.Conn, error) { - return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://")) + // Adjust the dial context for unix domain socket addresses in the + // internal HTTP transport, if exists. + if httpTransport, ok := c.client.Transport.(*http.Transport); ok { + httpTransport.DialContext = func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://")) + } + } else { + return nil, fmt.Errorf( + "the configured base client's transport (%T) is not of type *http.Transport and cannot be used with the unix:// address", + c.client.Transport, + ) } } - if err := configuration.TLS.applyTo(transport.TLSClientConfig); err != nil { - return nil, err + if !configuration.TLS.empty() { + // Apply TLS configuration to the internal HTTP transport, if exists. + if httpTransport, ok := c.client.Transport.(*http.Transport); ok { + if err := configuration.TLS.applyTo(httpTransport.TLSClientConfig); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf( + "the configured base client's transport (%T) is not of type *http.Transport and cannot be used with TLS configuration", + c.client.Transport, + ) + } } c.Auth = Auth{ diff --git a/client_configuration.go b/client_configuration.go index 873da2b1..9183ddfc 100644 --- a/client_configuration.go +++ b/client_configuration.go @@ -359,6 +359,47 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo return false, nil } +// empty returns true if t is equivalent to an empty TLSConfiguration{} object. +func (t *TLSConfiguration) empty() bool { + if t.ServerCertificate.FromFile != "" { + return false + } + + if len(t.ServerCertificate.FromBytes) != 0 { + return false + } + + if t.ServerCertificate.FromDirectory != "" { + return false + } + + if t.ClientCertificate.FromFile != "" { + return false + } + + if len(t.ClientCertificate.FromBytes) != 0 { + return false + } + + if t.ClientCertificateKey.FromFile != "" { + return false + } + + if len(t.ClientCertificateKey.FromBytes) != 0 { + return false + } + + if t.ServerName != "" { + return false + } + + if t.InsecureSkipVerify { + return false + } + + return true +} + // applyTo applies the user-defined TLS configuration to the given client's // *tls.Config pointer; it is used to configure the internal http.Client func (from *TLSConfiguration) applyTo(to *tls.Config) error { diff --git a/client_configuration_test.go b/client_configuration_test.go index 70684479..03fc0ccb 100644 --- a/client_configuration_test.go +++ b/client_configuration_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_walkconfigurationfields(t *testing.T) { +func Test_walkConfigurationFields(t *testing.T) { var ( actual = []string{} expected = []string{ @@ -44,3 +44,47 @@ func Test_walkconfigurationfields(t *testing.T) { require.Subset(t, actual, expected) } + +func Test_TLSConfiguration_empty(t *testing.T) { + cases := map[string]struct { + tls TLSConfiguration + expectedEmpty bool + }{ + "empty": { + tls: TLSConfiguration{}, + expectedEmpty: true, + }, + "with-server-name": { + tls: TLSConfiguration{ + ServerName: "my-server", + }, + expectedEmpty: false, + }, + "with-server-certificate-from-file": { + tls: TLSConfiguration{ + ServerCertificate: ServerCertificateEntry{ + FromFile: "./cert.pem", + }, + }, + expectedEmpty: false, + }, + "with-server-certificate-from-bytes": { + tls: TLSConfiguration{ + ServerCertificate: ServerCertificateEntry{ + FromBytes: []byte{1, 1, 2, 3, 5}, + }, + }, + expectedEmpty: false, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + if tc.expectedEmpty { + require.True(t, tc.tls.empty()) + } else { + require.False(t, tc.tls.empty()) + } + }) + } +} diff --git a/generate/templates/client.handlebars b/generate/templates/client.handlebars index e4a4399a..eba1c343 100644 --- a/generate/templates/client.handlebars +++ b/generate/templates/client.handlebars @@ -87,20 +87,33 @@ func newClient(configuration ClientConfiguration) (*Client, error) { } c.parsedBaseAddress = address - transport, ok := c.client.Transport.(*http.Transport); - if !ok { - return nil, fmt.Errorf("the configured base client's transport (%T) is not of type *http.Transport", c.client.Transport) - } - - // Adjust the dial context for unix domain socket addresses. if strings.HasPrefix(configuration.Address, "unix://") { - transport.DialContext = func(context.Context, string, string) (net.Conn, error) { - return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://")) + // Adjust the dial context for unix domain socket addresses in the + // internal HTTP transport, if exists. + if httpTransport, ok := c.client.Transport.(*http.Transport); ok { + httpTransport.DialContext = func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://")) + } + } else { + return nil, fmt.Errorf( + "the configured base client's transport (%T) is not of type *http.Transport and cannot be used with the unix:// address", + c.client.Transport, + ) } } - if err := configuration.TLS.applyTo(transport.TLSClientConfig); err != nil { - return nil, err + if !configuration.TLS.empty() { + // Apply TLS configuration to the internal HTTP transport, if exists. + if httpTransport, ok := c.client.Transport.(*http.Transport); ok { + if err := configuration.TLS.applyTo(httpTransport.TLSClientConfig); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf( + "the configured base client's transport (%T) is not of type *http.Transport and cannot be used with TLS configuration", + c.client.Transport, + ) + } } {{#with apiInfo}}