Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mTLS authentication w/ tests #4

Merged
merged 11 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile.local
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ server client:

.PHONY: test
test:
$(GO) test -v -race ./...
$(GO) test -v -count=1 -race ./...

.PHONY: fmt
fmt:
Expand Down
22 changes: 22 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package client

import (
"context"
"time"

"google.golang.org/grpc"
)

func Dial(target string, timeout time.Duration, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
andrejtokarcik marked this conversation as resolved.
Show resolved Hide resolved
defer cancel()

return grpc.DialContext(
ctx,
target,
append(
opts,
grpc.WithReturnConnectionError(),
)...,
)
}
44 changes: 44 additions & 0 deletions cmd/client/main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,48 @@
package main

import (
"flag"
"log"
"time"

"google.golang.org/grpc"

"github.com/andrejtokarcik/jobworker/client"
"github.com/andrejtokarcik/jobworker/mtls"
)

var (
serverAddress string
timeoutSeconds int
andrejtokarcik marked this conversation as resolved.
Show resolved Hide resolved
credsFiles mtls.CredsFiles
)

func init() {
flag.StringVar(&serverAddress, "server", "0.0.0.0:50051", "Address of the server to connect to")
andrejtokarcik marked this conversation as resolved.
Show resolved Hide resolved
flag.IntVar(&timeoutSeconds, "timeout", 5, "Connection timeout in seconds")

flag.StringVar(&credsFiles.Cert, "client-cert", "client.crt", "Certificate file to use for the client")
flag.StringVar(&credsFiles.Key, "client-key", "client.key", "Private key file to use for the client")
flag.StringVar(&credsFiles.PeerCACert, "server-ca-cert", "server-ca.crt", "Certificate file of the CA to authenticate the server")
}

func main() {
flag.Parse()

creds, err := mtls.NewClientCreds(credsFiles)
if err != nil {
log.Fatal("Failed to load mTLS credentials: ", err)
}

conn, err := client.Dial(
serverAddress,
time.Duration(timeoutSeconds)*time.Second,
grpc.WithTransportCredentials(creds),
)
if err != nil {
log.Fatal("Failed to dial server: ", err)
}
defer conn.Close()

log.Print("Successfully connected to server at ", serverAddress)
andrejtokarcik marked this conversation as resolved.
Show resolved Hide resolved
}
42 changes: 42 additions & 0 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,46 @@
package main

import (
"flag"
"fmt"
"log"
"net"

"google.golang.org/grpc"

"github.com/andrejtokarcik/jobworker/mtls"
"github.com/andrejtokarcik/jobworker/server"
)

var (
grpcPort int
credsFiles mtls.CredsFiles
)

func init() {
flag.IntVar(&grpcPort, "grpc-port", 50051, "Port to expose the gRPC server on")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Define default port number constant since it is used by both server and client.


flag.StringVar(&credsFiles.Cert, "server-cert", "server.crt", "Certificate file to use for the server")
flag.StringVar(&credsFiles.Key, "server-key", "server.key", "Private key file to use for the server")
flag.StringVar(&credsFiles.PeerCACert, "client-ca-cert", "client-ca.crt", "Certificate file of the CA to authenticate the clients")
}

func main() {
flag.Parse()

creds, err := mtls.NewServerCreds(credsFiles)
if err != nil {
log.Fatal("Failed to load mTLS credentials: ", err)
}
grpcServer := server.New(grpc.Creds(creds))

listener, err := net.Listen("tcp", fmt.Sprintf(":%d", grpcPort))
if err != nil {
log.Fatal("Failed to listen: ", err)
}

log.Print("Starting gRPC server at ", listener.Addr())
if err := grpcServer.Serve(listener); err != nil {
log.Fatal("Failed to serve: ", err)
}
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ go 1.15

require (
github.com/gogo/protobuf v1.3.1
google.golang.org/grpc v1.33.0
github.com/stretchr/testify v1.6.1
google.golang.org/grpc v1.33.1
)
18 changes: 16 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand All @@ -17,10 +20,17 @@ github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaW
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
Expand Down Expand Up @@ -55,7 +65,11 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.33.0 h1:IBKSUNL2uBS2DkJBncPP+TwT0sp9tgA8A75NjHt6umg=
google.golang.org/grpc v1.33.0/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0=
google.golang.org/grpc v1.33.1 h1:DGeFlSan2f+WEtCERJ4J9GJWk15TxUi8QGagfI87Xyc=
google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
65 changes: 65 additions & 0 deletions mtls/creds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package mtls

import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"

"google.golang.org/grpc/credentials"
)

type CredsFiles struct {
Cert, Key, PeerCACert string
}

type loadedCredsFiles struct {
cert tls.Certificate
peerCAPool *x509.CertPool
}

func loadCredsFiles(credsFiles CredsFiles) (*loadedCredsFiles, error) {
cert, err := tls.LoadX509KeyPair(credsFiles.Cert, credsFiles.Key)
if err != nil {
return nil, err
}

peerCACert, err := ioutil.ReadFile(credsFiles.PeerCACert)
if err != nil {
return nil, err
}

peerCAPool := x509.NewCertPool()
if ok := peerCAPool.AppendCertsFromPEM(peerCACert); !ok {
return nil, errors.New("failed to append to peer CA cert pool")
}

return &loadedCredsFiles{cert, peerCAPool}, nil
}

