Skip to content

Commit

Permalink
feat(dgraph): enabling TLS config in http zero (#6691) (#6867)
Browse files Browse the repository at this point in the history
  • Loading branch information
aman-bansal authored Nov 12, 2020
1 parent 046accd commit c66f67f
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 14 deletions.
79 changes: 66 additions & 13 deletions dgraph/cmd/zero/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
package zero

import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"

"github.com/dgraph-io/dgraph/protos/pb"
"github.com/dgraph-io/dgraph/x"
"github.com/gogo/protobuf/jsonpb"
"github.com/golang/glog"
"github.com/soheilhy/cmux"
)

// intFromQueryParam checks for name as a query param, converts it to uint64 and returns it.
Expand Down Expand Up @@ -239,23 +244,71 @@ func (st *state) pingResponse(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK"))
}

func (st *state) serveHTTP(l net.Listener) {
srv := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 2 * time.Minute,
}
func (st *state) startListenHttpAndHttps(l net.Listener, tlsConf *tls.Config) {
m := cmux.New(l)
startServers(m, tlsConf)

go func() {
defer st.zero.closer.Done()
err := srv.Serve(l)
glog.Errorf("Stopped taking more http(s) requests. Err: %v", err)
ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
glog.Infoln("All http(s) requests finished.")
err := m.Serve()
if err != nil {
glog.Errorf("Http(s) shutdown err: %v", err)
glog.Errorf("error from cmux serve: %v", err)
}
}()
}

func startServers(m cmux.CMux, tlsConf *tls.Config) {
httpRule := m.Match(func(r io.Reader) bool {
// no tls config is provided. http is being used.
if tlsConf == nil {
return true
}
path, ok := parseRequestPath(r)
if !ok {
// not able to parse the request. Let it be resolved via TLS
return false
}
// health endpoint will always be available over http.
// This is necessary for orchestration. It needs to be worked for
// monitoring tools which operate without authentication.
if strings.HasPrefix(path, "/health") {
return true
}
return false
})
go startListen(httpRule)

// if tls is enabled, make tls encryption based connections as default
if tlsConf != nil {
httpsRule := m.Match(cmux.Any())
//this is chained listener. tls listener will decrypt the message and send it in plain text to HTTP server
go startListen(tls.NewListener(httpsRule, tlsConf))
}
}

func startListen(l net.Listener) {
srv := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 2 * time.Minute,
}

err := srv.Serve(l)
glog.Errorf("Stopped taking more http(s) requests. Err: %v", err)
ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
glog.Infoln("All http(s) requests finished.")
if err != nil {
glog.Errorf("Http(s) shutdown err: %v", err)
}
}

func parseRequestPath(r io.Reader) (path string, ok bool) {
request, err := http.ReadRequest(bufio.NewReader(r))
if err != nil {
return "", false
}

return request.URL.Path, true
}
8 changes: 7 additions & 1 deletion dgraph/cmd/zero/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ instances to achieve high-availability.
" exporter does not support annotation logs and would discard them.")
flag.Bool("ludicrous_mode", false, "Run zero in ludicrous mode")
flag.String("enterprise_license", "", "Path to the enterprise license file.")
// TLS configurations
flag.String("tls_dir", "", "Path to directory that has TLS certificates and keys.")
flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
flag.String("tls_client_auth", "VERIFYIFGIVEN", "Enable TLS client authentication")

