Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mTLS authentication w/ tests #4

Merged
merged 11 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile.local
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ server client:

.PHONY: test
test:
$(GO) test -v -race ./...
$(GO) test -v -count=1 -race ./...

.PHONY: fmt
fmt:
Expand Down
22 changes: 22 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package client

import (
"context"
"time"

"google.golang.org/grpc"
)

func DialContextWithTimeout(ctx context.Context, timeout time.Duration, target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

return grpc.DialContext(
ctxWithTimeout,
target,
append(
opts,
grpc.WithReturnConnectionError(),
)...,
)
}
46 changes: 46 additions & 0 deletions cmd/client/main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,50 @@
package main

import (
"context"
"flag"
"log"
"time"

"google.golang.org/grpc"

"github.com/andrejtokarcik/jobworker/client"
"github.com/andrejtokarcik/jobworker/mtls"
)

var (
serverAddress string
connTimeout time.Duration
credsFiles mtls.CredsFiles
)

func init() {
flag.StringVar(&serverAddress, "server", "127.0.0.1:50051", "Address of the server to connect to")
flag.DurationVar(&connTimeout, "timeout", 5*time.Second, "Connection timeout")

flag.StringVar(&credsFiles.Cert, "client-cert", "client.crt", "Certificate file to use for the client")
flag.StringVar(&credsFiles.Key, "client-key", "client.key", "Private key file to use for the client")
flag.StringVar(&credsFiles.PeerCACert, "server-ca-cert", "server-ca.crt", "Certificate file of the CA to authenticate the server")
}

func main() {
flag.Parse()

creds, err := mtls.NewClientCreds(credsFiles)
if err != nil {
log.Fatal("Failed to load mTLS credentials: ", err)
}

conn, err := client.DialContextWithTimeout(
context.Background(),
connTimeout,
serverAddress,
grpc.WithTransportCredentials(creds),
)
if err != nil {
log.Fatal("Failed to dial server: ", err)
}
defer conn.Close()

log.Print("Successfully connected to server at ", serverAddress)
andrejtokarcik marked this conversation as resolved.
Show resolved Hide resolved
}
42 changes: 42 additions & 0 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,46 @@
package main

import (
"flag"
"fmt"
"log"
"net"

"google.golang.org/grpc"

"github.com/andrejtokarcik/jobworker/mtls"
"github.com/andrejtokarcik/jobworker/server"
)

var (
grpcPort int
credsFiles mtls.CredsFiles
)

func init() {
flag.IntVar(&grpcPort, "grpc-port", 50051, "Port to expose the gRPC server on")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Define default port number constant since it is used by both server and client.


flag.StringVar(&credsFiles.Cert, "server-cert", "server.crt", "Certificate file to use for the server")
flag.StringVar(&credsFiles.Key, "server-key", "server.key", "Private key file to use for the server")
flag.StringVar(&credsFiles.PeerCACert, "client-ca-cert", "client-ca.crt", "Certificate file of the CA to authenticate the clients")
}

func main() {
flag.Parse()

creds, err := mtls.NewServerCreds(credsFiles)
if err != nil {
log.Fatal("Failed to load mTLS credentials: ", err)
}
grpcServer := server.New(grpc.Creds(creds))

listener, err := net.Listen("tcp", fmt.Sprintf(":%d", grpcPort))
if err != nil {
log.Fatal("Failed to listen: ", err)
}

log.Print("Starting gRPC server at ", listener.Addr())
if err := grpcServer.Serve(listener); err != nil {
log.Fatal("Failed to serve: ", err)
}
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ go 1.15

require (
github.com/gogo/protobuf v1.3.1
google.golang.org/grpc v1.33.0
github.com/stretchr/testify v1.6.1
google.golang.org/grpc v1.33.1
)
18 changes: 16 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand All @@ -17,10 +20,17 @@ github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaW
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
Expand Down Expand Up @@ -55,7 +65,11 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.33.0 h1:IBKSUNL2uBS2DkJBncPP+TwT0sp9tgA8A75NjHt6umg=
google.golang.org/grpc v1.33.0/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0=
google.golang.org/grpc v1.33.1 h1:DGeFlSan2f+WEtCERJ4J9GJWk15TxUi8QGagfI87Xyc=
google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
65 changes: 65 additions & 0 deletions mtls/bufconn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package mtls_test

import (
"context"
"net"
"time"

"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"

"github.com/andrejtokarcik/jobworker/client"
"github.com/andrejtokarcik/jobworker/server"
)

type BufconnConfig struct {
BufSize int
ClientTimeout time.Duration
}

type BufconnSuite struct {
suite.Suite
BufconnConfig
grpcServer *grpc.Server
listener *bufconn.Listener
}

func NewBufconnSuite() (suite BufconnSuite) {
suite.BufconnConfig = BufconnConfig{
BufSize: 1024 * 1024,
ClientTimeout: 1 * time.Second,
}
return
}

func (suite *BufconnSuite) SetupBufconn(opts ...grpc.ServerOption) {
suite.grpcServer = server.New(opts...)
suite.listener = bufconn.Listen(suite.BufSize)
go func() {
if err := suite.grpcServer.Serve(suite.listener); err != nil {
panic(err)
}
}()
}

func (suite *BufconnSuite) TearDownBufconn() {
suite.listener.Close()
suite.grpcServer.Stop()
}

func (suite *BufconnSuite) contextDialer(context.Context, string) (net.Conn, error) {
return suite.listener.Dial()
}

func (suite *BufconnSuite) DialBufconn(serverName string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return client.DialContextWithTimeout(
context.Background(),
suite.ClientTimeout,
serverName,
append(
opts,
grpc.WithContextDialer(suite.contextDialer),
)...,
)
}
65 changes: 65 additions & 0 deletions mtls/creds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package mtls

import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"

"google.golang.org/grpc/credentials"
)

type CredsFiles struct {
Cert, Key, PeerCACert string
}

type loadedCredsFiles struct {
cert tls.Certificate
peerCAPool *x509.CertPool
}

func loadCredsFiles(credsFiles CredsFiles) (*loadedCredsFiles, error) {
cert, err := tls.LoadX509KeyPair(credsFiles.Cert, credsFiles.Key)
if err != nil {
return nil, err
}

peerCACert, err := ioutil.ReadFile(credsFiles.PeerCACert)
if err != nil {
return nil, err
}

peerCAPool := x509.NewCertPool()
if ok := peerCAPool.AppendCertsFromPEM(peerCACert); !ok {
return nil, errors.New("failed to append to peer CA cert pool")
}

return &loadedCredsFiles{cert, peerCAPool}, nil
}

func NewServerCreds(serverFiles CredsFiles) (credentials.TransportCredentials, error) {
loaded, err := loadCredsFiles(serverFiles)
if err != nil {
return nil, err
}

config := &tls.Config{
Certificates: []tls.Certificate{loaded.cert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: loaded.peerCAPool,
}
return credentials.NewTLS(config), nil
}

func NewClientCreds(clientFiles CredsFiles) (credentials.TransportCredentials, error) {
loaded, err := loadCredsFiles(clientFiles)
if err != nil {
return nil, err
}

config := &tls.Config{
Certificates: []tls.Certificate{loaded.cert},
RootCAs: loaded.peerCAPool,
}
return credentials.NewTLS(config), nil
}
Loading