Skip to content

Commit

Permalink
Add IMDS version 2 support. (#1489)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfoley1 authored Jun 2, 2021
1 parent bc5ede6 commit 219b23f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
37 changes: 35 additions & 2 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ const (
defaultECSRoleEndpoint = "http://169.254.170.2"
defaultSTSRoleEndpoint = "https://sts.amazonaws.com"
defaultIAMSecurityCredsPath = "/latest/meta-data/iam/security-credentials/"
tokenRequestTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
tokenPath = "/latest/api/token"
tokenTTL = "21600"
tokenRequestHeader = "X-aws-ec2-metadata-token"
)

// NewIAM returns a pointer to a new Credentials object wrapping the IAM.
Expand Down Expand Up @@ -192,11 +196,14 @@ func getIAMRoleURL(endpoint string) (*url.URL, error) {
// with the current EC2 service. If there are no credentials,
// or there is an error making or receiving the request.
// http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html
func listRoleNames(client *http.Client, u *url.URL) ([]string, error) {
func listRoleNames(client *http.Client, u *url.URL, token string) ([]string, error) {
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
if token != "" {
req.Header.Add(tokenRequestHeader, token)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -242,12 +249,35 @@ func getEcsTaskCredentials(client *http.Client, endpoint string) (ec2RoleCredRes
return respCreds, nil
}

func fetchIMDSToken(client *http.Client, endpoint string) (string, error) {
req, err := http.NewRequest(http.MethodPut, endpoint+tokenPath, nil)
if err != nil {
return "", err
}
req.Header.Add(tokenRequestTTLHeader, tokenTTL)
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", errors.New(resp.Status)
}
return string(data), nil
}

// getCredentials - obtains the credentials from the IAM role name associated with
// the current EC2 service.
//
// If the credentials cannot be found, or there is an error
// reading the response an error will be returned.
func getCredentials(client *http.Client, endpoint string) (ec2RoleCredRespBody, error) {
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
token, _ := fetchIMDSToken(client, endpoint)

// http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html
u, err := getIAMRoleURL(endpoint)
Expand All @@ -256,7 +286,7 @@ func getCredentials(client *http.Client, endpoint string) (ec2RoleCredRespBody,
}

// http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html
roleNames, err := listRoleNames(client, u)
roleNames, err := listRoleNames(client, u, token)
if err != nil {
return ec2RoleCredRespBody{}, err
}
Expand All @@ -280,6 +310,9 @@ func getCredentials(client *http.Client, endpoint string) (ec2RoleCredRespBody,
if err != nil {
return ec2RoleCredRespBody{}, err
}
if token != "" {
req.Header.Add(tokenRequestHeader, token)
}

resp, err := client.Do(req)
if err != nil {
Expand Down
47 changes: 47 additions & 0 deletions pkg/credentials/iam_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
)
Expand Down Expand Up @@ -105,6 +106,40 @@ func initTestServer(expireOn string, failAssume bool) *httptest.Server {
return server
}

// Instance Metadata Service with V1 disabled.
func initIMDSv2Server(expireOn string) *httptest.Server {
imdsToken := "IMDSTokenabc123=="
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println(r.URL.Path)
fmt.Println(r.Method)
if r.URL.Path == "/latest/api/token" && r.Method == "PUT" {
ttlHeader := r.Header.Get("X-aws-ec2-metadata-token-ttl-seconds")
ttl, err := strconv.ParseInt(ttlHeader, 10, 32)
if err != nil || ttl < 0 || ttl > 21600 {
http.Error(w, "", http.StatusBadRequest)
return
}
w.Header().Set("X-Aws-Ec2-Metadata-Token-Ttl-Seconds", ttlHeader)
w.Write([]byte(imdsToken))
return
}
token := r.Header.Get("X-aws-ec2-metadata-token")
if token != imdsToken {
http.Error(w, r.URL.Path, http.StatusUnauthorized)
return
}

if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
fmt.Fprintf(w, credsRespTmpl, expireOn)
} else {
http.Error(w, "bad request", http.StatusBadRequest)
}
}))
return server
}

func initEcsTaskTestServer(expireOn string) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, credsRespEcsTaskTmpl, expireOn)
Expand Down Expand Up @@ -391,3 +426,15 @@ func TestStsCn(t *testing.T) {
t.Error("Expected creds to be expired.")
}
}

func TestIMDSv1Blocked(t *testing.T) {
server := initIMDSv2Server("2014-12-16T01:51:37Z")
p := &IAM{
Client: http.DefaultClient,
Endpoint: server.URL,
}
_, err := p.Retrieve()
if err != nil {
t.Errorf("Unexpected IMDSv2 failure %s", err)
}
}

0 comments on commit 219b23f

Please sign in to comment.