diff --git a/aws-http-auth/credentials/credentials.go b/aws-http-auth/credentials/credentials.go new file mode 100644 index 00000000..f3eb253c --- /dev/null +++ b/aws-http-auth/credentials/credentials.go @@ -0,0 +1,14 @@ +// Package credentials exposes container types for AWS credentials. +package credentials + +import ( + "time" +) + +// Credentials describes a shared-secret AWS credential identity. +type Credentials struct { + AccessKeyID string + SecretAccessKey string + SessionToken string + Expires time.Time +} diff --git a/aws-http-auth/go.mod b/aws-http-auth/go.mod new file mode 100644 index 00000000..fed9a6fa --- /dev/null +++ b/aws-http-auth/go.mod @@ -0,0 +1,3 @@ +module github.com/aws/smithy-go/aws-http-auth + +go 1.21 diff --git a/aws-http-auth/go.sum b/aws-http-auth/go.sum new file mode 100644 index 00000000..e69de29b diff --git a/aws-http-auth/internal/v4/signer.go b/aws-http-auth/internal/v4/signer.go new file mode 100644 index 00000000..3ded1048 --- /dev/null +++ b/aws-http-auth/internal/v4/signer.go @@ -0,0 +1,225 @@ +package v4 + +import ( + "encoding/hex" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/aws/smithy-go/aws-http-auth/credentials" + v4 "github.com/aws/smithy-go/aws-http-auth/v4" +) + +const ( + // TimeFormat is the full-width form to be used in the X-Amz-Date header. + TimeFormat = "20060102T150405Z" + + // ShortTimeFormat is the shortened form used in credential scope. + ShortTimeFormat = "20060102" +) + +// Signer is the implementation structure for all variants of v4 signing. +type Signer struct { + Request *http.Request + PayloadHash []byte + Time time.Time + Credentials credentials.Credentials + Options v4.SignerOptions + + // variant-specific inputs + Algorithm string + CredentialScope string + Finalizer Finalizer +} + +// Finalizer performs the final step in v4 signing, deriving a signature for +// the string-to-sign with algorithm-specific key material. +type Finalizer interface { + SignString(string) (string, error) +} + +// Do performs v4 signing, modifying the request in-place with the +// signature. +// +// Do should be called exactly once for a configured Signer. The behavior of +// doing otherwise is undefined. +func (s *Signer) Do() error { + if err := s.init(); err != nil { + return err + } + + s.setRequiredHeaders() + + canonicalRequest, signedHeaders := s.buildCanonicalRequest() + stringToSign := s.buildStringToSign(canonicalRequest) + signature, err := s.Finalizer.SignString(stringToSign) + if err != nil { + return err + } + + s.Request.Header.Set("Authorization", + s.buildAuthorizationHeader(signature, signedHeaders)) + + return nil +} + +func (s *Signer) init() error { + // it might seem like time should also get defaulted/normalized here, but + // in practice sigv4 and sigv4a both need to do that beforehand to + // calculate scope, so there's no point + + if s.Options.HeaderRules == nil { + s.Options.HeaderRules = defaultHeaderRules{} + } + + if err := s.resolvePayloadHash(); err != nil { + return err + } + + return nil +} + +// ensure we have a value for payload hash, whether that be explicit, implicit, +// or the unsigned sentinel +func (s *Signer) resolvePayloadHash() error { + if len(s.PayloadHash) > 0 { + return nil + } + + rs, ok := s.Request.Body.(io.ReadSeeker) + if !ok || s.Options.DisableImplicitPayloadHashing { + s.PayloadHash = v4.UnsignedPayload() + return nil + } + + p, err := rtosha(rs) + if err != nil { + return err + } + + s.PayloadHash = p + return nil +} + +func (s *Signer) setRequiredHeaders() { + headers := s.Request.Header + + s.Request.Header.Set("Host", s.Request.Host) + s.Request.Header.Set("X-Amz-Date", s.Time.Format(TimeFormat)) + + if len(s.Credentials.SessionToken) > 0 { + s.Request.Header.Set("X-Amz-Security-Token", s.Credentials.SessionToken) + } + if len(s.PayloadHash) > 0 && s.Options.AddPayloadHashHeader { + headers.Set("X-Amz-Content-Sha256", payloadHashString(s.PayloadHash)) + } +} + +func (s *Signer) buildCanonicalRequest() (string, string) { + canonPath := s.Request.URL.EscapedPath() + // https://docs.aws.amazon.com/IAM/latest/UserGuide/create-signed-request.html: + // if input has no path, "/" is used + if len(canonPath) == 0 { + canonPath = "/" + } + if !s.Options.DisableDoublePathEscape { + canonPath = uriEncode(canonPath) + } + + query := s.Request.URL.Query() + for key := range query { + sort.Strings(query[key]) + } + canonQuery := strings.Replace(query.Encode(), "+", "%20", -1) + + canonHeaders, signedHeaders := s.buildCanonicalHeaders() + + req := strings.Join([]string{ + s.Request.Method, + canonPath, + canonQuery, + canonHeaders, + signedHeaders, + payloadHashString(s.PayloadHash), + }, "\n") + + return req, signedHeaders +} + +func (s *Signer) buildCanonicalHeaders() (canon, signed string) { + var canonHeaders []string + signedHeaders := map[string][]string{} + + // step 1: find what we're signing + for header, values := range s.Request.Header { + lowercase := strings.ToLower(header) + if !s.Options.HeaderRules.IsSigned(lowercase) { + continue + } + + canonHeaders = append(canonHeaders, lowercase) + signedHeaders[lowercase] = values + } + sort.Strings(canonHeaders) + + // step 2: indexing off of the list we built previously (which guarantees + // alphabetical order), build the canonical list + var ch strings.Builder + for i := range canonHeaders { + ch.WriteString(canonHeaders[i]) + ch.WriteRune(':') + + // headers can have multiple values + values := signedHeaders[canonHeaders[i]] + for j, value := range values { + ch.WriteString(strings.TrimSpace(value)) + if j < len(values)-1 { + ch.WriteRune(',') + } + } + ch.WriteRune('\n') + } + + return ch.String(), strings.Join(canonHeaders, ";") +} + +func (s *Signer) buildStringToSign(canonicalRequest string) string { + return strings.Join([]string{ + s.Algorithm, + s.Time.Format(TimeFormat), + s.CredentialScope, + hex.EncodeToString(Stosha(canonicalRequest)), + }, "\n") +} + +func (s *Signer) buildAuthorizationHeader(signature, headers string) string { + return fmt.Sprintf("%s Credential=%s, SignedHeaders=%s, Signature=%s", + s.Algorithm, + s.Credentials.AccessKeyID+"/"+s.CredentialScope, + headers, + signature) +} + +func payloadHashString(p []byte) string { + if string(p) == "UNSIGNED-PAYLOAD" { + return string(p) // sentinel, do not hex-encode + } + return hex.EncodeToString(p) +} + +// ResolveTime initializes a time value for signing. +func ResolveTime(t time.Time) time.Time { + if t.IsZero() { + return time.Now().UTC() + } + return t.UTC() +} + +type defaultHeaderRules struct{} + +func (defaultHeaderRules) IsSigned(h string) bool { + return h == "host" || strings.HasPrefix(h, "x-amz-") +} diff --git a/aws-http-auth/internal/v4/signer_test.go b/aws-http-auth/internal/v4/signer_test.go new file mode 100644 index 00000000..842800f8 --- /dev/null +++ b/aws-http-auth/internal/v4/signer_test.go @@ -0,0 +1,410 @@ +package v4 + +import ( + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/aws/smithy-go/aws-http-auth/credentials" + v4 "github.com/aws/smithy-go/aws-http-auth/v4" +) + +// Tests herein are meant to verify individual components of the v4 signer +// implementation and should generally not be calling Do() directly. +// +// The full algorithm contained in Do() is covered by tests for the +// Sigv4/Sigv4a APIs. + +func seekable(v string) io.ReadSeekCloser { + return readseekcloser{strings.NewReader(v)} +} + +type readseekcloser struct { + io.ReadSeeker +} + +func (readseekcloser) Close() error { return nil } + +type identityFinalizer struct{} + +func (identityFinalizer) SignString(v string) (string, error) { + return v, nil +} + +func TestBuildCanonicalRequest_SignedPayload(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, + "https://service.region.amazonaws.com", + seekable("{}")) + if err != nil { + t.Fatal(err) + } + + req.URL.Path = "/path1/path 2" + req.URL.RawQuery = "a=b" + req.Header.Set("Host", "service.region.amazonaws.com") + req.Header.Set("X-Amz-Foo", "\t \tbar ") + s := &Signer{ + Request: req, + PayloadHash: Stosha("{}"), + Options: v4.SignerOptions{ + HeaderRules: defaultHeaderRules{}, + }, + } + + expect := `POST +/path1/path%25202 +a=b +host:service.region.amazonaws.com +x-amz-foo:bar + +host;x-amz-foo +44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a` + + actual, _ := s.buildCanonicalRequest() + if expect != actual { + t.Errorf("canonical request\n%s\n!=\n%s", expect, actual) + } +} + +func TestBuildCanonicalRequest_NoPath(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, + "https://service.region.amazonaws.com", + seekable("{}")) + if err != nil { + t.Fatal(err) + } + + req.URL.Path = "" + req.URL.RawQuery = "a=b" + req.Header.Set("Host", "service.region.amazonaws.com") + req.Header.Set("X-Amz-Foo", "\t \tbar ") + s := &Signer{ + Request: req, + PayloadHash: Stosha("{}"), + Options: v4.SignerOptions{ + HeaderRules: defaultHeaderRules{}, + }, + } + + expect := `POST +/ +a=b +host:service.region.amazonaws.com +x-amz-foo:bar + +host;x-amz-foo +44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a` + + actual, _ := s.buildCanonicalRequest() + if expect != actual { + t.Errorf("canonical request\n%s\n!=\n%s", expect, actual) + } +} + +func TestBuildCanonicalRequest_DoubleHeader(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, + "https://service.region.amazonaws.com", + seekable("{}")) + if err != nil { + t.Fatal(err) + } + + req.URL.Path = "/" + req.Header.Set("X-Amz-Foo", "\t \tbar ") + req.Header.Set("Host", "service.region.amazonaws.com") + req.Header.Set("dontsignit", "dontsignit") // should be skipped + req.Header.Add("X-Amz-Foo", "\t \tbaz ") + s := &Signer{ + Request: req, + PayloadHash: Stosha("{}"), + Options: v4.SignerOptions{ + HeaderRules: defaultHeaderRules{}, + }, + } + + expect := `POST +/ + +host:service.region.amazonaws.com +x-amz-foo:bar,baz + +host;x-amz-foo +44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a` + + actual, _ := s.buildCanonicalRequest() + if expect != actual { + t.Errorf("canonical request\n%s\n!=\n%s", expect, actual) + } +} + +func TestBuildCanonicalRequest_SortQuery(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, + "https://service.region.amazonaws.com", + seekable("{}")) + if err != nil { + t.Fatal(err) + } + + req.URL.Path = "/" + req.URL.RawQuery = "a=b&%20b=c" + req.Header.Set("Host", "service.region.amazonaws.com") + s := &Signer{ + Request: req, + PayloadHash: v4.UnsignedPayload(), + Options: v4.SignerOptions{ + HeaderRules: defaultHeaderRules{}, + }, + } + + expect := `POST +/ +%20b=c&a=b +host:service.region.amazonaws.com + +host +UNSIGNED-PAYLOAD` + + actual, _ := s.buildCanonicalRequest() + if expect != actual { + t.Errorf("canonical request\n%s\n!=\n%s", expect, actual) + } +} + +func TestBuildCanonicalRequest_EmptyQuery(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, + "https://service.region.amazonaws.com", + seekable("{}")) + if err != nil { + t.Fatal(err) + } + + req.URL.Path = "/" + req.URL.RawQuery = "foo" + req.Header.Set("Host", "service.region.amazonaws.com") + s := &Signer{ + Request: req, + PayloadHash: v4.UnsignedPayload(), + Options: v4.SignerOptions{ + HeaderRules: defaultHeaderRules{}, + }, + } + + expect := `POST +/ +foo= +host:service.region.amazonaws.com + +host +UNSIGNED-PAYLOAD` + + actual, _ := s.buildCanonicalRequest() + if expect != actual { + t.Errorf("canonical request\n%s\n!=\n%s", expect, actual) + } +} + +func TestBuildCanonicalRequest_UnsignedPayload(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, + "https://service.region.amazonaws.com", + seekable("{}")) + if err != nil { + t.Fatal(err) + } + + req.URL.Path = "/path1/path 2" + req.URL.RawQuery = "a=b" + req.Header.Set("Host", "service.region.amazonaws.com") + req.Header.Set("X-Amz-Foo", "\t \tbar ") + s := &Signer{ + Request: req, + PayloadHash: []byte("UNSIGNED-PAYLOAD"), + Options: v4.SignerOptions{ + HeaderRules: defaultHeaderRules{}, + }, + } + + expect := `POST +/path1/path%25202 +a=b +host:service.region.amazonaws.com +x-amz-foo:bar + +host;x-amz-foo +UNSIGNED-PAYLOAD` + + actual, _ := s.buildCanonicalRequest() + if expect != actual { + t.Errorf("canonical request\n%s\n!=\n%s", expect, actual) + } +} + +func TestBuildCanonicalRequest_DisableDoubleEscape(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, + "https://service.region.amazonaws.com", + seekable("{}")) + if err != nil { + t.Fatal(err) + } + + req.URL.Path = "/path1/path 2" + req.URL.RawQuery = "a=b" + req.Header.Set("Host", "service.region.amazonaws.com") + req.Header.Set("X-Amz-Foo", "\t \tbar ") + s := &Signer{ + Request: req, + PayloadHash: []byte("UNSIGNED-PAYLOAD"), + Options: v4.SignerOptions{ + HeaderRules: defaultHeaderRules{}, + DisableDoublePathEscape: true, + }, + } + + expect := `POST +/path1/path%202 +a=b +host:service.region.amazonaws.com +x-amz-foo:bar + +host;x-amz-foo +UNSIGNED-PAYLOAD` + + actual, _ := s.buildCanonicalRequest() + if expect != actual { + t.Errorf("canonical request\n%s\n!=\n%s", expect, actual) + } +} + +func TestResolvePayloadHash_AlreadySet(t *testing.T) { + expect := "already set" + s := &Signer{ + PayloadHash: []byte(expect), + } + + err := s.resolvePayloadHash() + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + + if expect != string(s.PayloadHash) { + t.Fatalf("hash %q != %q", expect, s.PayloadHash) + } +} + +func TestResolvePayloadHash_Disabled(t *testing.T) { + expect := "UNSIGNED-PAYLOAD" + s := &Signer{ + Request: &http.Request{Body: seekable("foo")}, + Options: v4.SignerOptions{ + DisableImplicitPayloadHashing: true, + }, + } + + err := s.resolvePayloadHash() + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + + if expect != string(s.PayloadHash) { + t.Fatalf("hash %q != %q", expect, s.PayloadHash) + } +} + +type seekexploder struct { + io.ReadCloser +} + +func (seekexploder) Seek(int64, int) (int64, error) { + return 0, fmt.Errorf("boom") +} + +func TestResolvePayloadHash_SeekBlowsUp(t *testing.T) { + s := &Signer{ + Request: &http.Request{ + Body: seekexploder{seekable("foo")}, + }, + } + + err := s.resolvePayloadHash() + if err == nil { + t.Fatalf("expect err, got none") + } +} + +func TestResolvePayloadHash_OK(t *testing.T) { + expect := string(Stosha("foo")) + s := &Signer{ + Request: &http.Request{Body: seekable("foo")}, + } + + err := s.resolvePayloadHash() + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + + if expect != string(s.PayloadHash) { + t.Fatalf("hash %q != %q", expect, s.PayloadHash) + } +} + +func TestSetRequiredHeaders_All(t *testing.T) { + s := &Signer{ + Request: &http.Request{ + Host: "foo.service.com", + Header: http.Header{}, + }, + PayloadHash: []byte{0, 1, 2, 3}, + Time: time.Unix(0, 0).UTC(), + Credentials: credentials.Credentials{ + SessionToken: "session_token", + }, + Options: v4.SignerOptions{ + AddPayloadHashHeader: true, + }, + } + + s.setRequiredHeaders() + if actual := s.Request.Header.Get("Host"); s.Request.Host != actual { + t.Errorf("region header %q != %q", s.Request.Host, actual) + } + if expect, actual := "19700101T000000Z", s.Request.Header.Get("X-Amz-Date"); expect != actual { + t.Errorf("date header %q != %q", expect, actual) + } + if expect, actual := "session_token", s.Request.Header.Get("X-Amz-Security-Token"); expect != actual { + t.Errorf("token header %q != %q", expect, actual) + } + if expect, actual := "00010203", s.Request.Header.Get("X-Amz-Content-Sha256"); expect != actual { + t.Errorf("sha256 header %q != %q", expect, actual) + } +} + +func TestSetRequiredHeaders_UnsignedPayload(t *testing.T) { + s := &Signer{ + Request: &http.Request{ + Host: "foo.service.com", + Header: http.Header{}, + }, + PayloadHash: []byte("UNSIGNED-PAYLOAD"), + Time: time.Unix(0, 0).UTC(), + Credentials: credentials.Credentials{}, + Options: v4.SignerOptions{ + AddPayloadHashHeader: true, + }, + } + + s.setRequiredHeaders() + if actual := s.Request.Header.Get("Host"); s.Request.Host != actual { + t.Errorf("region header %q != %q", s.Request.Host, actual) + } + if expect, actual := "19700101T000000Z", s.Request.Header.Get("X-Amz-Date"); expect != actual { + t.Errorf("date header %q != %q", expect, actual) + } + if expect, actual := "", s.Request.Header.Get("X-Amz-Security-Token"); expect != actual { + t.Errorf("token header %q != %q", expect, actual) + } + if expect, actual := "UNSIGNED-PAYLOAD", s.Request.Header.Get("X-Amz-Content-Sha256"); expect != actual { + t.Errorf("sha256 header %q != %q", expect, actual) + } +} diff --git a/aws-http-auth/internal/v4/strings.go b/aws-http-auth/internal/v4/strings.go new file mode 100644 index 00000000..723b122f --- /dev/null +++ b/aws-http-auth/internal/v4/strings.go @@ -0,0 +1,59 @@ +package v4 + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io" +) + +var noEscape [256]bool + +func init() { + for i := 0; i < len(noEscape); i++ { + // AWS expects every character except these to be escaped + noEscape[i] = (i >= 'A' && i <= 'Z') || + (i >= 'a' && i <= 'z') || + (i >= '0' && i <= '9') || + i == '-' || + i == '.' || + i == '_' || + i == '~' || + i == '/' + } +} + +// uriEncode implements "Amazon-style" URL escaping. +func uriEncode(path string) string { + var buf bytes.Buffer + for i := 0; i < len(path); i++ { + c := path[i] + if noEscape[c] { + buf.WriteByte(c) + } else { + fmt.Fprintf(&buf, "%%%02X", c) + } + } + return buf.String() +} + +// rtosha computes the sha256 hash of the input Reader and rewinds it before +// returning. +func rtosha(r io.ReadSeeker) ([]byte, error) { + h := sha256.New() + if _, err := io.Copy(h, r); err != nil { + return nil, err + } + if _, err := r.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + return h.Sum(nil), nil +} + +// Stosha computes the sha256 hash of the given string. +func Stosha(s string) []byte { + h := sha256.New() + h.Write([]byte(s)) + return h.Sum(nil) +} diff --git a/aws-http-auth/sigv4/e2e_test.go b/aws-http-auth/sigv4/e2e_test.go new file mode 100644 index 00000000..0a6cf56f --- /dev/null +++ b/aws-http-auth/sigv4/e2e_test.go @@ -0,0 +1,154 @@ +//go:build e2e +// +build e2e + +package sigv4 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "os" + "testing" + + "github.com/aws/smithy-go/aws-http-auth/credentials" +) + +type closer struct{ io.ReadSeeker } + +func (closer) Close() error { return nil } + +type SQSClient struct { + Region string + Credentials credentials.Credentials + + HTTPClient *http.Client + Signer *Signer +} + +// all of these method definitions are very repetitive, it would be useful if +// there was some sort of API model we could generate code from... +func (c *SQSClient) CreateQueue(ctx context.Context, in *CreateQueueInput) (*CreateQueueOutput, error) { + var out CreateQueueOutput + if err := c.do(ctx, "CreateQueue", in, &out); err != nil { + return nil, err + } + return &out, nil +} + +type CreateQueueInput struct { + QueueName string `json:"QueueName,omitempty"` // This member is required. + + Attributes map[string]string `json:"Attributes,omitempty"` + Tags map[string]string `json:"Tags,omitempty"` +} + +type CreateQueueOutput struct { + QueueURL string `json:"QueueUrl"` +} + +func (c *SQSClient) DeleteQueue(ctx context.Context, in *DeleteQueueInput) (*DeleteQueueOutput, error) { + var out DeleteQueueOutput + if err := c.do(ctx, "DeleteQueue", in, &out); err != nil { + return nil, err + } + return &out, nil +} + +type DeleteQueueInput struct { + QueueURL string `json:"QueueUrl,omitempty"` // This member is required. +} + +type DeleteQueueOutput struct{} + +func (c *SQSClient) do(ctx context.Context, target string, in, out any) error { + // init (featuring budget resolve endpoint) + endpt := fmt.Sprintf("https://sqs.%s.amazonaws.com", c.Region) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpt, http.NoBody) + if err != nil { + return fmt.Errorf("new http request: %w", err) + } + + // serialize + req.URL.Path = "/" + req.Header.Set("X-Amz-Target", fmt.Sprintf("AmazonSQS.%s", target)) + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + payload, err := json.Marshal(in) + if err != nil { + return fmt.Errorf("serialize request: %w", err) + } + req.Body = closer{bytes.NewReader(payload)} + req.ContentLength = int64(len(payload)) + + // sign + err = c.Signer.SignRequest(&SignRequestInput{ + Request: req, + Credentials: c.Credentials, + Service: "sqs", + Region: c.Region, + }) + if err != nil { + return fmt.Errorf("sign request: %w", err) + } + + // round-trip + resp, err := c.HTTPClient.Do(req) + if err != nil { + return fmt.Errorf("do request: %w", err) + } + defer resp.Body.Close() + + // deserialize + data, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("request error: %s: %s", resp.Status, data) + } + if len(data) == 0 { + return nil + } + if err := json.Unmarshal(data, out); err != nil { + return fmt.Errorf("deserialize response: %w", err) + } + + return nil +} + +func TestE2E_SQS(t *testing.T) { + svc := &SQSClient{ + Region: "us-east-1", + Credentials: credentials.Credentials{ + AccessKeyID: os.Getenv("AWS_ACCESS_KEY_ID"), + SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + SessionToken: os.Getenv("AWS_SESSION_TOKEN"), + }, + HTTPClient: http.DefaultClient, + Signer: New(), + } + + queueName := fmt.Sprintf("aws-http-auth-e2etest-%d", rand.Int()%(2<<15)) + + out, err := svc.CreateQueue(context.Background(), &CreateQueueInput{ + QueueName: queueName, + }) + if err != nil { + t.Fatalf("create queue: %v", err) + } + + queueURL := out.QueueURL + t.Logf("created test queue %s", queueURL) + + _, err = svc.DeleteQueue(context.Background(), &DeleteQueueInput{ + QueueURL: queueURL, + }) + if err != nil { + t.Fatalf("delete queue: %v", err) + } + + t.Log("deleted test queue") +} diff --git a/aws-http-auth/sigv4/sigv4.go b/aws-http-auth/sigv4/sigv4.go new file mode 100644 index 00000000..fcaab808 --- /dev/null +++ b/aws-http-auth/sigv4/sigv4.go @@ -0,0 +1,154 @@ +// Package sigv4 implements request signing for the basic form AWS Signature +// Version 4. +package sigv4 + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "net/http" + "strings" + "time" + + "github.com/aws/smithy-go/aws-http-auth/credentials" + v4internal "github.com/aws/smithy-go/aws-http-auth/internal/v4" + v4 "github.com/aws/smithy-go/aws-http-auth/v4" +) + +const algorithm = "AWS4-HMAC-SHA256" + +// Signer signs requests with AWS Signature version 4. +type Signer struct { + options v4.SignerOptions +} + +// New returns an instance of Signer with applied options. +func New(opts ...v4.SignerOption) *Signer { + options := v4.SignerOptions{} + for _, opt := range opts { + opt(&options) + } + + return &Signer{options} +} + +// SignRequestInput is the set of inputs for Sigv4 signing. +type SignRequestInput struct { + // The input request, which will modified in-place during signing. + Request *http.Request + + // The SHA256 hash of the input request body. + // + // This value is NOT required to sign the request, but it is recommended to + // provide it (or provide a Body on the HTTP request that implements + // io.Seeker such that the signer can calculate it for you). Many services + // do not accept requests with unsigned payloads. + // + // If a value is not provided, and DisableImplicitPayloadHashing has not + // been set on SignerOptions, the signer will attempt to derive the payload + // hash itself. The request's Body MUST implement io.Seeker in order to do + // this, if it does not, the magic value for unsigned payload is used. If + // the body does implement io.Seeker, but a call to Seek returns an error, + // the signer will forward that error. + PayloadHash []byte + + // The identity used to sign the request. + Credentials credentials.Credentials + + // The service and region for which this request is to be signed. + // + // The appropriate values for these fields are determined by the service + // vendor. + Service, Region string + + // Wall-clock time used for calculating the signature. + // + // If the zero-value is given (generally by the caller not setting it), the + // signer will instead use the current system clock time for the signature. + Time time.Time +} + +// SignRequest signs an HTTP request with AWS Signature Version 4, modifying +// the request in-place by adding the headers that constitute the signature. +// +// SignRequest will modify the request by setting the following headers: +// - Host: required in general for HTTP/1.1 as well as for v4-signed requests +// - X-Amz-Date: required for v4-signed requests +// - X-Amz-Security-Token: required for v4-signed requests IF present on +// credentials used to sign, otherwise this header will not be set +// - Authorization: contains the v4 signature string +// +// The request MUST have a Host value set at the time that this API is called, +// such that it can be included in the signature calculation. Standard library +// HTTP clients set this as a request header by default, meaning that a request +// signed without a Host value will end up transmitting with the Host header +// anyway, which will cause the request to be rejected by the service due to +// signature mismatch (the Host header is required to be signed with Sigv4). +// +// Generally speaking, using http.NewRequest will ensure that request instances +// are sufficiently initialized to be used with this API, though it is not +// strictly required. +// +// SignRequest may be called any number of times on an http.Request instance, +// the header values set as part of the signature will simply be overwritten +// with newer or re-calculated values (such as a new set of credentials with a +// new session token, which would in turn result in a different signature). +func (s *Signer) SignRequest(in *SignRequestInput, opts ...v4.SignerOption) error { + options := s.options + for _, opt := range opts { + opt(&options) + } + + tm := v4internal.ResolveTime(in.Time) + signer := v4internal.Signer{ + Request: in.Request, + PayloadHash: in.PayloadHash, + Time: tm, + Credentials: in.Credentials, + Options: options, + + Algorithm: algorithm, + CredentialScope: scope(tm, in.Region, in.Service), + Finalizer: &finalizer{ + Secret: in.Credentials.SecretAccessKey, + Service: in.Service, + Region: in.Region, + Time: tm, + }, + } + if err := signer.Do(); err != nil { + return err + } + + return nil +} + +func scope(signingTime time.Time, region, service string) string { + return strings.Join([]string{ + signingTime.Format(v4internal.ShortTimeFormat), + region, + service, + "aws4_request", + }, "/") +} + +type finalizer struct { + Secret string + Service, Region string + Time time.Time +} + +func (f *finalizer) SignString(toSign string) (string, error) { + key := hmacSHA256([]byte("AWS4"+f.Secret), []byte(f.Time.Format(v4internal.ShortTimeFormat))) + key = hmacSHA256(key, []byte(f.Region)) + key = hmacSHA256(key, []byte(f.Service)) + key = hmacSHA256(key, []byte("aws4_request")) + + return hex.EncodeToString(hmacSHA256(key, []byte(toSign))), nil +} + +func hmacSHA256(key, data []byte) []byte { + hash := hmac.New(sha256.New, key) + hash.Write(data) + return hash.Sum(nil) +} diff --git a/aws-http-auth/sigv4/sigv4_test.go b/aws-http-auth/sigv4/sigv4_test.go new file mode 100644 index 00000000..6576b192 --- /dev/null +++ b/aws-http-auth/sigv4/sigv4_test.go @@ -0,0 +1,211 @@ +package sigv4 + +import ( + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/aws/smithy-go/aws-http-auth/credentials" + v4internal "github.com/aws/smithy-go/aws-http-auth/internal/v4" + v4 "github.com/aws/smithy-go/aws-http-auth/v4" +) + +var credsSession = credentials.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "SESSION", +} + +var credsNoSession = credentials.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", +} + +type signAll struct{} + +func (signAll) IsSigned(string) bool { return true } + +func seekable(v string) io.ReadSeekCloser { + return readseekcloser{strings.NewReader(v)} +} + +func nonseekable(v string) io.ReadCloser { + return io.NopCloser(strings.NewReader(v)) // io.NopCloser elides Seek() +} + +type readseekcloser struct { + io.ReadSeeker +} + +func (readseekcloser) Close() error { return nil } + +func newRequest(body io.ReadCloser, opts ...func(*http.Request)) *http.Request { + // we initialize via NewRequest because it sets basic things like host and + // proto and is generally how we recommend the signing APIs are used + // + // the url doesn't actually need to match the signing name / region + req, err := http.NewRequest(http.MethodPost, "https://service.region.amazonaws.com", body) + if err != nil { + panic(err) + } + + for _, opt := range opts { + opt(req) + } + return req +} + +func expectSignature(t *testing.T, signed *http.Request, expectSignature, expectDate, expectToken string) { + if actual := signed.Header.Get("Authorization"); expectSignature != actual { + t.Errorf("expect signature:\n%s\n!=\n%s", expectSignature, actual) + } + if actual := signed.Header.Get("X-Amz-Date"); expectDate != actual { + t.Errorf("expect date: %s != %s", expectDate, actual) + } + if actual := signed.Header.Get("X-Amz-Security-Token"); expectToken != actual { + t.Errorf("expect token: %s != %s", expectToken, actual) + } +} + +func TestSignRequest(t *testing.T) { + for name, tt := range map[string]struct { + Input *SignRequestInput + Opts v4.SignerOption + ExpectSignature string + ExpectDate string + ExpectToken string + }{ + "minimal case, nonseekable": { + Input: &SignRequestInput{ + Request: newRequest(nonseekable("{}")), + Credentials: credsSession, + Service: "dynamodb", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=671ed63777ad2f28bfefd733087414652c1498b3301d9bdf272e44a3172c28c0", + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + }, + "minimal case, seekable": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + Credentials: credsSession, + Service: "dynamodb", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=e75efbd4e2b3d3a8218d8fc0125e8fc888844510125ca6f33be555fd76d9aa18", + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + }, + "minimal case, no session": { + Input: &SignRequestInput{ + Request: newRequest(nonseekable("{}")), + Credentials: credsNoSession, + Service: "dynamodb", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date, Signature=6f249a4b86fd230f28cae603cdf92c2657b1d1ffc3fcccbd938e1339c4542e14", + ExpectDate: "19700101T000000Z", + ExpectToken: "", + }, + "explicit unsigned payload": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + PayloadHash: v4.UnsignedPayload(), + Credentials: credsSession, + Service: "dynamodb", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=671ed63777ad2f28bfefd733087414652c1498b3301d9bdf272e44a3172c28c0", + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + }, + "explicit payload hash": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + PayloadHash: v4internal.Stosha("{}"), + Credentials: credsSession, + Service: "dynamodb", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=e75efbd4e2b3d3a8218d8fc0125e8fc888844510125ca6f33be555fd76d9aa18", + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + }, + "sign all headers": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}"), func(r *http.Request) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Foo", "bar") + r.Header.Set("Bar", "baz") + }), + PayloadHash: v4internal.Stosha("{}"), + Credentials: credsSession, + Service: "dynamodb", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + Opts: func(o *v4.SignerOptions) { + o.HeaderRules = signAll{} + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=bar;content-type;foo;host;x-amz-date;x-amz-security-token, Signature=90673d8f57147fd36dbde4d4fe156f643ea25627e7b4d14c157c6369e685b80a", + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + }, + "disable implicit payload hash": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + Credentials: credsSession, + Service: "dynamodb", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + Opts: func(o *v4.SignerOptions) { + o.DisableImplicitPayloadHashing = true + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=671ed63777ad2f28bfefd733087414652c1498b3301d9bdf272e44a3172c28c0", + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + }, + "s3 settings": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + Credentials: credsSession, + Service: "s3", + Region: "us-east-1", + Time: time.Unix(0, 0), + }, + Opts: func(o *v4.SignerOptions) { + o.DisableDoublePathEscape = true + o.AddPayloadHashHeader = true + }, + ExpectSignature: "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, Signature=0232da513d9e9830b12cf0d9f374834671494bc362ee173adb2a50267d0339e0", + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + }, + } { + t.Run(name, func(t *testing.T) { + opt := tt.Opts + if opt == nil { + opt = func(o *v4.SignerOptions) {} + } + signer := New(opt) + if err := signer.SignRequest(tt.Input); err != nil { + t.Fatalf("expect no err, got %v", err) + } + + req := tt.Input.Request + expectSignature(t, req, tt.ExpectSignature, tt.ExpectDate, tt.ExpectToken) + if host := req.Header.Get("Host"); req.Host != host { + t.Errorf("expect host header: %s != %s", req.Host, host) + } + }) + } +} diff --git a/aws-http-auth/sigv4a/credentials.go b/aws-http-auth/sigv4a/credentials.go new file mode 100644 index 00000000..48aa0e95 --- /dev/null +++ b/aws-http-auth/sigv4a/credentials.go @@ -0,0 +1,179 @@ +package sigv4a + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "fmt" + "hash" + "math" + "math/big" + "sync" + + "github.com/aws/smithy-go/aws-http-auth/credentials" +) + +var ( + p256 elliptic.Curve + nMinusTwoP256 *big.Int + + one = new(big.Int).SetInt64(1) +) + +func init() { + p256 = elliptic.P256() + + nMinusTwoP256 = new(big.Int).SetBytes(p256.Params().N.Bytes()) + nMinusTwoP256 = nMinusTwoP256.Sub(nMinusTwoP256, new(big.Int).SetInt64(2)) +} + +// ecdsaCache stores the result of deriving an ECDSA private key from a +// shared-secret identity. +type ecdsaCache struct { + mu sync.Mutex + + akid string + priv *ecdsa.PrivateKey +} + +// Derive computes and caches the ECDSA key-pair for the identity, returning +// the result. +// +// Future calls to Derive with the same set of credentials (identified by AKID) +// will short-circuit. Future calls with a different set of credentials +// (identified by AKID) will re-derive the value, overwriting the old result. +func (c *ecdsaCache) Derive(creds credentials.Credentials) (*ecdsa.PrivateKey, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if creds.AccessKeyID == c.akid { + return c.priv, nil + } + + priv, err := derivePrivateKey(creds) + if err != nil { + return nil, err + } + + c.akid = creds.AccessKeyID + c.priv = priv + return priv, nil +} + +// derivePrivateKey derives a NIST P-256 PrivateKey from the given IAM +// AccessKey and SecretKey pair. +// +// Based on FIPS.186-4 Appendix B.4.2 +func derivePrivateKey(creds credentials.Credentials) (*ecdsa.PrivateKey, error) { + akid := creds.AccessKeyID + secret := creds.SecretAccessKey + + params := p256.Params() + bitLen := params.BitSize // Testing random candidates does not require an additional 64 bits + counter := 0x01 + + buffer := make([]byte, 1+len(akid)) // 1 byte counter + len(accessKey) + kdfContext := bytes.NewBuffer(buffer) + + inputKey := append([]byte("AWS4A"), []byte(secret)...) + + d := new(big.Int) + for { + kdfContext.Reset() + kdfContext.WriteString(akid) + kdfContext.WriteByte(byte(counter)) + + key, err := deriveHMACKey(sha256.New, bitLen, inputKey, []byte(algorithm), kdfContext.Bytes()) + if err != nil { + return nil, err + } + + cmp, err := cmpConst(key, nMinusTwoP256.Bytes()) + if err != nil { + return nil, err + } + if cmp == -1 { + d.SetBytes(key) + break + } + + counter++ + if counter > 0xFF { + return nil, fmt.Errorf("exhausted single byte external counter") + } + } + d = d.Add(d, one) + + priv := new(ecdsa.PrivateKey) + priv.PublicKey.Curve = p256 + priv.D = d + priv.PublicKey.X, priv.PublicKey.Y = p256.ScalarBaseMult(d.Bytes()) + + return priv, nil +} + +// deriveHMACKey provides an implementation of a NIST-800-108 of a KDF (Key +// Derivation Function) in Counter Mode. HMAC is used as the pseudorandom +// function, where the value of `r` is defined as a 4-byte counter. +func deriveHMACKey(hash func() hash.Hash, bitLen int, key []byte, label, context []byte) ([]byte, error) { + // verify that we won't overflow the counter + n := int64(math.Ceil((float64(bitLen) / 8) / float64(hash().Size()))) + if n > 0x7FFFFFFF { + return nil, fmt.Errorf("unable to derive key of size %d using 32-bit counter", bitLen) + } + + // verify the requested bit length is not larger then the length encoding size + if int64(bitLen) > 0x7FFFFFFF { + return nil, fmt.Errorf("bitLen is greater than 32-bits") + } + + fixedInput := bytes.NewBuffer(nil) + fixedInput.Write(label) + fixedInput.WriteByte(0x00) + fixedInput.Write(context) + if err := binary.Write(fixedInput, binary.BigEndian, int32(bitLen)); err != nil { + return nil, fmt.Errorf("failed to write bit length to fixed input string: %v", err) + } + + var output []byte + + h := hmac.New(hash, key) + + for i := int64(1); i <= n; i++ { + h.Reset() + if err := binary.Write(h, binary.BigEndian, int32(i)); err != nil { + return nil, err + } + _, err := h.Write(fixedInput.Bytes()) + if err != nil { + return nil, err + } + output = append(output, h.Sum(nil)...) + } + + return output[:bitLen/8], nil +} + +// constant-time byte slice compare +func cmpConst(x, y []byte) (int, error) { + if len(x) != len(y) { + return 0, fmt.Errorf("slice lengths do not match") + } + + xLarger, yLarger := 0, 0 + + for i := 0; i < len(x); i++ { + xByte, yByte := int(x[i]), int(y[i]) + + x := ((yByte - xByte) >> 8) & 1 + y := ((xByte - yByte) >> 8) & 1 + + xLarger |= x &^ yLarger + yLarger |= y &^ xLarger + } + + return xLarger - yLarger, nil +} diff --git a/aws-http-auth/sigv4a/e2e_test.go b/aws-http-auth/sigv4a/e2e_test.go new file mode 100644 index 00000000..bc12be45 --- /dev/null +++ b/aws-http-auth/sigv4a/e2e_test.go @@ -0,0 +1,535 @@ +//go:build e2e +// +build e2e + +package sigv4a + +import ( + "bytes" + "context" + "encoding/xml" + "fmt" + "io" + "math/rand" + "net/http" + "os" + "testing" + "time" + + "github.com/aws/smithy-go/aws-http-auth/credentials" + "github.com/aws/smithy-go/aws-http-auth/sigv4" + v4 "github.com/aws/smithy-go/aws-http-auth/v4" +) + +type closer struct{ io.ReadSeeker } + +func (closer) Close() error { return nil } + +type ToXML interface { + ToXML() []byte +} + +type S3Client struct { + Region string + AccountID string + Credentials credentials.Credentials + + HTTPClient *http.Client + V4 *sigv4.Signer + V4A *Signer +} + +func (c *S3Client) CreateBucket(ctx context.Context, in *CreateBucketInput) (*CreateBucketOutput, error) { + var out CreateBucketOutput + endpoint := fmt.Sprintf("https://%s.s3.%s.amazonaws.com", in.Bucket, c.Region) + method := http.MethodPut + path := "/" + + sign := signV4(c.V4, c.Credentials, c.Region) + if err := c.do(ctx, method, endpoint, path, in, &out, sign); err != nil { + return nil, err + } + return &out, nil +} + +type CreateBucketInput struct { + Bucket string +} + +func (*CreateBucketInput) ToXML() []byte { + return []byte("") +} + +type CreateBucketOutput struct{} + +func (c *S3Client) DeleteBucket(ctx context.Context, in *DeleteBucketInput) (*DeleteBucketOutput, error) { + var out DeleteBucketOutput + endpoint := fmt.Sprintf("https://%s.s3.%s.amazonaws.com", in.Bucket, c.Region) + method := http.MethodDelete + path := "/" + + sign := signV4(c.V4, c.Credentials, c.Region) + if err := c.do(ctx, method, endpoint, path, in, &out, sign); err != nil { + return nil, err + } + return &out, nil +} + +type DeleteBucketInput struct { + Bucket string +} + +func (*DeleteBucketInput) ToXML() []byte { + return []byte("") +} + +type DeleteBucketOutput struct{} + +func (c *S3Client) PutObjectMRAP(ctx context.Context, in *PutObjectMRAPInput) (*PutObjectMRAPOutput, error) { + var out PutObjectMRAPOutput + endpoint := fmt.Sprintf("https://%s.accesspoint.s3-global.amazonaws.com", in.MRAPAlias) + method := http.MethodPut + path := "/" + in.Key + + sign := signV4A(c.V4A, c.Credentials, true) // unsigned payload + if err := c.do(ctx, method, endpoint, path, in, &out, sign); err != nil { + return nil, err + } + return &out, nil +} + +type PutObjectMRAPInput struct { + MRAPAlias string + Key string + + ObjectData string +} + +func (i *PutObjectMRAPInput) ToXML() []byte { + // not actually XML but good enough to get the object data into the request + // body + return []byte(i.ObjectData) +} + +type PutObjectMRAPOutput struct{} + +func (c *S3Client) DeleteObjectMRAP(ctx context.Context, in *DeleteObjectMRAPInput) (*DeleteObjectMRAPOutput, error) { + var out DeleteObjectMRAPOutput + endpoint := fmt.Sprintf("https://%s.accesspoint.s3-global.amazonaws.com", in.MRAPAlias) + method := http.MethodDelete + path := "/" + in.Key + + sign := signV4A(c.V4A, c.Credentials, false) + if err := c.do(ctx, method, endpoint, path, in, &out, sign); err != nil { + return nil, err + } + return &out, nil +} + +type DeleteObjectMRAPInput struct { + MRAPAlias string + Key string +} + +func (i *DeleteObjectMRAPInput) ToXML() []byte { + return []byte("") +} + +type DeleteObjectMRAPOutput struct{} + +func signV4(signer *sigv4.Signer, creds credentials.Credentials, region string) func(*http.Request) error { + return func(r *http.Request) error { + return signer.SignRequest(&sigv4.SignRequestInput{ + Request: r, + Credentials: creds, + Service: "s3", + Region: region, + }) + } +} + +func signV4A(signer *Signer, creds credentials.Credentials, isUnsignedPayload bool) func(*http.Request) error { + var payloadHash []byte + if isUnsignedPayload { + payloadHash = v4.UnsignedPayload() + } + return func(r *http.Request) error { + err := signer.SignRequest(&SignRequestInput{ + Request: r, + PayloadHash: payloadHash, + Credentials: creds, + Service: "s3", + RegionSet: []string{"*"}, + }) + + fmt.Println("signed request ------------------------------------------") + fmt.Printf("%s %s\n", r.Method, r.URL.EscapedPath()) + for h := range r.Header { + fmt.Printf("%s: %s\n", h, r.Header.Get(h)) + } + fmt.Println("---------------------------------------------------------") + + return err + } +} + +func (c *S3Client) do(ctx context.Context, method, endpoint, path string, in ToXML, out any, signRequest func(*http.Request) error) error { + // init + req, err := http.NewRequestWithContext(ctx, method, endpoint, http.NoBody) + if err != nil { + return fmt.Errorf("new http request: %w", err) + } + + // serialize + req.URL.Path = path + req.Header.Set("Content-Type", "application/xml") + payload := in.ToXML() + req.Body = closer{bytes.NewReader(payload)} + req.ContentLength = int64(len(payload)) + + // sign + err = signRequest(req) + if err != nil { + return fmt.Errorf("sign request: %w", err) + } + + // round-trip + resp, err := c.HTTPClient.Do(req) + if err != nil { + return fmt.Errorf("do request: %w", err) + } + defer resp.Body.Close() + + // deserialize + data, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("request error: %s: %s", resp.Status, data) + } + if len(data) == 0 { + return nil + } + if err := xml.Unmarshal(data, out); err != nil { + return fmt.Errorf("deserialize response: %w", err) + } + + return nil +} + +type S3ControlClient struct { + // s3control only does us-west-2 + // Region string + AccountID string + Credentials credentials.Credentials + + HTTPClient *http.Client + Signer *sigv4.Signer +} + +func (c *S3ControlClient) GetMRAP(ctx context.Context, in *GetMRAPInput) (*GetMRAPOutput, error) { + var out GetMRAPOutput + method := http.MethodGet + path := "/v20180820/mrap/instances/" + in.Name + if err := c.do(ctx, c.AccountID, method, path, in, &out); err != nil { + return nil, err + } + return &out, nil +} + +type GetMRAPInput struct { + Name string +} + +func (i *GetMRAPInput) ToXML() []byte { + return []byte("") +} + +type GetMRAPOutput struct { + AccessPoint struct { + Alias string + } +} + +func (c *S3ControlClient) CreateMRAP(ctx context.Context, in *CreateMRAPInput) (*CreateMRAPOutput, error) { + var out CreateMRAPOutput + method := http.MethodPost + path := "/v20180820/async-requests/mrap/create" + if err := c.do(ctx, c.AccountID, method, path, in, &out); err != nil { + return nil, err + } + return &out, nil +} + +type CreateMRAPInput struct { + Name string + Bucket string +} + +func (i *CreateMRAPInput) ToXML() []byte { + const tmpl = ` + + %s +
+ %s + + + %s + + +
+
` + + token := fmt.Sprintf("%d", rand.Int31()) + return []byte(fmt.Sprintf(tmpl, token, i.Name, i.Bucket)) +} + +type CreateMRAPOutput struct { + RequestToken string `xml:"RequestTokenARN"` +} + +func (c *S3ControlClient) DescribeMRAPOperation(ctx context.Context, in *DescribeMRAPOperationInput) (*DescribeMRAPOperationOutput, error) { + var out DescribeMRAPOperationOutput + method := http.MethodGet + path := "/v20180820/async-requests/mrap/" + in.RequestToken + if err := c.do(ctx, c.AccountID, method, path, in, &out); err != nil { + return nil, err + } + return &out, nil +} + +type DescribeMRAPOperationInput struct { + RequestToken string +} + +func (i *DescribeMRAPOperationInput) ToXML() []byte { + return []byte("") +} + +type DescribeMRAPOperationOutput struct { + AsyncOperation struct { + RequestStatus string + } +} + +func (c *S3ControlClient) DeleteMRAP(ctx context.Context, in *DeleteMRAPInput) (*DeleteMRAPOutput, error) { + var out DeleteMRAPOutput + method := http.MethodPost + path := "/v20180820/async-requests/mrap/delete" + if err := c.do(ctx, c.AccountID, method, path, in, &out); err != nil { + return nil, err + } + return &out, nil +} + +type DeleteMRAPInput struct { + Name string +} + +func (i *DeleteMRAPInput) ToXML() []byte { + const tmpl = ` + + %s +
+ %s +
+
` + + token := fmt.Sprintf("%d", rand.Int31()) + return []byte(fmt.Sprintf(tmpl, token, i.Name)) +} + +type DeleteMRAPOutput struct { + RequestToken string `xml:"RequestTokenARN"` +} + +func (c *S3ControlClient) do(ctx context.Context, accountID, method, path string, in ToXML, out any) error { + // init + endpoint := fmt.Sprintf("https://%s.s3-control.us-west-2.amazonaws.com", accountID) + req, err := http.NewRequestWithContext(ctx, method, endpoint, http.NoBody) + if err != nil { + return fmt.Errorf("new http request: %w", err) + } + + // serialize + req.URL.Path = path + req.Header.Set("Content-Type", "application/xml") + req.Header.Set("X-Amz-Account-Id", accountID) + payload := in.ToXML() + req.Body = closer{bytes.NewReader(payload)} + req.ContentLength = int64(len(payload)) + + // sign + err = c.Signer.SignRequest(&sigv4.SignRequestInput{ + Request: req, + Credentials: c.Credentials, + Service: "s3", + Region: "us-west-2", + }) + if err != nil { + return fmt.Errorf("sign request: %w", err) + } + + // round-trip + resp, err := c.HTTPClient.Do(req) + if err != nil { + return fmt.Errorf("do request: %w", err) + } + defer resp.Body.Close() + + // deserialize + data, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("request error: %s: %s", resp.Status, data) + } + if len(data) == 0 { + return nil + } + if err := xml.Unmarshal(data, out); err != nil { + return fmt.Errorf("deserialize response: %w", err) + } + + return nil +} + +// WARNING: this test takes a while, because creating an MRAP is asynchronous +// and slow +// +// 1. creates a bucket in us-east-1 +// 2. creates an MRAP that points to that bucket +// 3. polls MRAP status until created +// 4. puts object to MRAP +// 5. deletes object +// 6. deletes MRAP +// 7. deletes bucket +func TestE2E_S3MRAP(t *testing.T) { + svc := &S3Client{ + Region: "us-east-1", + AccountID: os.Getenv("AWS_ACCOUNT_ID"), + Credentials: credentials.Credentials{ + AccessKeyID: os.Getenv("AWS_ACCESS_KEY_ID"), + SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + SessionToken: os.Getenv("AWS_SESSION_TOKEN"), + }, + HTTPClient: http.DefaultClient, + V4: sigv4.New(func(o *v4.SignerOptions) { + o.DisableDoublePathEscape = true + o.AddPayloadHashHeader = true + }), + V4A: New(func(o *v4.SignerOptions) { + o.DisableDoublePathEscape = true + o.AddPayloadHashHeader = true + }), + } + controlsvc := &S3ControlClient{ + AccountID: os.Getenv("AWS_ACCOUNT_ID"), + Credentials: credentials.Credentials{ + AccessKeyID: os.Getenv("AWS_ACCESS_KEY_ID"), + SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + SessionToken: os.Getenv("AWS_SESSION_TOKEN"), + }, + HTTPClient: http.DefaultClient, + Signer: sigv4.New(func(o *v4.SignerOptions) { + o.AddPayloadHashHeader = true + }), + } + + testid := rand.Int() % (2 << 15) + bucket := fmt.Sprintf("aws-http-auth-e2etest-bucket-%d", testid) + mrap := fmt.Sprintf("aws-http-auth-e2etest-mrap-%d", testid) + + _, err := svc.CreateBucket(context.Background(), &CreateBucketInput{ + Bucket: bucket, + }) + if err != nil { + t.Fatalf("create bucket: %v", err) + } + + t.Logf("created test bucket: %s", bucket) + + createMRAPOutput, err := controlsvc.CreateMRAP(context.Background(), &CreateMRAPInput{ + Name: mrap, + Bucket: bucket, + }) + if err != nil { + t.Fatalf("create mrap: %v", err) + } + t.Logf("started mrap create... token %s", createMRAPOutput.RequestToken) + awaitS3ControlOperation(t, context.Background(), controlsvc, createMRAPOutput.RequestToken) + + t.Logf("created test mrap: %s", mrap) + + getMRAPOutput, err := controlsvc.GetMRAP(context.Background(), &GetMRAPInput{ + Name: mrap, + }) + if err != nil { + t.Fatalf("get mrap info: %v", err) + } + + t.Logf("retrieved mrap alias: %s", getMRAPOutput.AccessPoint.Alias) + + _, err = svc.PutObjectMRAP(context.Background(), &PutObjectMRAPInput{ + MRAPAlias: getMRAPOutput.AccessPoint.Alias, + Key: "path1 / path2", // verify single-encode behavior + ObjectData: mrap, + }) + if err != nil { + t.Fatalf("put object mrap: %v", err) + } + + _, err = svc.DeleteObjectMRAP(context.Background(), &DeleteObjectMRAPInput{ + MRAPAlias: getMRAPOutput.AccessPoint.Alias, + Key: "path1 / path2", + }) + if err != nil { + t.Fatalf("delete object mrap: %v", err) + } + + deleteMRAPOutput, err := controlsvc.DeleteMRAP(context.Background(), &DeleteMRAPInput{ + Name: mrap, + }) + if err != nil { + t.Fatalf("delete mrap: %v", err) + } + t.Logf("started mrap delete... token %s", deleteMRAPOutput.RequestToken) + awaitS3ControlOperation(t, context.Background(), controlsvc, deleteMRAPOutput.RequestToken) + + _, err = svc.DeleteBucket(context.Background(), &DeleteBucketInput{ + Bucket: bucket, + }) + if err != nil { + t.Fatalf("delete bucket: %v", err) + } + + t.Logf("deleted test bucket: %s", bucket) +} + +func awaitS3ControlOperation(t *testing.T, ctx context.Context, svc *S3ControlClient, requestToken string) { + t.Helper() + + start := time.Now() + for { + out, err := svc.DescribeMRAPOperation(ctx, &DescribeMRAPOperationInput{ + RequestToken: requestToken, + }) + if err != nil { + t.Fatalf("describe mrap operation: %v", err) + } + + t.Logf("poll status: %s\n", out.AsyncOperation.RequestStatus) + time.Sleep(5 * time.Second) + + // S3Control does not document the values for this field. + // Anecdotally: + // - returns NEW a few seconds after the operation starts + // - returns INPROGRESS until complete + // - returns SUCCEEDED when complete + if out.AsyncOperation.RequestStatus == "SUCCEEDED" { + break + } + } + t.Logf("operation completed after %v", time.Now().Sub(start)) +} diff --git a/aws-http-auth/sigv4a/sigv4a.go b/aws-http-auth/sigv4a/sigv4a.go new file mode 100644 index 00000000..f78aab4c --- /dev/null +++ b/aws-http-auth/sigv4a/sigv4a.go @@ -0,0 +1,187 @@ +// Package sigv4a implements request signing for AWS Signature Version 4a +// (asymmetric). +// +// The algorithm for Signature Version 4a is identical to that of plain v4 +// apart from the following: +// - A request can be signed for multiple regions. This is represented in the +// signature using the X-Amz-Region-Set header. The credential scope string +// used in the calculation correspondingly lacks the region component from +// that of plain v4. +// - The string-to-sign component of the calculation is instead signed with +// an ECDSA private key. This private key is typically derived from your +// regular AWS credentials. +package sigv4a + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "encoding/hex" + "net/http" + "strings" + "time" + + "github.com/aws/smithy-go/aws-http-auth/credentials" + v4internal "github.com/aws/smithy-go/aws-http-auth/internal/v4" + v4 "github.com/aws/smithy-go/aws-http-auth/v4" +) + +const algorithm = "AWS4-ECDSA-P256-SHA256" + +// Signer signs requests with AWS Signature Version 4a. +// +// Unlike Sigv4, AWS SigV4a signs requests with an ECDSA private key. This is +// derived automatically from the AWS credential identity passed to +// SignRequest. This derivation result is cached on the Signer and is uniquely +// identified by the access key ID (AKID) of the credentials that were +// provided. +// +// Because of this, the caller is encouraged to create multiple instances of +// Signer for different underlying identities (e.g. IAM roles). +type Signer struct { + options v4.SignerOptions + + // derived asymmetric credentials + privCache *ecdsaCache +} + +// New returns an instance of Signer with applied options. +func New(opts ...v4.SignerOption) *Signer { + options := v4.SignerOptions{} + + for _, opt := range opts { + opt(&options) + } + + return &Signer{ + options: options, + privCache: &ecdsaCache{}, + } +} + +// SignRequestInput is the set of inputs for the Sigv4a signing process. +type SignRequestInput struct { + // The input request, which will modified in-place during signing. + Request *http.Request + + // The SHA256 hash of the input request body. + // + // This value is NOT required to sign the request, but it is recommended to + // provide it (or provide a Body on the HTTP request that implements + // io.Seeker such that the signer can calculate it for you). Many services + // do not accept requests with unsigned payloads. + // + // If a value is not provided, and DisableImplicitPayloadHashing has not + // been set on SignerOptions, the signer will attempt to derive the payload + // hash itself. The request's Body MUST implement io.Seeker in order to do + // this, if it does not, the magic value for unsigned payload is used. If + // the body does implement io.Seeker, but a call to Seek returns an error, + // the signer will forward that error. + PayloadHash []byte + + // The identity used to sign the request. + Credentials credentials.Credentials + + // The service for which this request is to be signed. + // + // The appropriate value for this field is determined by the service + // vendor. + Service string + + // The set of regions for which this request is to be signed. + // + // The sentinel {"*"} is used to indicate that the signed request is valid + // in all regions. Callers MUST set a value for this field - the API will + // not fill in a default and the resulting signature will ultimately be + // invalid. + // + // The acceptable values for list entries of this field are determined by + // the service vendor. + RegionSet []string + + // Wall-clock time used for calculating the signature. + // + // If the zero-value is given (generally by the caller not setting it), the + // signer will instead use the current system clock time for the signature. + Time time.Time +} + +// SignRequest signs an HTTP request with AWS Signature Version 4, modifying +// the request in-place by adding the headers that constitute the signature. +// +// SignRequest will modify the request by setting the following headers: +// - Host: required in general for HTTP/1.1 as well as for v4-signed requests +// - X-Amz-Date: required for v4a-signed requests +// - X-Amz-Region-Set: used to convey the regions for which the request is +// signed in v4a +// - X-Amz-Security-Token: required for v4a-signed requests IF present on +// credentials used to sign, otherwise this header will not be set +// - Authorization: contains the v4a signature string +// +// The request MUST have a Host value set at the time that this API is called, +// such that it can be included in the signature calculation. Standard library +// HTTP clients set this as a request header by default, meaning that a request +// signed without a Host value will end up transmitting with the Host header +// anyway, which will cause the request to be rejected by the service due to +// signature mismatch (the Host header is required to be signed with Sigv4). +// +// Generally speaking, using http.NewRequest will ensure that request instances +// are sufficiently initialized to be used with this API, though it is not +// strictly required. +// +// SignRequest may be called any number of times on an http.Request instance, +// the header values set as part of the signature will simply be re-calculated. +// Note that v4a signatures are non-deterministic due to the random component +// of ECDSA signing, callers should not expect two calls to SignRequest() to +// produce an identical signature. +func (s *Signer) SignRequest(in *SignRequestInput, opts ...v4.SignerOption) error { + options := s.options + for _, fn := range opts { + fn(&options) + } + + priv, err := s.privCache.Derive(in.Credentials) + if err != nil { + return err + } + + in.Request.Header.Set("X-Amz-Region-Set", strings.Join(in.RegionSet, ",")) + + tm := v4internal.ResolveTime(in.Time) + signer := &v4internal.Signer{ + Request: in.Request, + PayloadHash: in.PayloadHash, + Time: tm, + Credentials: in.Credentials, + Options: options, + + Algorithm: algorithm, + CredentialScope: scope(tm, in.Service), + Finalizer: &finalizer{priv}, + } + if err := signer.Do(); err != nil { + return err + } + + return nil +} + +func scope(tm time.Time, service string) string { + return strings.Join([]string{ + tm.Format(v4internal.ShortTimeFormat), + service, + "aws4_request", + }, "/") +} + +type finalizer struct { + Secret *ecdsa.PrivateKey +} + +func (f *finalizer) SignString(strToSign string) (string, error) { + sig, err := f.Secret.Sign(rand.Reader, v4internal.Stosha(strToSign), crypto.SHA256) + if err != nil { + return "", err + } + return hex.EncodeToString(sig), nil +} diff --git a/aws-http-auth/sigv4a/sigv4a_test.go b/aws-http-auth/sigv4a/sigv4a_test.go new file mode 100644 index 00000000..fc545aae --- /dev/null +++ b/aws-http-auth/sigv4a/sigv4a_test.go @@ -0,0 +1,406 @@ +package sigv4a + +import ( + "crypto/ecdsa" + "crypto/rand" + "encoding/asn1" + "encoding/hex" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "testing" + "time" + + "github.com/aws/smithy-go/aws-http-auth/credentials" + v4internal "github.com/aws/smithy-go/aws-http-auth/internal/v4" + v4 "github.com/aws/smithy-go/aws-http-auth/v4" +) + +const ( + accessKey = "AKISORANDOMAASORANDOM" + secretKey = "q+jcrXGc+0zWN6uzclKVhvMmUsIfRPa4rlRandom" + sessionToken = "TOKEN" +) + +type signAll struct{} + +func (signAll) IsSigned(string) bool { return true } + +type ecdsaSignature struct { + R, S *big.Int +} + +var credsSession = credentials.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "SESSION", +} + +var credsNoSession = credentials.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", +} + +func seekable(v string) io.ReadSeekCloser { + return readseekcloser{strings.NewReader(v)} +} + +func nonseekable(v string) io.ReadCloser { + return io.NopCloser(strings.NewReader(v)) // io.NopCloser elides Seek() +} + +type readseekcloser struct { + io.ReadSeeker +} + +func (readseekcloser) Close() error { return nil } + +func verifySignature(key *ecdsa.PublicKey, hash []byte, signature []byte) (bool, error) { + var sig ecdsaSignature + + _, err := asn1.Unmarshal(signature, &sig) + if err != nil { + return false, err + } + + return ecdsa.Verify(key, hash, sig.R, sig.S), nil +} + +func TestDeriveECDSAKeyPairFromSecret(t *testing.T) { + t.Skip() + creds := credentials.Credentials{ + AccessKeyID: accessKey, + SecretAccessKey: secretKey, + } + privateKey, err := derivePrivateKey(creds) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + expectedX := func() *big.Int { + t.Helper() + b, ok := new(big.Int).SetString("15D242CEEBF8D8169FD6A8B5A746C41140414C3B07579038DA06AF89190FFFCB", 16) + if !ok { + t.Fatalf("failed to parse big integer") + } + return b + }() + expectedY := func() *big.Int { + t.Helper() + b, ok := new(big.Int).SetString("515242CEDD82E94799482E4C0514B505AFCCF2C0C98D6A553BF539F424C5EC0", 16) + if !ok { + t.Fatalf("failed to parse big integer") + } + return b + }() + + if privateKey.X.Cmp(expectedX) != 0 { + t.Errorf("expected % X, got % X", expectedX, privateKey.X) + } + if privateKey.Y.Cmp(expectedY) != 0 { + t.Errorf("expected % X, got % X", expectedY, privateKey.Y) + } +} + +func newRequest(body io.ReadCloser, opts ...func(*http.Request)) *http.Request { + req, err := http.NewRequest(http.MethodPost, "https://service.region.amazonaws.com", body) + if err != nil { + panic(err) + } + + for _, opt := range opts { + opt(req) + } + return req +} + +func TestSignRequest(t *testing.T) { + for name, tt := range map[string]struct { + Input *SignRequestInput + Opts v4.SignerOption + ExpectPreamble string + ExpectSignedHeaders string + ExpectStringToSign string + ExpectDate string + ExpectToken string + ExpectRegionSetHeader string + }{ + "minimal case, nonseekable": { + Input: &SignRequestInput{ + Request: newRequest(nonseekable("{}")), + Credentials: credsSession, + Service: "dynamodb", + RegionSet: []string{"us-east-1", "us-west-1"}, + Time: time.Unix(0, 0), + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/dynamodb/aws4_request", + ExpectSignedHeaders: "SignedHeaders=host;x-amz-date;x-amz-region-set;x-amz-security-token", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/dynamodb/aws4_request +968265b4e87c6b10c8ec6bcfd63e8002814cb3256d74c6c381f0c31268c80b53`, + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + ExpectRegionSetHeader: "us-east-1,us-west-1", + }, + "minimal case, seekable": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + Credentials: credsSession, + Service: "dynamodb", + RegionSet: []string{"us-east-1"}, + Time: time.Unix(0, 0), + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/dynamodb/aws4_request", + ExpectSignedHeaders: "SignedHeaders=host;x-amz-date;x-amz-region-set;x-amz-security-token", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/dynamodb/aws4_request +6fbe2f6247e506a47694e695d825477af6c604184f775050ce3b83e04674d9aa`, + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + ExpectRegionSetHeader: "us-east-1", + }, + "minimal case, no session": { + Input: &SignRequestInput{ + Request: newRequest(nonseekable("{}")), + Credentials: credsNoSession, + Service: "dynamodb", + RegionSet: []string{"us-east-1"}, + Time: time.Unix(0, 0), + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/dynamodb/aws4_request", + ExpectSignedHeaders: "SignedHeaders=host;x-amz-date;x-amz-region-set", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/dynamodb/aws4_request +825ea1f5e80bdb91ac8802e832504d1ff1c3b05b7619ffc273a1565a7600ff5a`, + ExpectDate: "19700101T000000Z", + ExpectToken: "", + ExpectRegionSetHeader: "us-east-1", + }, + "explicit unsigned payload": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + PayloadHash: v4.UnsignedPayload(), + Credentials: credsSession, + Service: "dynamodb", + RegionSet: []string{"us-east-1"}, + Time: time.Unix(0, 0), + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/dynamodb/aws4_request", + ExpectSignedHeaders: "SignedHeaders=host;x-amz-date;x-amz-region-set;x-amz-security-token", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/dynamodb/aws4_request +69e5041f5ff858ee8f53a30e5f98cdb4c6bcbfe0f8e61b8aba537d2713bf41a4`, + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + ExpectRegionSetHeader: "us-east-1", + }, + "explicit payload hash": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + PayloadHash: v4internal.Stosha("{}"), + Credentials: credsSession, + Service: "dynamodb", + RegionSet: []string{"us-east-1"}, + Time: time.Unix(0, 0), + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/dynamodb/aws4_request", + ExpectSignedHeaders: "SignedHeaders=host;x-amz-date;x-amz-region-set;x-amz-security-token", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/dynamodb/aws4_request +6fbe2f6247e506a47694e695d825477af6c604184f775050ce3b83e04674d9aa`, + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + ExpectRegionSetHeader: "us-east-1", + }, + "sign all headers": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}"), func(r *http.Request) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Foo", "bar") + r.Header.Set("Bar", "baz") + }), + PayloadHash: v4internal.Stosha("{}"), + Credentials: credsSession, + Service: "dynamodb", + RegionSet: []string{"us-east-1"}, + Time: time.Unix(0, 0), + }, + Opts: func(o *v4.SignerOptions) { + o.HeaderRules = signAll{} + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/dynamodb/aws4_request", + ExpectSignedHeaders: "SignedHeaders=bar;content-type;foo;host;x-amz-date;x-amz-region-set;x-amz-security-token", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/dynamodb/aws4_request +b28cca9faeaa86f4dbfcc3113b05b38f53cd41f41448a41e08e0171cea8ec363`, + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + ExpectRegionSetHeader: "us-east-1", + }, + "disable implicit payload hash": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + Credentials: credsSession, + Service: "dynamodb", + RegionSet: []string{"us-east-1"}, + Time: time.Unix(0, 0), + }, + Opts: func(o *v4.SignerOptions) { + o.DisableImplicitPayloadHashing = true + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/dynamodb/aws4_request", + ExpectSignedHeaders: "SignedHeaders=host;x-amz-date;x-amz-region-set;x-amz-security-token", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/dynamodb/aws4_request +69e5041f5ff858ee8f53a30e5f98cdb4c6bcbfe0f8e61b8aba537d2713bf41a4`, + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + ExpectRegionSetHeader: "us-east-1", + }, + "s3 settings": { + Input: &SignRequestInput{ + Request: newRequest(seekable("{}")), + Credentials: credsSession, + Service: "s3", + RegionSet: []string{"us-east-1"}, + Time: time.Unix(0, 0), + }, + Opts: func(o *v4.SignerOptions) { + o.DisableDoublePathEscape = true + o.AddPayloadHashHeader = true + }, + ExpectPreamble: "AWS4-ECDSA-P256-SHA256 Credential=AKID/19700101/s3/aws4_request", + ExpectSignedHeaders: "SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-region-set;x-amz-security-token", + ExpectStringToSign: `AWS4-ECDSA-P256-SHA256 +19700101T000000Z +19700101/s3/aws4_request +3cf4ba7f150421e93dbc22112484147af6e355676d08a75799cfd32424458d7f`, + ExpectDate: "19700101T000000Z", + ExpectToken: "SESSION", + ExpectRegionSetHeader: "us-east-1", + }, + } { + t.Run(name, func(t *testing.T) { + opt := tt.Opts + if opt == nil { + opt = func(o *v4.SignerOptions) {} + } + signer := New(opt) + if err := signer.SignRequest(tt.Input); err != nil { + t.Fatalf("expect no err, got %v", err) + } + + req := tt.Input.Request + expectSignature(t, req, tt.Input.Credentials, + tt.ExpectPreamble, tt.ExpectSignedHeaders, tt.ExpectStringToSign, + tt.ExpectDate, tt.ExpectToken, tt.ExpectRegionSetHeader) + if host := req.Header.Get("Host"); req.Host != host { + t.Errorf("expect host header: %s != %s", req.Host, host) + } + }) + } +} + +// v4a signatures are random because ECDSA itself is random +// to verify the signature, we effectively have to formally verify the other +// side of the ECDSA calculation +// +// note that this implicitly verifies the correctness of derivePrivateKey, +// otherwise signatures wouldn't match +func expectSignature( + t *testing.T, signed *http.Request, creds credentials.Credentials, + expectPreamble, expectSignedHeaders string, // fixed header components + expectStrToSign string, // for manual signature verification + expectDate, expectToken, expectRegionSet string, // fixed headers +) { + t.Helper() + + preamble, signedHeaders, signature, err := getSignature(signed) + if err != nil { + t.Fatalf("get signature: %v", err) + } + + if expectPreamble != preamble { + t.Errorf("preamble:\n%s\n!=\n%s", expectPreamble, preamble) + } + if signedHeaders != expectSignedHeaders { + t.Errorf("signed headers:\n%s\n!=\n%s", expectSignedHeaders, signedHeaders) + } + + priv, err := derivePrivateKey(creds) + if err != nil { + t.Fatalf("derive private key: %v", err) + } + + ok, err := verifySignature(&priv.PublicKey, v4internal.Stosha(expectStrToSign), signature) + if err != nil { + t.Fatalf("verify signature: %v", err) + } + if !ok { + t.Errorf("signature mismatch") + } +} + +func getSignature(r *http.Request) ( + preamble, headers string, signature []byte, err error, +) { + auth := r.Header.Get("Authorization") + if len(auth) == 0 { + err = fmt.Errorf("no authorization header") + return + } + + parts := strings.Split(auth, ", ") + if len(parts) != 3 { + err = fmt.Errorf("auth header is malformed: %q", auth) + return + } + + sigpart := parts[2] + sigparts := strings.Split(sigpart, "=") + if len(sigparts) != 2 { + err = fmt.Errorf("Signature= component is malformed: %s", sigpart) + return + } + + sig, err := hex.DecodeString(sigparts[1]) + if err != nil { + err = fmt.Errorf("decode signature hex: %w", err) + return + } + + return parts[0], parts[1], sig, nil +} + +type readexploder struct{} + +func (readexploder) Read([]byte) (int, error) { + return 0, fmt.Errorf("readexploder boom") +} + +func TestSignRequest_SignStringError(t *testing.T) { + randReader := rand.Reader + rand.Reader = readexploder{} + defer func() { rand.Reader = randReader }() + s := New() + + err := s.SignRequest(&SignRequestInput{ + Request: newRequest(http.NoBody), + PayloadHash: v4.UnsignedPayload(), + }) + if err == nil { + t.Fatal("expect error but didn't get one") + } + if expect := "readexploder boom"; expect != err.Error() { + t.Errorf("error mismatch: %v != %v", expect, err.Error()) + } +} diff --git a/aws-http-auth/v4/v4.go b/aws-http-auth/v4/v4.go new file mode 100644 index 00000000..15df1ba3 --- /dev/null +++ b/aws-http-auth/v4/v4.go @@ -0,0 +1,44 @@ +// Package v4 exposes common APIs for AWS Signature Version 4. +package v4 + +// SignerOption applies configuration to a signer. +type SignerOption func(*SignerOptions) + +// SignerOptions configures SigV4. +type SignerOptions struct { + // Rules to determine what headers are signed. + // + // By default, the signer will only include the minimum required headers: + // - Host + // - X-Amz-* + HeaderRules SignedHeaderRules + + // Setting this flag will instead cause the signer to use the + // UNSIGNED-PAYLOAD sentinel if a hash is not explicitly provided. + DisableImplicitPayloadHashing bool + + // Disables the automatic escaping of the URI path of the request for the + // siganture's canonical string's path. + // + // Amazon S3 is an example of a service that requires this setting. + DisableDoublePathEscape bool + + // Adds the X-Amz-Content-Sha256 header to signed requests. + // + // Amazon S3 is an example of a service that requires this setting. + AddPayloadHashHeader bool +} + +// SignedHeaderRules determines whether a request header should be included in +// the calculated signature. +// +// By convention, ShouldSign is invoked with lowercase values. +type SignedHeaderRules interface { + IsSigned(string) bool +} + +// UnsignedPayload provides the sentinel value for a payload hash to indicate +// that a request's payload is unsigned. +func UnsignedPayload() []byte { + return []byte("UNSIGNED-PAYLOAD") +}