Skip to content

Commit

Permalink
Allow for custom non-http transports (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
averche authored Sep 14, 2023
1 parent c2deccd commit 0bdf670
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 21 deletions.
33 changes: 23 additions & 10 deletions client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions client_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,47 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo
return false, nil
}

// empty returns true if t is equivalent to an empty TLSConfiguration{} object.
func (t *TLSConfiguration) empty() bool {
if t.ServerCertificate.FromFile != "" {
return false
}

if len(t.ServerCertificate.FromBytes) != 0 {
return false
}

if t.ServerCertificate.FromDirectory != "" {
return false
}

if t.ClientCertificate.FromFile != "" {
return false
}

if len(t.ClientCertificate.FromBytes) != 0 {
return false
}

if t.ClientCertificateKey.FromFile != "" {
return false
}

if len(t.ClientCertificateKey.FromBytes) != 0 {
return false
}

if t.ServerName != "" {
return false
}

if t.InsecureSkipVerify {
return false
}

return true
}

// applyTo applies the user-defined TLS configuration to the given client's
// *tls.Config pointer; it is used to configure the internal http.Client
func (from *TLSConfiguration) applyTo(to *tls.Config) error {
Expand Down
46 changes: 45 additions & 1 deletion client_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)

func Test_walkconfigurationfields(t *testing.T) {
func Test_walkConfigurationFields(t *testing.T) {
var (
actual = []string{}
expected = []string{
Expand Down Expand Up @@ -44,3 +44,47 @@ func Test_walkconfigurationfields(t *testing.T) {

require.Subset(t, actual, expected)
}

func Test_TLSConfiguration_empty(t *testing.T) {
cases := map[string]struct {
tls TLSConfiguration
expectedEmpty bool
}{
"empty": {
tls: TLSConfiguration{},
expectedEmpty: true,
},
"with-server-name": {
tls: TLSConfiguration{
ServerName: "my-server",
},
expectedEmpty: false,
},
"with-server-certificate-from-file": {
tls: TLSConfiguration{
ServerCertificate: ServerCertificateEntry{
FromFile: "./cert.pem",
},
},
expectedEmpty: false,
},
"with-server-certificate-from-bytes": {
tls: TLSConfiguration{
ServerCertificate: ServerCertificateEntry{
FromBytes: []byte{1, 1, 2, 3, 5},
},
},
expectedEmpty: false,
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
if tc.expectedEmpty {
require.True(t, tc.tls.empty())
} else {
require.False(t, tc.tls.empty())
}
})
}
}
33 changes: 23 additions & 10 deletions generate/templates/client.handlebars
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,33 @@ func newClient(configuration ClientConfiguration) (*Client, error) {
}
c.parsedBaseAddress = address

transport, ok := c.client.Transport.(*http.Transport);
if !ok {
return nil, fmt.Errorf("the configured base client's transport (%T) is not of type *http.Transport", c.client.Transport)
}

// Adjust the dial context for unix domain socket addresses.
if strings.HasPrefix(configuration.Address, "unix://") {
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://"))
// Adjust the dial context for unix domain socket addresses in the
// internal HTTP transport, if exists.
if httpTransport, ok := c.client.Transport.(*http.Transport); ok {
httpTransport.DialContext = func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://"))
}
} else {
return nil, fmt.Errorf(
"the configured base client's transport (%T) is not of type *http.Transport and cannot be used with the unix:// address",
c.client.Transport,
)
}
}

if err := configuration.TLS.applyTo(transport.TLSClientConfig); err != nil {
return nil, err
if !configuration.TLS.empty() {
// Apply TLS configuration to the internal HTTP transport, if exists.
if httpTransport, ok := c.client.Transport.(*http.Transport); ok {
if err := configuration.TLS.applyTo(httpTransport.TLSClientConfig); err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(
"the configured base client's transport (%T) is not of type *http.Transport and cannot be used with TLS configuration",
c.client.Transport,
)
}
}

{{#with apiInfo}}
Expand Down

0 comments on commit 0bdf670

Please sign in to comment.