Skip to content

Commit

Permalink
fix CheckSecurityLevel format
Browse files Browse the repository at this point in the history
  • Loading branch information
yihuazhang committed Dec 20, 2019
1 parent 37e65e3 commit bef9360
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
19 changes: 11 additions & 8 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type SecurityLevel int

const (
// NoSecurity indicates a connection is insecure.
// The zero SecurityLevel value is invalid for backward compatibility.
NoSecurity SecurityLevel = iota + 1
// IntegrityOnly indicates a connection only provides integrity protection.
IntegrityOnly
Expand Down Expand Up @@ -169,7 +170,7 @@ type Bundle interface {
type RequestInfo struct {
// The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method")
Method string
// AuthInfo contains the information resulted from a security handshake (TransportCredentials.ClientHandshake, TransportCredentials.ServerHandshake)
// AuthInfo contains the information from a security handshake (TransportCredentials.ClientHandshake, TransportCredentials.ServerHandshake)
AuthInfo AuthInfo
}

Expand All @@ -185,27 +186,29 @@ func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) {
}

// CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one.
// It returns true if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) bool {
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
return false
return errors.New("RequestInfo does not contain AuthInfo struct")
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == 0 {
return true
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
return fmt.Errorf("requires SecurityLevel %v; connection has %v", level, ci.GetCommonAuthInfo().SecurityLevel)
}
return ci.GetCommonAuthInfo().SecurityLevel >= level
}
// AuthInfo struct does not implement GetCommonAuthInfo() method.
return true
// The condition is satisfied or AuthInfo struct does not implement GetCommonAuthInfo() method.
return nil
}

func init() {
Expand Down
12 changes: 8 additions & 4 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ func TestCheckSecurityLevel(t *testing.T) {
},
}
for _, tc := range testCases {
if got := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel); got != tc.want {
t.Fatalf("CheckSeurityLevel(%s, %s) returned %v but want %v", tc.authLevel.String(), tc.testLevel.String(), got, tc.want)
err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel)
if tc.want && (err != nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
} else if !tc.want && (err == nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String())

}
}
}
Expand All @@ -101,8 +105,8 @@ func TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
AuthInfo: auth,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri)
if !CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity) {
t.Fatalf("CheckSeurityLevel() returned false, want true")
if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil {
t.Fatalf("CheckSeurityLevel() returned failure but want success")
}
}

Expand Down
16 changes: 8 additions & 8 deletions credentials/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma
if err != nil {
return nil, err
}
if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) {
return nil, fmt.Errorf("connection is not secure enough to transfer TokenSource PerRPCCredentials which requires PrivacyAndIntegrity")
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": token.Type() + " " + token.AccessToken,
Expand Down Expand Up @@ -82,8 +82,8 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s
if err != nil {
return nil, err
}
if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) {
return nil, fmt.Errorf("connection is not secure enough to transfer JWTAccess PerRPCCredentials which requires PrivacyAndIntegrity")
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": token.Type() + " " + token.AccessToken,
Expand All @@ -105,8 +105,8 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
}

func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) {
return nil, fmt.Errorf("connection is not secure enough to transfer OauthAccess PerRPCCredentials which requires PrivacyAndIntegrity")
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": oa.token.Type() + " " + oa.token.AccessToken,
Expand Down Expand Up @@ -142,8 +142,8 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string)
return nil, err
}
}
if !credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) {
return nil, fmt.Errorf("connection is not secure enough to transfer ServiceAccount PerRPCCredentials which requires PrivacyAndIntegrity")
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err)
}
return map[string]string{
"authorization": s.t.Type() + " " + s.t.AccessToken,
Expand Down

0 comments on commit bef9360

Please sign in to comment.