diff --git a/credentials/credentials.go b/credentials/credentials.go index 0c319b51f561..d6365fd35aae 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -56,6 +56,7 @@ type SecurityLevel int const ( // NoSecurity indicates a connection is insecure. + // The zero SecurityLevel value is invalid for backward compatibility. NoSecurity SecurityLevel = iota + 1 // IntegrityOnly indicates a connection only provides integrity protection. IntegrityOnly @@ -169,7 +170,7 @@ type Bundle interface { type RequestInfo struct { // The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method") Method string - // AuthInfo contains the information resulted from a security handshake (TransportCredentials.ClientHandshake, TransportCredentials.ServerHandshake) + // AuthInfo contains the information from a security handshake (TransportCredentials.ClientHandshake, TransportCredentials.ServerHandshake) AuthInfo AuthInfo } @@ -185,27 +186,29 @@ func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) { } // CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one. -// It returns true if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method +// It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method // or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility. // // This API is experimental. -func CheckSecurityLevel(ctx context.Context, level SecurityLevel) bool { +func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error { type internalInfo interface { GetCommonAuthInfo() *CommonAuthInfo } ri, _ := RequestInfoFromContext(ctx) if ri.AuthInfo == nil { - return false + return errors.New("RequestInfo does not contain AuthInfo struct") } if ci, ok := ri.AuthInfo.(internalInfo); ok { // CommonAuthInfo.SecurityLevel has an invalid value. if ci.GetCommonAuthInfo().SecurityLevel == 0 { - return true + return nil + } + if ci.GetCommonAuthInfo().SecurityLevel < level { + return fmt.Errorf("requires SecurityLevel %v; connection has %v", level, ci.GetCommonAuthInfo().SecurityLevel) } - return ci.GetCommonAuthInfo().SecurityLevel >= level } - // AuthInfo struct does not implement GetCommonAuthInfo() method. - return true + // The condition is satisfied or AuthInfo struct does not implement GetCommonAuthInfo() method. + return nil } func init() { diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index 82ae817dbe8d..f14f98e1ff58 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -88,8 +88,12 @@ func TestCheckSecurityLevel(t *testing.T) { }, } for _, tc := range testCases { - if got := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel); got != tc.want { - t.Fatalf("CheckSeurityLevel(%s, %s) returned %v but want %v", tc.authLevel.String(), tc.testLevel.String(), got, tc.want) + err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel) + if tc.want && (err != nil) { + t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String()) + } else if !tc.want && (err == nil) { + t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String()) + } } } @@ -101,8 +105,8 @@ func TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) { AuthInfo: auth, } ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri) - if !CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity) { - t.Fatalf("CheckSeurityLevel() returned false, want true") + if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil { + t.Fatalf("CheckSeurityLevel() returned failure but want success") } } diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go index c1f6e3356375..899e3372ce3c 100644 --- a/credentials/oauth/oauth.go +++ b/credentials/oauth/oauth.go @@ -42,8 +42,8 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma if err != nil { return nil, err } - if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) { - return nil, fmt.Errorf("connection is not secure enough to transfer TokenSource PerRPCCredentials which requires PrivacyAndIntegrity") + if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err) } return map[string]string{ "authorization": token.Type() + " " + token.AccessToken, @@ -82,8 +82,8 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s if err != nil { return nil, err } - if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) { - return nil, fmt.Errorf("connection is not secure enough to transfer JWTAccess PerRPCCredentials which requires PrivacyAndIntegrity") + if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err) } return map[string]string{ "authorization": token.Type() + " " + token.AccessToken, @@ -105,8 +105,8 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials { } func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) { - return nil, fmt.Errorf("connection is not secure enough to transfer OauthAccess PerRPCCredentials which requires PrivacyAndIntegrity") + if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err) } return map[string]string{ "authorization": oa.token.Type() + " " + oa.token.AccessToken, @@ -142,8 +142,8 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string) return nil, err } } - if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) { - return nil, fmt.Errorf("connection is not secure enough to transfer ServiceAccount PerRPCCredentials which requires PrivacyAndIntegrity") + if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err) } return map[string]string{ "authorization": s.t.Type() + " " + s.t.AccessToken,