-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
445b728
commit 40ef281
Showing
31 changed files
with
5,269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
/* | ||
* | ||
* Copyright 2018 gRPC authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
*/ | ||
|
||
// Package alts implements the ALTS credential support by gRPC library, which | ||
// encapsulates all the state needed by a client to authenticate with a server | ||
// using ALTS and make various assertions, e.g., about the client's identity, | ||
// role, or whether it is authorized to make a particular call. | ||
package alts | ||
|
||
import ( | ||
"errors" | ||
"flag" | ||
"fmt" | ||
"net" | ||
"time" | ||
|
||
"golang.org/x/net/context" | ||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/credentials/alts/core" | ||
"google.golang.org/grpc/credentials/alts/core/handshaker" | ||
"google.golang.org/grpc/credentials/alts/core/handshaker/service" | ||
altspb "google.golang.org/grpc/credentials/alts/core/proto" | ||
"google.golang.org/grpc/grpclog" | ||
) | ||
|
||
const ( | ||
// defaultTimeout specifies the server handshake timeout. | ||
defaultTimeout = 30.0 * time.Second | ||
// The following constants specify the minimum and maximum acceptable | ||
// protocol versions. | ||
protocolVersionMaxMajor = 2 | ||
protocolVersionMaxMinor = 1 | ||
protocolVersionMinMajor = 2 | ||
protocolVersionMinMinor = 1 | ||
) | ||
|
||
var ( | ||
enableUntrustedALTS = flag.Bool("enable_untrusted_alts", false, "Enables ALTS in untrusted mode. Enabling this mode is risky since we cannot ensure that the application is running on GCP with a trusted handshaker service.") | ||
// ErrUntrustedPlatform is returned from ClientHandshake and | ||
// ServerHandshake is running on a platform where the trustworthiness of | ||
// the handshaker service is not guaranteed. | ||
ErrUntrustedPlatform = errors.New("untrusted platform, use enable_untrusted_alts flag at your own risk") | ||
) | ||
|
||
// AuthInfo exposes security information from the ALTS handshake to the | ||
// application. | ||
type AuthInfo interface { | ||
// ApplicationProtocol returns application protocol negotiated for the | ||
// ALTS connection. | ||
ApplicationProtocol() string | ||
// RecordProtocol returns the record protocol negotiated for the ALTS | ||
// connection. | ||
RecordProtocol() string | ||
// SecurityLevel returns the security level of the created ALTS secure | ||
// channel. | ||
SecurityLevel() altspb.SecurityLevel | ||
// PeerServiceAccount returns the peer service account. | ||
PeerServiceAccount() string | ||
// LocalServiceAccount returns the local service account. | ||
LocalServiceAccount() string | ||
// PeerRPCVersions returns the RPC version supported by the peer. | ||
PeerRPCVersions() *altspb.RpcProtocolVersions | ||
} | ||
|
||
// altsTC is the credentials required for authenticating a connection using Google | ||
// Transport Security. It implements credentials.TransportCredentials interface. | ||
type altsTC struct { | ||
info *credentials.ProtocolInfo | ||
hsAddr string | ||
side core.Side | ||
accounts []string | ||
} | ||
|
||
// NewClientALTS constructs a client-side ALTS TransportCredentials object. | ||
func NewClientALTS(targetServiceAccounts []string) credentials.TransportCredentials { | ||
return newALTS(core.ClientSide, targetServiceAccounts) | ||
} | ||
|
||
// NewServerALTS constructs a server-side ALTS TransportCredentials object. | ||
func NewServerALTS() credentials.TransportCredentials { | ||
return newALTS(core.ServerSide, nil) | ||
} | ||
|
||
func newALTS(side core.Side, accounts []string) credentials.TransportCredentials { | ||
if *enableUntrustedALTS { | ||
grpclog.Warning("untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS handshaker service.") | ||
} | ||
|
||
return &altsTC{ | ||
info: &credentials.ProtocolInfo{ | ||
SecurityProtocol: "alts", | ||
SecurityVersion: "1.0", | ||
}, | ||
side: side, | ||
accounts: accounts, | ||
} | ||
} | ||
|
||
// ClientHandshake implements the client side handshake protocol. | ||
func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { | ||
if !*enableUntrustedALTS && !vmOnGCP { | ||
return nil, nil, ErrUntrustedPlatform | ||
} | ||
|
||
// Connecting to ALTS handshaker service. | ||
hsConn, err := service.Dial() | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
// Do not close hsConn since it is shared with other handshakes. | ||
|
||
opts := handshaker.DefaultClientHandshakerOptions() | ||
opts.TargetServiceAccounts = g.accounts | ||
opts.RPCVersions = &altspb.RpcProtocolVersions{ | ||
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{ | ||
Major: protocolVersionMaxMajor, | ||
Minor: protocolVersionMaxMinor, | ||
}, | ||
MinRpcVersion: &altspb.RpcProtocolVersions_Version{ | ||
Major: protocolVersionMinMajor, | ||
Minor: protocolVersionMinMinor, | ||
}, | ||
} | ||
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
secConn, authInfo, err := chs.ClientHandshake(ctx) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
altsAuthInfo, ok := authInfo.(AuthInfo) | ||
if !ok { | ||
return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo") | ||
} | ||
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) | ||
if !match { | ||
return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) | ||
} | ||
return secConn, authInfo, nil | ||
} | ||
|
||
// ServerHandshake implements the server side ALTS handshaker. | ||
func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { | ||
if !*enableUntrustedALTS && !vmOnGCP { | ||
return nil, nil, ErrUntrustedPlatform | ||
} | ||
// Connecting to ALTS handshaker service. | ||
hsConn, err := service.Dial() | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
// Do not close hsConn since it's shared with other handshakes. | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) | ||
defer cancel() | ||
opts := handshaker.DefaultServerHandshakerOptions() | ||
opts.RPCVersions = &altspb.RpcProtocolVersions{ | ||
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{ | ||
Major: protocolVersionMaxMajor, | ||
Minor: protocolVersionMaxMinor, | ||
}, | ||
MinRpcVersion: &altspb.RpcProtocolVersions_Version{ | ||
Major: protocolVersionMinMajor, | ||
Minor: protocolVersionMinMinor, | ||
}, | ||
} | ||
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
secConn, authInfo, err := shs.ServerHandshake(ctx) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
altsAuthInfo, ok := authInfo.(AuthInfo) | ||
if !ok { | ||
return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo") | ||
} | ||
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) | ||
if !match { | ||
return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) | ||
} | ||
return secConn, authInfo, nil | ||
} | ||
|
||
func (g *altsTC) Info() credentials.ProtocolInfo { | ||
return *g.info | ||
} | ||
|
||
func (g *altsTC) Clone() credentials.TransportCredentials { | ||
info := *g.info | ||
return &altsTC{ | ||
info: &info, | ||
} | ||
} | ||
|
||
func (g *altsTC) OverrideServerName(serverNameOverride string) error { | ||
g.info.ServerName = serverNameOverride | ||
return nil | ||
} | ||
|
||
// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2. | ||
func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int { | ||
switch { | ||
case v1.GetMajor() > v2.GetMajor(): | ||
fallthrough | ||
case v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor(): | ||
return 1 | ||
} | ||
switch { | ||
case v1.GetMajor() < v2.GetMajor(): | ||
fallthrough | ||
case v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor(): | ||
return -1 | ||
} | ||
return 0 | ||
} | ||
|
||
// checkRPCVersions performs a version check between local and peer rpc protocol | ||
// versions. This function returns true if the check passes which means both | ||
// parties agreed on a common rpc protocol to use, and false otherwise. The | ||
// function also returns the highest common RPC protocol version both parties | ||
// agreed on. | ||
func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) { | ||
if local == nil || peer == nil { | ||
grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.") | ||
return false, nil | ||
} | ||
|
||
// maxCommonVersion is MIN(local.max, peer.max). | ||
maxCommonVersion := local.GetMaxRpcVersion() | ||
if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 { | ||
maxCommonVersion = peer.GetMaxRpcVersion() | ||
} | ||
|
||
// minCommonVersion is MAX(local.min, peer.min). | ||
minCommonVersion := peer.GetMinRpcVersion() | ||
if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 { | ||
minCommonVersion = local.GetMinRpcVersion() | ||
} | ||
|
||
if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 { | ||
return false, nil | ||
} | ||
return true, maxCommonVersion | ||
} |
Oops, something went wrong.