// Cache flags
flag.Int64("cache_mb", 0, "Total size of cache (in MB) to be used in zero.")
Expand Down Expand Up @@ -310,7 +314,9 @@ func run() {
// Initialize the servers.
var st state
st.serveGRPC(grpcListener, store)
st.serveHTTP(httpListener)
tlsCfg, err := x.LoadServerTLSConfig(Zero.Conf, "node.crt", "node.key")
x.Check(err)
st.startListenHttpAndHttps(httpListener, tlsCfg)

http.HandleFunc("/health", st.pingResponse)
http.HandleFunc("/state", st.getState)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ require (
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 // indirect
github.com/prometheus/common v0.4.1 // indirect
github.com/prometheus/procfs v0.0.0-20190517135640-51af30a78b0e // indirect
github.com/soheilhy/cmux v0.1.4
github.com/spf13/cast v1.3.0
github.com/spf13/cobra v0.0.5
github.com/spf13/pflag v1.0.3
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1
github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
Expand Down
134 changes: 134 additions & 0 deletions tlstest/zero_https/all_routes_tls/all_routes_tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package all_routes_tls

import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"
)

type testCase struct {
url string
statusCode int
response string
}

var testCasesHttp = []testCase{
{
url: "http://localhost:6180/health",
response: "OK",
statusCode: 200,
},
{
url: "http://localhost:6180/state",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
{
url: "http://localhost:6180/removeNode?id=2&group=0",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
}

func TestZeroWithAllRoutesTLSWithHTTPClient(t *testing.T) {
client := http.Client{
Timeout: time.Second * 10,
}
defer client.CloseIdleConnections()
for _, test := range testCasesHttp {
request, err := http.NewRequest("GET", test.url, nil)
require.NoError(t, err)
do, err := client.Do(request)
require.NoError(t, err)
if do != nil && do.StatusCode != test.statusCode {
t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode)
}

body := readResponseBody(t, do)
if test.response != string(body) {
t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response)
}
}
}

var testCasesHttps = []testCase{
{
url: "https://localhost:6180/health",
response: "OK",
statusCode: 200,
},
{
url: "https://localhost:6180/state",
response: "\"id\":\"1\",\"addr\":\"zero1:5180\",\"leader\":true",
statusCode: 200,
},
}

func TestZeroWithAllRoutesTLSWithTLSClient(t *testing.T) {
pool, err := generateCertPool("../../tls/ca.crt", true)
require.NoError(t, err)

tlsCfg := &tls.Config{RootCAs: pool, ServerName: "localhost", InsecureSkipVerify: true}
tr := &http.Transport{
IdleConnTimeout: 30 * time.Second,
DisableCompression: true,
TLSClientConfig: tlsCfg,
}
client := http.Client{
Transport: tr,
}

defer client.CloseIdleConnections()
for _, test := range testCasesHttps {
request, err := http.NewRequest("GET", test.url, nil)
require.NoError(t, err)
do, err := client.Do(request)
require.NoError(t, err)
if do != nil && do.StatusCode != test.statusCode {
t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode)
}

body := readResponseBody(t, do)
if !strings.Contains(string(body), test.response) {
t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response)
}
}
}

func readResponseBody(t *testing.T, do *http.Response) []byte {
defer func() { _ = do.Body.Close() }()
body, err := ioutil.ReadAll(do.Body)
require.NoError(t, err)
return body
}

func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) {
var pool *x509.CertPool
if useSystemCA {
var err error
if pool, err = x509.SystemCertPool(); err != nil {
return nil, err
}
} else {
pool = x509.NewCertPool()
}

if len(certPath) > 0 {
caFile, err := ioutil.ReadFile(certPath)
if err != nil {
return nil, err
}
if !pool.AppendCertsFromPEM(caFile) {
return nil, errors.Errorf("error reading CA file %q", certPath)
}
}

return pool, nil
}
37 changes: 37 additions & 0 deletions tlstest/zero_https/all_routes_tls/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
version: "3.5"
services:
alpha1:
image: dgraph/dgraph:latest
container_name: alpha1
working_dir: /data/alpha1
labels:
cluster: test
ports:
- 8180:8180
- 9180:9180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
command: /gobin/dgraph alpha -o 100 --my=alpha1:7180 --zero=zero1:5180 --logtostderr -v=2 --whitelist=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16
zero1:
image: dgraph/dgraph:latest
container_name: zero1
working_dir: /data/zero1
labels:
cluster: test
ports:
- 5180:5180
- 6180:6180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
- type: bind
source: ../../tls
target: /dgraph-tls
read_only: true
command: /gobin/dgraph zero -o 100 --idx=1 --my=zero1:5180 --tls_dir /dgraph-tls -v=2 --bindall
volumes: {}
33 changes: 33 additions & 0 deletions tlstest/zero_https/no_tls/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
version: "3.5"
services:
alpha1:
image: dgraph/dgraph:latest
container_name: alpha1
working_dir: /data/alpha1
labels:
cluster: test
ports:
- 8180:8180
- 9180:9180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
command: /gobin/dgraph alpha -o 100 --my=alpha1:7180 --zero=zero1:5180 --logtostderr -v=2 --whitelist=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16
zero1:
image: dgraph/dgraph:latest
container_name: zero1
working_dir: /data/zero1
labels:
cluster: test
ports:
- 5180:5180
- 6180:6180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
command: /gobin/dgraph zero -o 100 --idx=1 --my=zero1:5180 -v=2 --bindall
volumes: {}
Loading

0 comments on commit c66f67f

Please sign in to comment.