Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cesarghali committed Feb 23, 2018
1 parent 40ef281 commit a4d4431
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 28 deletions.
63 changes: 44 additions & 19 deletions credentials/alts/alts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
// encapsulates all the state needed by a client to authenticate with a server
// using ALTS and make various assertions, e.g., about the client's identity,
// role, or whether it is authorized to make a particular call.
// This package is experimental.
package alts

import (
"errors"
"flag"
"fmt"
"net"
"sync"
"time"

"golang.org/x/net/context"
Expand All @@ -51,14 +53,26 @@ const (

var (
enableUntrustedALTS = flag.Bool("enable_untrusted_alts", false, "Enables ALTS in untrusted mode. Enabling this mode is risky since we cannot ensure that the application is running on GCP with a trusted handshaker service.")
once sync.Once
maxRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMaxMajor,
Minor: protocolVersionMaxMinor,
}
minRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMinMajor,
Minor: protocolVersionMinMinor,
}
// ErrUntrustedPlatform is returned from ClientHandshake and
// ServerHandshake is running on a platform where the trustworthiness of
// the handshaker service is not guaranteed.
ErrUntrustedPlatform = errors.New("untrusted platform, use enable_untrusted_alts flag at your own risk")
)

// AuthInfo exposes security information from the ALTS handshake to the
// application.
// application. This interface is to be implemented by ALTS. Users should not
// need a brand new implementation of this interface. For situations like
// testing, any new implementation should embed this interface. This allows
// ALTS to add new methods to this interface.
type AuthInfo interface {
// ApplicationProtocol returns application protocol negotiated for the
// ALTS connection.
Expand All @@ -77,8 +91,8 @@ type AuthInfo interface {
PeerRPCVersions() *altspb.RpcProtocolVersions
}

// altsTC is the credentials required for authenticating a connection using Google
// Transport Security. It implements credentials.TransportCredentials interface.
// altsTC is the credentials required for authenticating a connection using ALTS.
// It implements credentials.TransportCredentials interface.
type altsTC struct {
info *credentials.ProtocolInfo
hsAddr string
Expand All @@ -97,6 +111,8 @@ func NewServerALTS() credentials.TransportCredentials {
}

func newALTS(side core.Side, accounts []string) credentials.TransportCredentials {
// Make sure flags are parsed before accessing enableUntrustedALTS.
once.Do(func() { flag.Parse() })
if *enableUntrustedALTS {
grpclog.Warning("untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS handshaker service.")
}
Expand Down Expand Up @@ -124,19 +140,29 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
}
// Do not close hsConn since it is shared with other handshakes.

// Possible context leak:
// The cancel function for the child context we create will only be
// called a non-nil error is returned.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()

opts := handshaker.DefaultClientHandshakerOptions()
opts.TargetServiceAccounts = g.accounts
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMaxMajor,
Minor: protocolVersionMaxMinor,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMinMajor,
Minor: protocolVersionMinMinor,
},
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
chs.Close()
}
}()
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -171,16 +197,15 @@ func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.Au
defer cancel()
opts := handshaker.DefaultServerHandshakerOptions()
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMaxMajor,
Minor: protocolVersionMaxMinor,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMinMajor,
Minor: protocolVersionMinMinor,
},
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
shs.Close()
}
}()
if err != nil {
return nil, nil, err
}
Expand Down
13 changes: 7 additions & 6 deletions credentials/alts/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func (k platformError) Error() string {

var (
// The following two variables will be reassigned in tests.
runningOS = runtime.GOOS
readerFunc = func() (io.Reader, error) {
runningOS = runtime.GOOS
manufacturerReader = func() (io.Reader, error) {
switch runningOS {
case "linux":
return os.Open(linuxProductNameFile)
Expand All @@ -72,7 +72,7 @@ var (

return nil, errors.New("cannot determine the machine's manufacturer")
default:
panic(platformError(runningOS))
return nil, platformError(runningOS)
}
}
vmOnGCP = isRunningOnGCP()
Expand All @@ -83,7 +83,7 @@ var (
func isRunningOnGCP() bool {
manufacturer, err := readManufacturer()
if err != nil {
log.Fatal(err)
log.Fatalf("failure to read manufacturer information: %v", err)
}
name := string(manufacturer)
switch runningOS {
Expand All @@ -96,12 +96,13 @@ func isRunningOnGCP() bool {
name = strings.Replace(name, "\r", "", -1)
return name == "Google"
default:
panic(platformError(runningOS))
log.Fatal(platformError(runningOS))
}
return false
}

func readManufacturer() ([]byte, error) {
reader, err := readerFunc()
reader, err := manufacturerReader()
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions credentials/alts/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ func TestIsRunningOnGCP(t *testing.T) {

func setup(testOS string, testReader io.Reader) func() {
tmpOS := runningOS
tmpReader := readerFunc
tmpReader := manufacturerReader

// Set test OS and reader function.
runningOS = testOS
readerFunc = func() (io.Reader, error) {
manufacturerReader = func() (io.Reader, error) {
return testReader, nil
}

return func() {
runningOS = tmpOS
readerFunc = tmpReader
manufacturerReader = tmpReader
}
}

0 comments on commit a4d4431

Please sign in to comment.