Skip to content

Commit

Permalink
Echo.StartTLS: accept string or []byte as parameters. (#1277)
Browse files Browse the repository at this point in the history
If `certFile` or `keyFile` is `string` the values are treated as file paths.
If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
  • Loading branch information
mbana authored and vishr committed Feb 15, 2019
1 parent 3d73323 commit 8716acb
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 5 deletions.
32 changes: 27 additions & 5 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
stdLog "log"
"net"
"net/http"
Expand Down Expand Up @@ -269,6 +270,7 @@ var (
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
)

// Error handlers
Expand Down Expand Up @@ -605,20 +607,40 @@ func (e *Echo) Start(address string) error {
}

// StartTLS starts an HTTPS server.
func (e *Echo) StartTLS(address string, certFile, keyFile string) (err error) {
if certFile == "" || keyFile == "" {
return errors.New("invalid tls configuration")
// If `certFile` or `keyFile` is `string` the values are treated as file paths.
// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) {
var cert []byte
if cert, err = filepathOrContent(certFile); err != nil {
return
}

var key []byte
if key, err = filepathOrContent(keyFile); err != nil {
return
}

s := e.TLSServer
s.TLSConfig = new(tls.Config)
s.TLSConfig.Certificates = make([]tls.Certificate, 1)
s.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil {
return
}

return e.startTLS(address)
}

func filepathOrContent(fileOrContent interface{}) (content []byte, err error) {
switch v := fileOrContent.(type) {
case string:
return ioutil.ReadFile(v)
case []byte:
return v, nil
default:
return nil, ErrInvalidCertOrKeyType
}
}

// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org.
func (e *Echo) StartAutoTLS(address string) error {
s := e.TLSServer
Expand Down
67 changes: 67 additions & 0 deletions echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
stdContext "context"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
Expand All @@ -12,6 +13,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type (
Expand Down Expand Up @@ -428,6 +430,71 @@ func TestEchoStartTLS(t *testing.T) {
e.Close()
}

func TestEchoStartTLSByteString(t *testing.T) {
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
require.NoError(t, err)
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
require.NoError(t, err)

testCases := []struct {
cert interface{}
key interface{}
expectedErr error
name string
}{
{
cert: "_fixture/certs/cert.pem",
key: "_fixture/certs/key.pem",
expectedErr: nil,
name: `ValidCertAndKeyFilePath`,
},
{
cert: cert,
key: key,
expectedErr: nil,
name: `ValidCertAndKeyByteString`,
},
{
cert: cert,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
name: `InvalidKeyType`,
},
{
cert: 0,
key: key,
expectedErr: ErrInvalidCertOrKeyType,
name: `InvalidCertType`,
},
{
cert: 0,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
name: `InvalidCertAndKeyTypes`,
},
}

for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
e := New()
e.HideBanner = true

go func() {
err := e.StartTLS(":0", test.cert, test.key)
if test.expectedErr != nil {
require.EqualError(t, err, test.expectedErr.Error())
} else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers
require.NoError(t, err)
}
}()
time.Sleep(200 * time.Millisecond)

require.NoError(t, e.Close())
})
}
}

func TestEchoStartAutoTLS(t *testing.T) {
e := New()
errChan := make(chan error, 0)
Expand Down

0 comments on commit 8716acb

Please sign in to comment.