From c66f67f6ca5b5d451e57ffc9c2aa0228edbdcf70 Mon Sep 17 00:00:00 2001 From: aman bansal Date: Thu, 12 Nov 2020 21:57:15 +0530 Subject: [PATCH] feat(dgraph): enabling TLS config in http zero (#6691) (#6867) --- dgraph/cmd/zero/http.go | 79 +++++++++-- dgraph/cmd/zero/run.go | 8 +- go.mod | 1 + go.sum | 1 + .../all_routes_tls/all_routes_tls_test.go | 134 ++++++++++++++++++ .../all_routes_tls/docker-compose.yml | 37 +++++ tlstest/zero_https/no_tls/docker-compose.yml | 33 +++++ tlstest/zero_https/no_tls/no_tls_test.go | 58 ++++++++ 8 files changed, 337 insertions(+), 14 deletions(-) create mode 100644 tlstest/zero_https/all_routes_tls/all_routes_tls_test.go create mode 100644 tlstest/zero_https/all_routes_tls/docker-compose.yml create mode 100644 tlstest/zero_https/no_tls/docker-compose.yml create mode 100644 tlstest/zero_https/no_tls/no_tls_test.go diff --git a/dgraph/cmd/zero/http.go b/dgraph/cmd/zero/http.go index aecb76bc2b0..8d339bb3e8e 100644 --- a/dgraph/cmd/zero/http.go +++ b/dgraph/cmd/zero/http.go @@ -17,17 +17,22 @@ package zero import ( + "bufio" "context" + "crypto/tls" "fmt" + "io" "net" "net/http" "strconv" + "strings" "time" "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/x" "github.com/gogo/protobuf/jsonpb" "github.com/golang/glog" + "github.com/soheilhy/cmux" ) // intFromQueryParam checks for name as a query param, converts it to uint64 and returns it. @@ -239,23 +244,71 @@ func (st *state) pingResponse(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK")) } -func (st *state) serveHTTP(l net.Listener) { - srv := &http.Server{ - ReadTimeout: 10 * time.Second, - WriteTimeout: 600 * time.Second, - IdleTimeout: 2 * time.Minute, - } +func (st *state) startListenHttpAndHttps(l net.Listener, tlsConf *tls.Config) { + m := cmux.New(l) + startServers(m, tlsConf) go func() { defer st.zero.closer.Done() - err := srv.Serve(l) - glog.Errorf("Stopped taking more http(s) requests. Err: %v", err) - ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second) - defer cancel() - err = srv.Shutdown(ctx) - glog.Infoln("All http(s) requests finished.") + err := m.Serve() if err != nil { - glog.Errorf("Http(s) shutdown err: %v", err) + glog.Errorf("error from cmux serve: %v", err) } }() } + +func startServers(m cmux.CMux, tlsConf *tls.Config) { + httpRule := m.Match(func(r io.Reader) bool { + // no tls config is provided. http is being used. + if tlsConf == nil { + return true + } + path, ok := parseRequestPath(r) + if !ok { + // not able to parse the request. Let it be resolved via TLS + return false + } + // health endpoint will always be available over http. + // This is necessary for orchestration. It needs to be worked for + // monitoring tools which operate without authentication. + if strings.HasPrefix(path, "/health") { + return true + } + return false + }) + go startListen(httpRule) + + // if tls is enabled, make tls encryption based connections as default + if tlsConf != nil { + httpsRule := m.Match(cmux.Any()) + //this is chained listener. tls listener will decrypt the message and send it in plain text to HTTP server + go startListen(tls.NewListener(httpsRule, tlsConf)) + } +} + +func startListen(l net.Listener) { + srv := &http.Server{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 600 * time.Second, + IdleTimeout: 2 * time.Minute, + } + + err := srv.Serve(l) + glog.Errorf("Stopped taking more http(s) requests. Err: %v", err) + ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second) + defer cancel() + err = srv.Shutdown(ctx) + glog.Infoln("All http(s) requests finished.") + if err != nil { + glog.Errorf("Http(s) shutdown err: %v", err) + } +} + +func parseRequestPath(r io.Reader) (path string, ok bool) { + request, err := http.ReadRequest(bufio.NewReader(r)) + if err != nil { + return "", false + } + + return request.URL.Path, true +} diff --git a/dgraph/cmd/zero/run.go b/dgraph/cmd/zero/run.go index b188553f546..8111251c5fc 100644 --- a/dgraph/cmd/zero/run.go +++ b/dgraph/cmd/zero/run.go @@ -104,6 +104,10 @@ instances to achieve high-availability. " exporter does not support annotation logs and would discard them.") flag.Bool("ludicrous_mode", false, "Run zero in ludicrous mode") flag.String("enterprise_license", "", "Path to the enterprise license file.") + // TLS configurations + flag.String("tls_dir", "", "Path to directory that has TLS certificates and keys.") + flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.") + flag.String("tls_client_auth", "VERIFYIFGIVEN", "Enable TLS client authentication") // Cache flags flag.Int64("cache_mb", 0, "Total size of cache (in MB) to be used in zero.") @@ -310,7 +314,9 @@ func run() { // Initialize the servers. var st state st.serveGRPC(grpcListener, store) - st.serveHTTP(httpListener) + tlsCfg, err := x.LoadServerTLSConfig(Zero.Conf, "node.crt", "node.key") + x.Check(err) + st.startListenHttpAndHttps(httpListener, tlsCfg) http.HandleFunc("/health", st.pingResponse) http.HandleFunc("/state", st.getState) diff --git a/go.mod b/go.mod index 3d079267c9a..e16e5d4bb80 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 // indirect github.com/prometheus/common v0.4.1 // indirect github.com/prometheus/procfs v0.0.0-20190517135640-51af30a78b0e // indirect + github.com/soheilhy/cmux v0.1.4 github.com/spf13/cast v1.3.0 github.com/spf13/cobra v0.0.5 github.com/spf13/pflag v1.0.3 diff --git a/go.sum b/go.sum index 01091828463..8743649250f 100644 --- a/go.sum +++ b/go.sum @@ -383,6 +383,7 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1 github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= diff --git a/tlstest/zero_https/all_routes_tls/all_routes_tls_test.go b/tlstest/zero_https/all_routes_tls/all_routes_tls_test.go new file mode 100644 index 00000000000..1f58b2d45f1 --- /dev/null +++ b/tlstest/zero_https/all_routes_tls/all_routes_tls_test.go @@ -0,0 +1,134 @@ +package all_routes_tls + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +type testCase struct { + url string + statusCode int + response string +} + +var testCasesHttp = []testCase{ + { + url: "http://localhost:6180/health", + response: "OK", + statusCode: 200, + }, + { + url: "http://localhost:6180/state", + response: "Client sent an HTTP request to an HTTPS server.\n", + statusCode: 400, + }, + { + url: "http://localhost:6180/removeNode?id=2&group=0", + response: "Client sent an HTTP request to an HTTPS server.\n", + statusCode: 400, + }, +} + +func TestZeroWithAllRoutesTLSWithHTTPClient(t *testing.T) { + client := http.Client{ + Timeout: time.Second * 10, + } + defer client.CloseIdleConnections() + for _, test := range testCasesHttp { + request, err := http.NewRequest("GET", test.url, nil) + require.NoError(t, err) + do, err := client.Do(request) + require.NoError(t, err) + if do != nil && do.StatusCode != test.statusCode { + t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode) + } + + body := readResponseBody(t, do) + if test.response != string(body) { + t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response) + } + } +} + +var testCasesHttps = []testCase{ + { + url: "https://localhost:6180/health", + response: "OK", + statusCode: 200, + }, + { + url: "https://localhost:6180/state", + response: "\"id\":\"1\",\"addr\":\"zero1:5180\",\"leader\":true", + statusCode: 200, + }, +} + +func TestZeroWithAllRoutesTLSWithTLSClient(t *testing.T) { + pool, err := generateCertPool("../../tls/ca.crt", true) + require.NoError(t, err) + + tlsCfg := &tls.Config{RootCAs: pool, ServerName: "localhost", InsecureSkipVerify: true} + tr := &http.Transport{ + IdleConnTimeout: 30 * time.Second, + DisableCompression: true, + TLSClientConfig: tlsCfg, + } + client := http.Client{ + Transport: tr, + } + + defer client.CloseIdleConnections() + for _, test := range testCasesHttps { + request, err := http.NewRequest("GET", test.url, nil) + require.NoError(t, err) + do, err := client.Do(request) + require.NoError(t, err) + if do != nil && do.StatusCode != test.statusCode { + t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode) + } + + body := readResponseBody(t, do) + if !strings.Contains(string(body), test.response) { + t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response) + } + } +} + +func readResponseBody(t *testing.T, do *http.Response) []byte { + defer func() { _ = do.Body.Close() }() + body, err := ioutil.ReadAll(do.Body) + require.NoError(t, err) + return body +} + +func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) { + var pool *x509.CertPool + if useSystemCA { + var err error + if pool, err = x509.SystemCertPool(); err != nil { + return nil, err + } + } else { + pool = x509.NewCertPool() + } + + if len(certPath) > 0 { + caFile, err := ioutil.ReadFile(certPath) + if err != nil { + return nil, err + } + if !pool.AppendCertsFromPEM(caFile) { + return nil, errors.Errorf("error reading CA file %q", certPath) + } + } + + return pool, nil +} diff --git a/tlstest/zero_https/all_routes_tls/docker-compose.yml b/tlstest/zero_https/all_routes_tls/docker-compose.yml new file mode 100644 index 00000000000..0ef45e50f4a --- /dev/null +++ b/tlstest/zero_https/all_routes_tls/docker-compose.yml @@ -0,0 +1,37 @@ +version: "3.5" +services: + alpha1: + image: dgraph/dgraph:latest + container_name: alpha1 + working_dir: /data/alpha1 + labels: + cluster: test + ports: + - 8180:8180 + - 9180:9180 + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + command: /gobin/dgraph alpha -o 100 --my=alpha1:7180 --zero=zero1:5180 --logtostderr -v=2 --whitelist=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16 + zero1: + image: dgraph/dgraph:latest + container_name: zero1 + working_dir: /data/zero1 + labels: + cluster: test + ports: + - 5180:5180 + - 6180:6180 + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + - type: bind + source: ../../tls + target: /dgraph-tls + read_only: true + command: /gobin/dgraph zero -o 100 --idx=1 --my=zero1:5180 --tls_dir /dgraph-tls -v=2 --bindall +volumes: {} \ No newline at end of file diff --git a/tlstest/zero_https/no_tls/docker-compose.yml b/tlstest/zero_https/no_tls/docker-compose.yml new file mode 100644 index 00000000000..af0c0ba8277 --- /dev/null +++ b/tlstest/zero_https/no_tls/docker-compose.yml @@ -0,0 +1,33 @@ +version: "3.5" +services: + alpha1: + image: dgraph/dgraph:latest + container_name: alpha1 + working_dir: /data/alpha1 + labels: + cluster: test + ports: + - 8180:8180 + - 9180:9180 + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + command: /gobin/dgraph alpha -o 100 --my=alpha1:7180 --zero=zero1:5180 --logtostderr -v=2 --whitelist=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16 + zero1: + image: dgraph/dgraph:latest + container_name: zero1 + working_dir: /data/zero1 + labels: + cluster: test + ports: + - 5180:5180 + - 6180:6180 + volumes: + - type: bind + source: $GOPATH/bin + target: /gobin + read_only: true + command: /gobin/dgraph zero -o 100 --idx=1 --my=zero1:5180 -v=2 --bindall +volumes: {} \ No newline at end of file diff --git a/tlstest/zero_https/no_tls/no_tls_test.go b/tlstest/zero_https/no_tls/no_tls_test.go new file mode 100644 index 00000000000..19d7ceccb22 --- /dev/null +++ b/tlstest/zero_https/no_tls/no_tls_test.go @@ -0,0 +1,58 @@ +package no_tls + +import ( + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type testCase struct { + url string + statusCode int + response string +} + +var testCasesHttp = []testCase{ + { + url: "http://localhost:6180/health", + response: "OK", + statusCode: 200, + }, + { + url: "http://localhost:6180/state", + response: "\"id\":\"1\",\"addr\":\"zero1:5180\",\"leader\":true", + statusCode: 200, + }, +} + +func TestZeroWithNoTLS(t *testing.T) { + client := http.Client{ + Timeout: time.Second * 10, + } + defer client.CloseIdleConnections() + for _, test := range testCasesHttp { + request, err := http.NewRequest("GET", test.url, nil) + require.NoError(t, err) + do, err := client.Do(request) + require.NoError(t, err) + if do != nil && do.StatusCode != test.statusCode { + t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode) + } + + body := readResponseBody(t, do) + if !strings.Contains(string(body), test.response) { + t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response) + } + } +} + +func readResponseBody(t *testing.T, do *http.Response) []byte { + defer func() { _ = do.Body.Close() }() + body, err := ioutil.ReadAll(do.Body) + require.NoError(t, err) + return body +}