Skip to content

Commit

Permalink
Add X-Request-Id and User-Agent headers to attestation requests
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed May 21, 2024
1 parent 6a28ca4 commit 2c23fb4
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
16 changes: 14 additions & 2 deletions tpm/attestation/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -214,7 +215,7 @@ func (ac *Client) attest(ctx context.Context, info *tpm.Info, ek *tpm.EK, attest
}

attestURL := ac.baseURL.JoinPath("attest").String()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, attestURL, bytes.NewReader(body))
req, err := newRequest(ctx, http.MethodPost, attestURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed creating POST http request for %q: %w", attestURL, err)
}
Expand Down Expand Up @@ -258,7 +259,7 @@ func (ac *Client) secret(ctx context.Context, secret []byte) (*secretResponse, e
}

secretURL := ac.baseURL.JoinPath("secret").String()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, secretURL, bytes.NewReader(body))
req, err := newRequest(ctx, http.MethodPost, secretURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed creating POST http request for %q: %w", secretURL, err)
}
Expand All @@ -280,3 +281,14 @@ func (ac *Client) secret(ctx context.Context, secret []byte) (*secretResponse, e

return &secretResp, nil
}

func newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {

Check failure on line 285 in tpm/attestation/client.go

View workflow job for this annotation

GitHub Actions / ci / lint / lint

importShadow: shadow of imported package 'url' (gocritic)
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
enforceRequestID(req)
setUserAgent(req)

return req, nil
}
7 changes: 7 additions & 0 deletions tpm/attestation/client_simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func mustParseURL(t *testing.T, urlString string) *url.URL {

func TestClient_Attest(t *testing.T) {
ctx := context.Background()
ctx = NewRequestIDContext(ctx, "requestID")
instance := newSimulatedTPM(t)
ak, err := instance.CreateAK(ctx, "ak1")
require.NoError(t, err)
Expand Down Expand Up @@ -140,6 +141,9 @@ func TestClient_Attest(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/attest":
assert.Equal(t, "step-attestation-http-client/1.0", r.Header.Get("User-Agent"))
assert.Equal(t, "requestID", r.Header.Get("X-Request-Id"))

var ar attestationRequest
err := json.NewDecoder(r.Body).Decode(&ar)
require.NoError(t, err)
Expand All @@ -165,6 +169,9 @@ func TestClient_Attest(t *testing.T) {
Secret: encryptedCredentials.Secret,
})
case "/secret":
assert.Equal(t, "step-attestation-http-client/1.0", r.Header.Get("User-Agent"))
assert.Equal(t, "requestID", r.Header.Get("X-Request-Id"))

var sr secretRequest
err := json.NewDecoder(r.Body).Decode(&sr)
require.NoError(t, err)
Expand Down
52 changes: 52 additions & 0 deletions tpm/attestation/requestid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package attestation

import (
"context"
"net/http"

"go.step.sm/crypto/randutil"
)

type requestIDContextKey struct{}

// NewRequestIDContext returns a new context with the given request ID added to the
// context.
func NewRequestIDContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDContextKey{}, requestID)
}

// RequestIDFromContext returns the request ID from the context if it exists.
// and is not empty.
func RequestIDFromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDContextKey{}).(string)
return v, ok && v != ""
}

// requestIDHeader is the header name used for propagating request IDs from
// the CA client to the CA and back again.
const requestIDHeader = "X-Request-Id"

// newRequestID generates a new random UUIDv4 request ID. If it fails,
// the request ID will be the empty string.
func newRequestID() string {
requestID, err := randutil.UUIDv4()
if err != nil {
return ""
}

return requestID
}

// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's
// empty, the context is searched for a request ID. If that's also empty, a new
// request ID is generated.
func enforceRequestID(r *http.Request) {
if requestID := r.Header.Get(requestIDHeader); requestID == "" {
if reqID, ok := RequestIDFromContext(r.Context()); ok {
requestID = reqID
} else {
requestID = newRequestID()
}
r.Header.Set(requestIDHeader, requestID)
}
}
10 changes: 10 additions & 0 deletions tpm/attestation/useragent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package attestation

import "net/http"

// UserAgent will set the User-Agent header in the client requests.
var UserAgent = "step-attestation-http-client/1.0"

func setUserAgent(r *http.Request) {
r.Header.Set("User-Agent", UserAgent)
}

0 comments on commit 2c23fb4

Please sign in to comment.