Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials: create API for transport security level information #3214

Merged
merged 2 commits into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions credentials/alts/internal/authinfo/authinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var _ credentials.AuthInfo = (*altsAuthInfo)(nil)
// application. altsAuthInfo is immutable and implements credentials.AuthInfo.
type altsAuthInfo struct {
p *altspb.AltsContext
credentials.CommonAuthInfo
}

// New returns a new altsAuthInfo object given handshaker results.
Expand All @@ -48,6 +49,7 @@ func newAuthInfo(result *altspb.HandshakerResult) *altsAuthInfo {
LocalServiceAccount: result.GetLocalIdentity().GetServiceAccount(),
PeerRpcVersions: result.GetPeerRpcVersions(),
},
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity},
}
}

Expand Down
79 changes: 77 additions & 2 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package credentials // import "google.golang.org/grpc/credentials"
import (
"context"
"errors"
"fmt"
"net"

"github.com/golang/protobuf/proto"
Expand All @@ -50,6 +51,48 @@ type PerRPCCredentials interface {
RequireTransportSecurity() bool
}

// SecurityLevel defines the protection level on an established connection.
//
// This API is experimental.
type SecurityLevel int
dfawley marked this conversation as resolved.
Show resolved Hide resolved

const (
// NoSecurity indicates a connection is insecure.
// The zero SecurityLevel value is invalid for backward compatibility.
NoSecurity SecurityLevel = iota + 1
dfawley marked this conversation as resolved.
Show resolved Hide resolved
// IntegrityOnly indicates a connection only provides integrity protection.
IntegrityOnly
// PrivacyAndIntegrity indicates a connection provides both privacy and integrity protection.
PrivacyAndIntegrity
)

// String returns SecurityLevel in a string format.
func (s SecurityLevel) String() string {
switch s {
case NoSecurity:
return "NoSecurity"
case IntegrityOnly:
return "IntegrityOnly"
case PrivacyAndIntegrity:
return "PrivacyAndIntegrity"
}
return fmt.Sprintf("invalid SecurityLevel: %v", int(s))
}

// CommonAuthInfo contains authenticated information common to AuthInfo implementations.
dfawley marked this conversation as resolved.
Show resolved Hide resolved
dfawley marked this conversation as resolved.
Show resolved Hide resolved
// It should be embedded in a struct implementing AuthInfo to provide additional information
// about the credentials.
//
// This API is experimental.
type CommonAuthInfo struct {
SecurityLevel SecurityLevel
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo {
return c
}

// ProtocolInfo provides information regarding the gRPC wire protocol version,
// security protocol, security protocol version in use, server name, etc.
type ProtocolInfo struct {
Expand All @@ -64,6 +107,8 @@ type ProtocolInfo struct {
}

// AuthInfo defines the common interface for the auth information the users are interested in.
// A struct that implements AuthInfo should embed CommonAuthInfo by including additional
// information about the credentials in it.
type AuthInfo interface {
AuthType() string
}
Expand All @@ -78,7 +123,8 @@ type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection.
// Implementations must use the provided context to implement timely cancellation.
// The auth information should embed CommonAuthInfo to return additional information about
// the credentials. Implementations must use the provided context to implement timely cancellation.
// gRPC will try to reconnect if the error returned is a temporary error
// (io.EOF, context.DeadlineExceeded or err.Temporary() == true).
// If the returned error is a wrapper error, implementations should make sure that
Expand All @@ -88,7 +134,8 @@ type TransportCredentials interface {
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
// the connection. The auth information should embed CommonAuthInfo to return additional information
// about the credentials.
//
// If the returned net.Conn is closed, it MUST close the net.Conn provided.
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
Expand Down Expand Up @@ -127,6 +174,8 @@ type Bundle interface {
type RequestInfo struct {
// The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method")
Method string
// AuthInfo contains the information from a security handshake (TransportCredentials.ClientHandshake, TransportCredentials.ServerHandshake)
AuthInfo AuthInfo
}

// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object.
Expand All @@ -140,6 +189,32 @@ func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) {
return
}

// CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one.
// It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
return errors.New("unable to obtain SecurityLevel from context")
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == 0 {
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
return fmt.Errorf("requires SecurityLevel %v; connection has %v", level, ci.GetCommonAuthInfo().SecurityLevel)
}
}
// The condition is satisfied or AuthInfo struct does not implement GetCommonAuthInfo() method.
return nil
}

func init() {
internal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, ri)
Expand Down
85 changes: 83 additions & 2 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,90 @@ import (
"strings"
"testing"

"google.golang.org/grpc/internal"
"google.golang.org/grpc/testdata"
)

// A struct that implements AuthInfo interface but does not implement GetCommonAuthInfo() method.
type testAuthInfoNoGetCommonAuthInfoMethod struct{}

func (ta testAuthInfoNoGetCommonAuthInfoMethod) AuthType() string {
return "testAuthInfoNoGetCommonAuthInfoMethod"
}

// A struct that implements AuthInfo interface and implements CommonAuthInfo() method.
type testAuthInfo struct {
CommonAuthInfo
}

func (ta testAuthInfo) AuthType() string {
return "testAuthInfo"
}

func createTestContext(s SecurityLevel) context.Context {
auth := &testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: s}}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
return internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri)
}

