Skip to content

Commit

Permalink
add validation for tls private key and cert file values (#771)
Browse files Browse the repository at this point in the history
* add validation for tls private key and cert file values
  • Loading branch information
Devang Gaur authored May 18, 2021
1 parent 1d7e5b9 commit dc0b428
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 3 deletions.
20 changes: 17 additions & 3 deletions pkg/http-server/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,21 @@ func Start(port, certFile, privateKeyFile string) {
func (g *APIServer) start(routes []*Route, port, certFile, privateKeyFile string) {

var (
err error
logger = logging.GetDefaultLogger() // new logger
router = mux.NewRouter() // new router
err error
router = mux.NewRouter() // new router
)

logger.Info("registering routes...")

if privateKeyFile != "" || certFile != "" {
logger.Debugf("certfile is %s, privateKeyFile is %s", certFile, privateKeyFile)

if err := g.validateFiles(privateKeyFile, certFile); err != nil {
logger.Fatal(err)
}
}

// register all routes
for _, v := range routes {
logger.Info("Route ", v.verb, " - ", v.path)
Expand All @@ -72,19 +80,25 @@ func (g *APIServer) start(routes []*Route, port, certFile, privateKeyFile string
Handler: router,
}

message := make(chan string)
go func() {
var err error
if certFile != "" && privateKeyFile != "" {
// In case a certificate file is specified, the server support TLS
message <- "https server listening at port %v"
err = server.ListenAndServeTLS(certFile, privateKeyFile)
} else {
message <- "http server listening at port %v"
err = server.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
logger.Fatal(err)
}
}()
logger.Infof("http server listening at port %v", port)

logger.Infof(<-message, port)

close(message)

// Wait for interrupt signal to gracefully shutdown the server
quit := make(chan os.Signal, 1)
Expand Down
32 changes: 32 additions & 0 deletions pkg/http-server/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
Copyright (C) 2020 Accurics, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package httpserver

import "fmt"

func (g *APIServer) validateFiles(privateKeyFile, certFile string) error {
keylength := len(privateKeyFile)
certlength := len(certFile)

if keylength > 0 && certlength == 0 {
return fmt.Errorf("private key file provided but certficate file missing")
} else if keylength == 0 && certlength > 0 {
return fmt.Errorf("certificate file provided but private key file missing")
}

return nil
}
70 changes: 70 additions & 0 deletions pkg/http-server/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
Copyright (C) 2020 Accurics, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package httpserver

import (
"fmt"
"reflect"
"testing"
)

func TestValidateFiles(t *testing.T) {
server := APIServer{}
table := []struct {
name string
privateKeyFile string
certFile string
wantOutput interface{}
wantErr error
}{
{
name: "both file names provided",
privateKeyFile: "key",
certFile: "cert",
wantErr: nil,
},
{
name: "privatekey filename absent",
privateKeyFile: "",
certFile: "server.crt",
wantErr: fmt.Errorf("certificate file provided but private key file missing"),
},
{
name: "both file names blank",
privateKeyFile: "",
certFile: "",
wantErr: nil,
},
{
name: "cert filename absent",
privateKeyFile: "keyfile",
certFile: "",
wantErr: fmt.Errorf("private key file provided but certficate file missing"),
},
}

for _, tt := range table {
t.Run(tt.name, func(t *testing.T) {
gotErr := server.validateFiles(tt.privateKeyFile, tt.certFile)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
if tt.wantErr != nil && gotErr != nil && tt.wantErr.Error() != gotErr.Error() {
t.Errorf("error got: '%v', want: '%v'", gotErr, tt.wantErr)
}
}
})
}
}

0 comments on commit dc0b428

Please sign in to comment.