Skip to content

Commit

Permalink
Allow requesting a join token with IAM method from the web api (#11339)
Browse files Browse the repository at this point in the history

Co-authored-by: Jim Bishopp <[email protected]>
Co-authored-by: Nic Klaassen <[email protected]>
  • Loading branch information
3 people authored Apr 19, 2022
1 parent 2eca7c7 commit 57cc2ed
Show file tree
Hide file tree
Showing 6 changed files with 510 additions and 29 deletions.
2 changes: 1 addition & 1 deletion api/types/provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func NewProvisionTokenFromSpec(token string, expires time.Time, spec ProvisionTo
}

// MustCreateProvisionToken returns a new valid provision token
// or panics, used in testes
// or panics, used in tests
func MustCreateProvisionToken(token string, roles SystemRoles, expires time.Time) ProvisionToken {
t, err := NewProvisionToken(token, roles, expires)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion lib/services/provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type Provisioner interface {
}

// MustCreateProvisionToken returns a new valid provision token
// or panics, used in testes
// or panics, used in tests
func MustCreateProvisionToken(token string, roles types.SystemRoles, expires time.Time) types.ProvisionToken {
t, err := types.NewProvisionToken(token, roles, expires)
if err != nil {
Expand Down
34 changes: 29 additions & 5 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2008,9 +2008,11 @@ func (s *WebSuite) TestGetClusterDetails(c *C) {

func TestTokenGeneration(t *testing.T) {
tt := []struct {
name string
roles types.SystemRoles
shouldErr bool
name string
roles types.SystemRoles
shouldErr bool
joinMethod types.JoinMethod
allow []*types.TokenRule
}{
{
name: "single node role",
Expand All @@ -2037,6 +2039,19 @@ func TestTokenGeneration(t *testing.T) {
roles: types.SystemRoles{},
shouldErr: true,
},
{
name: "cannot request token with IAM join method without allow field",
roles: types.SystemRoles{types.RoleNode},
joinMethod: types.JoinMethodIAM,
shouldErr: true,
},
{
name: "can request token with IAM join method",
roles: types.SystemRoles{types.RoleNode},
joinMethod: types.JoinMethodIAM,
allow: []*types.TokenRule{{AWSAccount: "1234"}},
shouldErr: false,
},
}

for _, tc := range tt {
Expand All @@ -2047,8 +2062,10 @@ func TestTokenGeneration(t *testing.T) {
pack := proxy.authPack(t, "[email protected]")

endpoint := pack.clt.Endpoint("webapi", "token")
re, err := pack.clt.PostJSON(context.Background(), endpoint, createTokenRequest{
Roles: tc.roles,
re, err := pack.clt.PostJSON(context.Background(), endpoint, types.ProvisionTokenSpecV2{
Roles: tc.roles,
JoinMethod: tc.joinMethod,
Allow: tc.allow,
})

if tc.shouldErr {
Expand All @@ -2066,6 +2083,13 @@ func TestTokenGeneration(t *testing.T) {
generatedToken, err := proxy.auth.Auth().GetToken(context.Background(), responseToken.ID)
require.NoError(t, err)
require.Equal(t, tc.roles, generatedToken.GetRoles())

expectedJoinMethod := tc.joinMethod
if tc.joinMethod == "" {
expectedJoinMethod = types.JoinMethodToken
}
// if no joinMethod is provided, expect token method
require.Equal(t, expectedJoinMethod, generatedToken.GetJoinMethod())
})
}
}
Expand Down
140 changes: 120 additions & 20 deletions lib/web/join_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ import (
"context"
"encoding/hex"
"fmt"
"hash/fnv"
"net/http"
"net/url"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"time"
Expand All @@ -46,6 +49,8 @@ type nodeJoinToken struct {
ID string `json:"id"`
// Expiry is token expiration time.
Expiry time.Time `json:"expiry,omitempty"`
// Method is the join method that the token supports
Method types.JoinMethod `json:"method"`
}

// scriptSettings is used to hold values which are passed into the function that
Expand All @@ -55,28 +60,75 @@ type scriptSettings struct {
appInstallMode bool
appName string
appURI string
}

// createTokenRequest is the expected request body of
// the endpoint to create token
type createTokenRequest struct {
Roles types.SystemRoles `json:"roles"`
joinMethod string
}

func (h *Handler) createTokenHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params, ctx *SessionContext) (interface{}, error) {
var req createTokenRequest
var req types.ProvisionTokenSpecV2
if err := httplib.ReadJSON(r, &req); err != nil {
log.WithError(err).Error("error reading body")
return nil, trace.Wrap(err)
}

clt, err := ctx.GetClient()
if err != nil {
log.WithError(err).Error("error getting client")
return nil, trace.Wrap(err)
}

return createJoinToken(r.Context(), clt, req.Roles)
var expires time.Time
var tokenName string
switch req.JoinMethod {
case types.JoinMethodIAM:
// to prevent generation of redundant IAM tokens
// we generate a deterministic name for them
tokenName, err = generateIAMTokenName(req.Allow)
if err != nil {
return nil, trace.Wrap(err)
}
// if a token with this name is found and it has indeed the same rule set,
// return it. Otherwise, go ahead and create it
t, err := clt.GetToken(r.Context(), tokenName)
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}

if err == nil {
// check if the token found has the right rules
if t.GetJoinMethod() != types.JoinMethodIAM || !isSameRuleSet(req.Allow, t.GetAllowRules()) {
return nil, trace.BadParameter("failed to create token: token with name %q already exists and does not have the expected allow rules", tokenName)
}

return &nodeJoinToken{
ID: t.GetName(),
Expiry: *t.GetMetadata().Expires,
Method: t.GetJoinMethod(),
}, nil
}

// IAM tokens should 'never' expire
expires = time.Now().UTC().AddDate(1000, 0, 0)
default:
tokenName, err = utils.CryptoRandomHex(auth.TokenLenBytes)
if err != nil {
return nil, trace.Wrap(err)
}
expires = time.Now().UTC().Add(defaults.NodeJoinTokenTTL)
}

provisionToken, err := types.NewProvisionTokenFromSpec(tokenName, expires, req)
if err != nil {
return nil, trace.Wrap(err)
}

err = clt.UpsertToken(r.Context(), provisionToken)
if err != nil {
return nil, trace.Wrap(err)
}

return &nodeJoinToken{
ID: tokenName,
Expiry: expires,
Method: provisionToken.GetJoinMethod(),
}, nil
}

func (h *Handler) createNodeTokenHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params, ctx *SessionContext) (interface{}, error) {
Expand All @@ -99,6 +151,7 @@ func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request
settings := scriptSettings{
token: params.ByName("token"),
appInstallMode: false,
joinMethod: r.URL.Query().Get("method"),
}

script, err := getJoinScript(settings, h.GetProxyClient())
Expand Down Expand Up @@ -176,17 +229,20 @@ func createJoinToken(ctx context.Context, m nodeAPIGetter, roles types.SystemRol
}

func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
// This token does not need to be validated against the backend because it's not used to
// reveal any sensitive information. However, we still need to perform a simple input
// validation check by verifying that the token was auto-generated.
// Auto-generated tokens must be encoded and must have an expected length.
decodedToken, err := hex.DecodeString(settings.token)
if err != nil {
return "", trace.Wrap(err)
}
// Skip decoding validation for IAM tokens since they are generated with a different method
if settings.joinMethod != string(types.JoinMethodIAM) {
// This token does not need to be validated against the backend because it's not used to
// reveal any sensitive information. However, we still need to perform a simple input
// validation check by verifying that the token was auto-generated.
// Auto-generated tokens must be encoded and must have an expected length.
decodedToken, err := hex.DecodeString(settings.token)
if err != nil {
return "", trace.Wrap(err)
}

if len(decodedToken) != auth.TokenLenBytes {
return "", trace.BadParameter("invalid token length")
if len(decodedToken) != auth.TokenLenBytes {
return "", trace.BadParameter("invalid token length")
}
}

// Get hostname and port from proxy server address.
Expand Down Expand Up @@ -237,6 +293,7 @@ func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
"appInstallMode": strconv.FormatBool(settings.appInstallMode),
"appName": settings.appName,
"appURI": settings.appURI,
"joinMethod": settings.joinMethod,
})
if err != nil {
return "", trace.Wrap(err)
Expand All @@ -245,6 +302,49 @@ func getJoinScript(settings scriptSettings, m nodeAPIGetter) (string, error) {
return buf.String(), nil
}

// generateIAMTokenName makes a deterministic name for a iam join token
// based on its rule set
func generateIAMTokenName(rules []*types.TokenRule) (string, error) {
// sort the rules by (account ID, arn)
// to make sure a set of rules will produce the same hash,
// no matter the order they are in the slice
orderedRules := make([]*types.TokenRule, len(rules))
copy(orderedRules, rules)
sortRules(orderedRules)

h := fnv.New32a()
for _, r := range orderedRules {
s := fmt.Sprintf("%s%s", r.AWSAccount, r.AWSARN)
_, err := h.Write([]byte(s))
if err != nil {
return "", trace.Wrap(err)
}
}

return fmt.Sprintf("teleport-ui-iam-%d", h.Sum32()), nil
}

// sortRules sorts a slice of rules based on their AWS Account ID and ARN
func sortRules(rules []*types.TokenRule) {
sort.Slice(rules, func(i, j int) bool {
iAcct, jAcct := rules[i].AWSAccount, rules[j].AWSAccount
// if accountID is the same, sort based on arn
if iAcct == jAcct {
arn1, arn2 := rules[i].AWSARN, rules[j].AWSARN
return arn1 < arn2
}

return iAcct < jAcct
})
}

// isSameRuleSet check if r1 and r2 are the same rules, ignoring the order
func isSameRuleSet(r1 []*types.TokenRule, r2 []*types.TokenRule) bool {
sortRules(r1)
sortRules(r2)
return reflect.DeepEqual(r1, r2)
}

type nodeAPIGetter interface {
// GenerateToken creates a special provisioning token for a new SSH server
// that is valid for ttl period seconds.
Expand Down
Loading

0 comments on commit 57cc2ed

Please sign in to comment.