func TestCheckSecurityLevel(t *testing.T) {
testCases := []struct {
authLevel SecurityLevel
testLevel SecurityLevel
want bool
}{
{
authLevel: PrivacyAndIntegrity,
testLevel: PrivacyAndIntegrity,
want: true,
},
{
authLevel: IntegrityOnly,
testLevel: PrivacyAndIntegrity,
want: false,
},
{
authLevel: IntegrityOnly,
testLevel: NoSecurity,
want: true,
},
{
authLevel: 0,
testLevel: IntegrityOnly,
want: true,
},
{
authLevel: 0,
testLevel: PrivacyAndIntegrity,
want: true,
},
}
for _, tc := range testCases {
err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel)
if tc.want && (err != nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
} else if !tc.want && (err == nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String())

}
}
}

func TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
auth := &testAuthInfoNoGetCommonAuthInfoMethod{}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri)
if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil {
t.Fatalf("CheckSeurityLevel() returned failure but want success")
}
}

func TestTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
c := NewTLS(nil)
Expand Down Expand Up @@ -225,7 +306,7 @@ func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
if err != nil {
return nil, err
}
return TLSInfo{State: serverConn.ConnectionState()}, nil
return TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
}

func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
Expand All @@ -234,7 +315,7 @@ func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
if err := clientConn.Handshake(); err != nil {
return nil, err
}
return TLSInfo{State: clientConn.ConnectionState()}, nil
return TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
}

func TestAppendH2ToNextProtos(t *testing.T) {
Expand Down
12 changes: 12 additions & 0 deletions credentials/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": token.Type() + " " + token.AccessToken,
}, nil
Expand Down Expand Up @@ -79,6 +82,9 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": token.Type() + " " + token.AccessToken,
}, nil
Expand All @@ -99,6 +105,9 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
}

func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": oa.token.Type() + " " + oa.token.AccessToken,
}, nil
Expand Down Expand Up @@ -133,6 +142,9 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string)
return nil, err
}
}
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": s.t.Type() + " " + s.t.AccessToken,
}, nil
Expand Down
5 changes: 3 additions & 2 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
// It implements the AuthInfo interface.
type TLSInfo struct {
State tls.ConnectionState
CommonAuthInfo
dfawley marked this conversation as resolved.
Show resolved Hide resolved
dfawley marked this conversation as resolved.
Show resolved Hide resolved
}

// AuthType returns the type of TLSInfo as a string.
Expand Down Expand Up @@ -90,15 +91,15 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
case <-ctx.Done():
return nil, nil, ctx.Err()
}
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState()}, nil
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil {
return nil, nil, err
}
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState()}, nil
return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
}

func (c *tlsCreds) Clone() TransportCredentials {
Expand Down
2 changes: 1 addition & 1 deletion internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
Addr: ht.RemoteAddr(),
}
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
Expand Down
3 changes: 2 additions & 1 deletion internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ func (t *http2Client) getPeer() *peer.Peer {
func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr)
ri := credentials.RequestInfo{
Method: callHdr.Method,
Method: callHdr.Method,
AuthInfo: t.authInfo,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
authData, err := t.getTrAuthData(ctxWithRequestInfo, aud)
Expand Down