Skip to content

Commit

Permalink
add security level negotation
Browse files Browse the repository at this point in the history
  • Loading branch information
yihuazhang committed Dec 21, 2019
1 parent 01d201e commit bf6cafe
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 8 deletions.
2 changes: 2 additions & 0 deletions credentials/alts/internal/authinfo/authinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var _ credentials.AuthInfo = (*altsAuthInfo)(nil)
// application. altsAuthInfo is immutable and implements credentials.AuthInfo.
type altsAuthInfo struct {
p *altspb.AltsContext
credentials.CommonAuthInfo
}

// New returns a new altsAuthInfo object given handshaker results.
Expand All @@ -48,6 +49,7 @@ func newAuthInfo(result *altspb.HandshakerResult) *altsAuthInfo {
LocalServiceAccount: result.GetLocalIdentity().GetServiceAccount(),
PeerRpcVersions: result.GetPeerRpcVersions(),
},
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity},
}
}

Expand Down
75 changes: 73 additions & 2 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package credentials // import "google.golang.org/grpc/credentials"
import (
"context"
"errors"
"fmt"
"net"

"github.com/golang/protobuf/proto"
Expand All @@ -50,6 +51,44 @@ type PerRPCCredentials interface {
RequireTransportSecurity() bool
}

// SecurityLevel defines the protection level on an established connection.
type SecurityLevel int

const (
// NoSecurity indicates a connection is insecure.
// The zero SecurityLevel value is invalid for backward compatibility.
NoSecurity SecurityLevel = iota + 1
// IntegrityOnly indicates a connection only provides integrity protection.
IntegrityOnly
// PrivacyAndIntegrity indicates a connection provides both privacy and integrity protection.
PrivacyAndIntegrity
)

// String returns SecurityLevel in a string format.
func (s SecurityLevel) String() string {
switch s {
case NoSecurity:
return "NoSecurity"
case IntegrityOnly:
return "IntegrityOnly"
case PrivacyAndIntegrity:
return "PrivacyAndIntegrity"
}
return fmt.Sprintf("invalid SecurityLevel: %v", int(s))
}

// CommonAuthInfo contains authenticated information common to AuthInfo implementations.
// It should be embedded in a struct implementing AuthInfo to provide additional information
// about the credentials.
type CommonAuthInfo struct {
SecurityLevel SecurityLevel
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo {
return c
}

// ProtocolInfo provides information regarding the gRPC wire protocol version,
// security protocol, security protocol version in use, server name, etc.
type ProtocolInfo struct {
Expand All @@ -64,6 +103,8 @@ type ProtocolInfo struct {
}

// AuthInfo defines the common interface for the auth information the users are interested in.
// A struct that implements AuthInfo should embed CommonAuthInfo by including additional
// information about the credentials in it.
type AuthInfo interface {
AuthType() string
}
Expand All @@ -78,7 +119,8 @@ type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection.
// Implementations must use the provided context to implement timely cancellation.
// The auth information should embed CommonAuthInfo to return additional information about
// the credentials. Implementations must use the provided context to implement timely cancellation.
// gRPC will try to reconnect if the error returned is a temporary error
// (io.EOF, context.DeadlineExceeded or err.Temporary() == true).
// If the returned error is a wrapper error, implementations should make sure that
Expand All @@ -88,7 +130,8 @@ type TransportCredentials interface {
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
// the connection. The auth information should embed CommonAuthInfo to return additional information
// about the credentials.
//
// If the returned net.Conn is closed, it MUST close the net.Conn provided.
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
Expand Down Expand Up @@ -127,6 +170,8 @@ type Bundle interface {
type RequestInfo struct {
// The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method")
Method string
// AuthInfo contains the information from a security handshake (TransportCredentials.ClientHandshake, TransportCredentials.ServerHandshake)
AuthInfo AuthInfo
}

// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object.
Expand All @@ -140,6 +185,32 @@ func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) {
return
}

// CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one.
// It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
return errors.New("unable to obtain SecurityLevel from context")
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == 0 {
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
return fmt.Errorf("requires SecurityLevel %v; connection has %v", level, ci.GetCommonAuthInfo().SecurityLevel)
}
}
// The condition is satisfied or AuthInfo struct does not implement GetCommonAuthInfo() method.
return nil
}

func init() {
internal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, ri)
Expand Down
85 changes: 83 additions & 2 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,90 @@ import (
"strings"
"testing"

"google.golang.org/grpc/internal"
"google.golang.org/grpc/testdata"
)

// A struct that implements AuthInfo interface but does not implement GetCommonAuthInfo() method.
type testAuthInfoNoGetCommonAuthInfoMethod struct{}

func (ta testAuthInfoNoGetCommonAuthInfoMethod) AuthType() string {
return "testAuthInfoNoGetCommonAuthInfoMethod"
}

// A struct that implements AuthInfo interface and implements CommonAuthInfo() method.
type testAuthInfo struct {
CommonAuthInfo
}

func (ta testAuthInfo) AuthType() string {
return "testAuthInfo"
}

func createTestContext(s SecurityLevel) context.Context {
auth := &testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: s}}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
return internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri)
}

