From 8716acb0ab758843852f807976c8f587ad10ff54 Mon Sep 17 00:00:00 2001 From: Mohamed Bana Date: Fri, 15 Feb 2019 17:51:54 +0000 Subject: [PATCH] `Echo.StartTLS`: accept `string` or `[]byte` as parameters. (#1277) If `certFile` or `keyFile` is `string` the values are treated as file paths. If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. --- echo.go | 32 +++++++++++++++++++++---- echo_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/echo.go b/echo.go index 6a41e5611..4f8209b73 100644 --- a/echo.go +++ b/echo.go @@ -43,6 +43,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" stdLog "log" "net" "net/http" @@ -269,6 +270,7 @@ var ( ErrRendererNotRegistered = errors.New("renderer not registered") ErrInvalidRedirectCode = errors.New("invalid redirect status code") ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") ) // Error handlers @@ -605,20 +607,40 @@ func (e *Echo) Start(address string) error { } // StartTLS starts an HTTPS server. -func (e *Echo) StartTLS(address string, certFile, keyFile string) (err error) { - if certFile == "" || keyFile == "" { - return errors.New("invalid tls configuration") +// If `certFile` or `keyFile` is `string` the values are treated as file paths. +// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. +func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { + var cert []byte + if cert, err = filepathOrContent(certFile); err != nil { + return } + + var key []byte + if key, err = filepathOrContent(keyFile); err != nil { + return + } + s := e.TLSServer s.TLSConfig = new(tls.Config) s.TLSConfig.Certificates = make([]tls.Certificate, 1) - s.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { + if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { return } + return e.startTLS(address) } +func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { + switch v := fileOrContent.(type) { + case string: + return ioutil.ReadFile(v) + case []byte: + return v, nil + default: + return nil, ErrInvalidCertOrKeyType + } +} + // StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. func (e *Echo) StartAutoTLS(address string) error { s := e.TLSServer diff --git a/echo_test.go b/echo_test.go index 3ac10a732..dec713ece 100644 --- a/echo_test.go +++ b/echo_test.go @@ -4,6 +4,7 @@ import ( "bytes" stdContext "context" "errors" + "io/ioutil" "net/http" "net/http/httptest" "reflect" @@ -12,6 +13,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type ( @@ -428,6 +430,71 @@ func TestEchoStartTLS(t *testing.T) { e.Close() } +func TestEchoStartTLSByteString(t *testing.T) { + cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := ioutil.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + + testCases := []struct { + cert interface{} + key interface{} + expectedErr error + name string + }{ + { + cert: "_fixture/certs/cert.pem", + key: "_fixture/certs/key.pem", + expectedErr: nil, + name: `ValidCertAndKeyFilePath`, + }, + { + cert: cert, + key: key, + expectedErr: nil, + name: `ValidCertAndKeyByteString`, + }, + { + cert: cert, + key: 1, + expectedErr: ErrInvalidCertOrKeyType, + name: `InvalidKeyType`, + }, + { + cert: 0, + key: key, + expectedErr: ErrInvalidCertOrKeyType, + name: `InvalidCertType`, + }, + { + cert: 0, + key: 1, + expectedErr: ErrInvalidCertOrKeyType, + name: `InvalidCertAndKeyTypes`, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.name, func(t *testing.T) { + e := New() + e.HideBanner = true + + go func() { + err := e.StartTLS(":0", test.cert, test.key) + if test.expectedErr != nil { + require.EqualError(t, err, test.expectedErr.Error()) + } else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers + require.NoError(t, err) + } + }() + time.Sleep(200 * time.Millisecond) + + require.NoError(t, e.Close()) + }) + } +} + func TestEchoStartAutoTLS(t *testing.T) { e := New() errChan := make(chan error, 0)