func NewServerCreds(serverFiles CredsFiles) (credentials.TransportCredentials, error) {
loaded, err := loadCredsFiles(serverFiles)
if err != nil {
return nil, err
}

config := &tls.Config{
Certificates: []tls.Certificate{loaded.cert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: loaded.peerCAPool,
}
return credentials.NewTLS(config), nil
}

func NewClientCreds(clientFiles CredsFiles) (credentials.TransportCredentials, error) {
loaded, err := loadCredsFiles(clientFiles)
if err != nil {
return nil, err
}

config := &tls.Config{
Certificates: []tls.Certificate{loaded.cert},
RootCAs: loaded.peerCAPool,
}
return credentials.NewTLS(config), nil
}
79 changes: 79 additions & 0 deletions mtls/mtls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package mtls_test

import (
"errors"
"testing"

"github.com/stretchr/testify/suite"

"github.com/andrejtokarcik/jobworker/mtls"
"github.com/andrejtokarcik/jobworker/test"
"github.com/andrejtokarcik/jobworker/test/data"
)

type mTLSTestSuite struct {
test.BufconnSuite
}

type mTLSTestCase struct {
clientCredsFiles mtls.CredsFiles
expectedErr error
}

func (suite *mTLSTestSuite) SetupSuite() {
suite.SetupBufconnWithDefaultCreds()
}

func (suite *mTLSTestSuite) TearDownSuite() {
suite.TearDownBufconn()
}

func (suite *mTLSTestSuite) runTestCase(tc mTLSTestCase) {
clientCreds, err := mtls.NewClientCreds(
testdata.CredsFilesPaths(tc.clientCredsFiles),
)
suite.Require().Nil(err)

conn, err := suite.DialBufconn(clientCreds)
if conn != nil {
defer conn.Close()
}

if tc.expectedErr == nil {
suite.Require().Nil(err)
} else {
suite.Require().NotNil(err)
suite.Assert().Contains(err.Error(), tc.expectedErr.Error())
}
}

func validTestCase() mTLSTestCase {
return mTLSTestCase{
clientCredsFiles: testdata.DefaultClientCredsFiles(),
expectedErr: nil,
}
}

func (suite *mTLSTestSuite) TestValidCreds() {
tc := validTestCase()
suite.runTestCase(tc)
}

func (suite *mTLSTestSuite) TestWrongServerCA() {
tc := validTestCase()
tc.clientCredsFiles.PeerCACert = "server-ca2.crt"
tc.expectedErr = errors.New("x509: certificate signed by unknown authority")
suite.runTestCase(tc)
}

func (suite *mTLSTestSuite) TestSelfSignedClientCert() {
tc := validTestCase()
tc.clientCredsFiles.Cert = "self-signed.crt"
tc.clientCredsFiles.Key = "self-signed.key"
tc.expectedErr = errors.New("context deadline exceeded")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to improve the error handling for this case? Receiving a 'context deadline exceeded' here might not be too helpful in figuring out what went wrong.

Copy link
Owner Author

@andrejtokarcik andrejtokarcik Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only option I see is to somehow inspect the logs emitted by the grpc framework. In the returned error object itself, the more precise message seems to be just a "connection error" or "connection closed".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does grpc.WithBlock() (or some other dial option) help catch the TLS error early?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. I already use grpc.WithReturnConnectionError() which implies grpc.WithBlock() and is even more helpful in this regard.

suite.runTestCase(tc)
}

func TestMutualTLS(t *testing.T) {
suite.Run(t, &mTLSTestSuite{test.NewBufconnSuite()})
}
19 changes: 19 additions & 0 deletions run_with_testdata.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

trap 'kill $SERVER_PID' TERM INT EXIT
export GRPC_GO_LOG_SEVERITY_LEVEL=warning

CERT_DIR=${CERT_DIR:-./test/data/x509}

./bin/server \
-server-cert $CERT_DIR/server-ca/server1.crt \
-server-key $CERT_DIR/server-ca/server1.key \
-client-ca-cert $CERT_DIR/client-ca.crt &
SERVER_PID=$!

sleep 1

./bin/client \
-client-cert $CERT_DIR/client-ca/client1.crt \
-client-key $CERT_DIR/client-ca/client1.key \
-server-ca-cert $CERT_DIR/server-ca.crt $@
32 changes: 32 additions & 0 deletions server/authn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package server

import (
"context"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)

type clientSubjectKey struct{}

func AttachClientSubject(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, status.Error(codes.Unauthenticated, "no peer found")
}

tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
if !ok {
return nil, status.Error(codes.Unauthenticated, "unexpected peer transport credentials")
}

if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 {
andrejtokarcik marked this conversation as resolved.
Show resolved Hide resolved
return nil, status.Error(codes.Unauthenticated, "could not verify peer certificate")
}

newCtx := context.WithValue(ctx, clientSubjectKey{}, tlsAuth.State.VerifiedChains[0][0].Subject)
return handler(newCtx, req)
}
14 changes: 14 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package server

import (
"google.golang.org/grpc"
)

func New(opts ...grpc.ServerOption) *grpc.Server {
return grpc.NewServer(
append(
opts,
grpc.UnaryInterceptor(AttachClientSubject),
)...,
)
}
Loading