func TestCheckSecurityLevel(t *testing.T) {
testCases := []struct {
authLevel SecurityLevel
testLevel SecurityLevel
want bool
}{
{
authLevel: PrivacyAndIntegrity,
testLevel: PrivacyAndIntegrity,
want: true,
},
{
authLevel: IntegrityOnly,
testLevel: PrivacyAndIntegrity,
want: false,
},
{
authLevel: IntegrityOnly,
testLevel: NoSecurity,
want: true,
},
{
authLevel: 0,
testLevel: IntegrityOnly,
want: true,
},
{
authLevel: 0,
testLevel: PrivacyAndIntegrity,
want: true,
},
}
for _, tc := range testCases {
err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel)
if tc.want && (err != nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
} else if !tc.want && (err == nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String())

}
}
}

func TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
auth := &testAuthInfoNoGetCommonAuthInfoMethod{}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri)
if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil {
t.Fatalf("CheckSeurityLevel() returned failure but want success")
}
}

func TestTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
c := NewTLS(nil)
Expand Down Expand Up @@ -225,7 +306,7 @@ func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
if err != nil {
return nil, err
}
return TLSInfo{State: serverConn.ConnectionState()}, nil
return TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
}

func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
Expand All @@ -234,7 +315,7 @@ func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
if err := clientConn.Handshake(); err != nil {
return nil, err
}
return TLSInfo{State: clientConn.ConnectionState()}, nil
return TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
}

func TestAppendH2ToNextProtos(t *testing.T) {
Expand Down
12 changes: 12 additions & 0 deletions credentials/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": token.Type() + " " + token.AccessToken,
}, nil
Expand Down Expand Up @@ -79,6 +82,9 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": token.Type() + " " + token.AccessToken,
}, nil
Expand All @@ -99,6 +105,9 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
}

func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": oa.token.Type() + " " + oa.token.AccessToken,
}, nil
Expand Down Expand Up @@ -133,6 +142,9 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string)
return nil, err
}
}
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": s.t.Type() + " " + s.t.AccessToken,
}, nil
Expand Down
5 changes: 3 additions & 2 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
// It implements the AuthInfo interface.
type TLSInfo struct {
State tls.ConnectionState
CommonAuthInfo
}

// AuthType returns the type of TLSInfo as a string.
Expand Down Expand Up @@ -90,15 +91,15 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
case <-ctx.Done():
return nil, nil, ctx.Err()
}
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState()}, nil
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil {
return nil, nil, err
}
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState()}, nil
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
}

func (c *tlsCreds) Clone() TransportCredentials {
Expand Down
2 changes: 1 addition & 1 deletion internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
Addr: ht.RemoteAddr(),
}
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
Expand Down
3 changes: 2 additions & 1 deletion internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ func (t *http2Client) getPeer() *peer.Peer {
func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr)
ri := credentials.RequestInfo{
Method: callHdr.Method,
Method: callHdr.Method,
AuthInfo: t.authInfo,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
authData, err := t.getTrAuthData(ctxWithRequestInfo, aud)
Expand Down

0 comments on commit bf6cafe

Please sign in to comment.