diff --git a/aws/signer/v4/stream.go b/aws/signer/v4/stream.go new file mode 100644 index 00000000000..02cbd97e234 --- /dev/null +++ b/aws/signer/v4/stream.go @@ -0,0 +1,63 @@ +package v4 + +import ( + "encoding/hex" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +type credentialValueProvider interface { + Get() (credentials.Value, error) +} + +// StreamSigner implements signing of event stream encoded payloads +type StreamSigner struct { + region string + service string + + credentials credentialValueProvider + + prevSig []byte +} + +// NewStreamSigner creates a SigV4 signer used to sign Event Stream encoded messages +func NewStreamSigner(region, service string, seedSignature []byte, credentials *credentials.Credentials) *StreamSigner { + return &StreamSigner{ + region: region, + service: service, + credentials: credentials, + prevSig: seedSignature, + } +} + +// GetSignature takes an event stream encoded headers and payload and returns a signature +func (s *StreamSigner) GetSignature(headers, payload []byte, date time.Time) ([]byte, error) { + credValue, err := s.credentials.Get() + if err != nil { + return nil, err + } + + sigKey := deriveSigningKey(s.region, s.service, credValue.SecretAccessKey, date) + + keyPath := buildSigningScope(s.region, s.service, date) + + stringToSign := buildEventStreamStringToSign(headers, payload, s.prevSig, keyPath, date) + + signature := hmacSHA256(sigKey, []byte(stringToSign)) + s.prevSig = signature + + return signature, nil +} + +func buildEventStreamStringToSign(headers, payload, prevSig []byte, scope string, date time.Time) string { + return strings.Join([]string{ + "AWS4-HMAC-SHA256-PAYLOAD", + formatTime(date), + scope, + hex.EncodeToString(prevSig), + hex.EncodeToString(hashSHA256(headers)), + hex.EncodeToString(hashSHA256(payload)), + }, "\n") +} diff --git a/aws/signer/v4/stream_test.go b/aws/signer/v4/stream_test.go new file mode 100644 index 00000000000..1e974db52fb --- /dev/null +++ b/aws/signer/v4/stream_test.go @@ -0,0 +1,133 @@ +// +build go1.7 + +package v4 + +import ( + "encoding/hex" + "fmt" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +type periodicBadCredentials struct { + call int + credentials *credentials.Credentials +} + +func (p *periodicBadCredentials) Get() (credentials.Value, error) { + defer func() { + p.call++ + }() + + if p.call%2 == 0 { + return credentials.Value{}, fmt.Errorf("credentials error") + } + + return p.credentials.Get() +} + +type chunk struct { + headers, payload []byte +} + +func mustDecodeHex(b []byte, err error) []byte { + if err != nil { + panic(err) + } + + return b +} + +func TestStreamingChunkSigner(t *testing.T) { + const ( + region = "us-east-1" + service = "transcribe" + seedSignature = "9d9ab996c81f32c9d4e6fc166c92584f3741d1cb5ce325cd11a77d1f962c8de2" + ) + + staticCredentials := credentials.NewStaticCredentials("AKIDEXAMPLE", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "") + currentTime := time.Date(2019, 1, 27, 22, 37, 54, 0, time.UTC) + + cases := map[string]struct { + credentials credentialValueProvider + chunks []chunk + expectedSignatures map[int]string + expectedErrors map[int]string + }{ + "signature calculation": { + credentials: staticCredentials, + chunks: []chunk{ + {headers: []byte("headers"), payload: []byte("payload")}, + {headers: []byte("more headers"), payload: []byte("more payload")}, + }, + expectedSignatures: map[int]string{ + 0: "681a7eaa82891536f24af7ec7e9219ee251ccd9bac2f1b981eab7c5ec8579115", + 1: "07633d9d4ab4d81634a2164934d1f648c7cbc6839a8cf0773d818127a267e4d6", + }, + }, + "signature calculation errors": { + credentials: &periodicBadCredentials{credentials: staticCredentials}, + chunks: []chunk{ + {headers: []byte("headers"), payload: []byte("payload")}, + {headers: []byte("headers"), payload: []byte("payload")}, + {headers: []byte("more headers"), payload: []byte("more payload")}, + {headers: []byte("more headers"), payload: []byte("more payload")}, + }, + expectedSignatures: map[int]string{ + 1: "681a7eaa82891536f24af7ec7e9219ee251ccd9bac2f1b981eab7c5ec8579115", + 3: "07633d9d4ab4d81634a2164934d1f648c7cbc6839a8cf0773d818127a267e4d6", + }, + expectedErrors: map[int]string{ + 0: "credentials error", + 2: "credentials error", + }, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + chunkSigner := &StreamSigner{ + region: region, + service: service, + credentials: tt.credentials, + prevSig: mustDecodeHex(hex.DecodeString(seedSignature)), + } + + for i, chunk := range tt.chunks { + var expectedError string + if len(tt.expectedErrors) != 0 { + _, ok := tt.expectedErrors[i] + if ok { + expectedError = tt.expectedErrors[i] + } + } + + signature, err := chunkSigner.GetSignature(chunk.headers, chunk.payload, currentTime) + if err == nil && len(expectedError) > 0 { + t.Errorf("expected error, but got nil") + continue + } else if err != nil && len(expectedError) == 0 { + t.Errorf("expected no error, but got %v", err) + continue + } else if err != nil && len(expectedError) > 0 && !strings.Contains(err.Error(), expectedError) { + t.Errorf("expected %v, but got %v", expectedError, err) + continue + } else if len(expectedError) > 0 { + continue + } + + expectedSignature, ok := tt.expectedSignatures[i] + if !ok { + t.Fatalf("expected signature not provided for test case") + } + + if e, a := expectedSignature, hex.EncodeToString(signature); e != a { + t.Errorf("expected %v, got %v", e, a) + } + } + }) + } +} diff --git a/aws/signer/v4/v4.go b/aws/signer/v4/v4.go index 8104793aa5b..3d052c0f923 100644 --- a/aws/signer/v4/v4.go +++ b/aws/signer/v4/v4.go @@ -79,6 +79,7 @@ const ( authHeaderPrefix = "AWS4-HMAC-SHA256" timeFormat = "20060102T150405Z" shortTimeFormat = "20060102" + awsV4Request = "aws4_request" // emptyStringSHA256 is a SHA256 of an empty string emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855` @@ -229,11 +230,9 @@ type signingCtx struct { DisableURIPathEscaping bool - credValues credentials.Value - isPresign bool - formattedTime string - formattedShortTime string - unsignedPayload bool + credValues credentials.Value + isPresign bool + unsignedPayload bool bodyDigest string signedHeaders string @@ -546,25 +545,17 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) error { } func (ctx *signingCtx) buildTime() { - ctx.formattedTime = ctx.Time.UTC().Format(timeFormat) - ctx.formattedShortTime = ctx.Time.UTC().Format(shortTimeFormat) - if ctx.isPresign { duration := int64(ctx.ExpireTime / time.Second) - ctx.Query.Set("X-Amz-Date", ctx.formattedTime) + ctx.Query.Set("X-Amz-Date", formatTime(ctx.Time)) ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10)) } else { - ctx.Request.Header.Set("X-Amz-Date", ctx.formattedTime) + ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time)) } } func (ctx *signingCtx) buildCredentialString() { - ctx.credentialString = strings.Join([]string{ - ctx.formattedShortTime, - ctx.Region, - ctx.ServiceName, - "aws4_request", - }, "/") + ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time) if ctx.isPresign { ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString) @@ -653,19 +644,15 @@ func (ctx *signingCtx) buildCanonicalString() { func (ctx *signingCtx) buildStringToSign() { ctx.stringToSign = strings.Join([]string{ authHeaderPrefix, - ctx.formattedTime, + formatTime(ctx.Time), ctx.credentialString, - hex.EncodeToString(makeSha256([]byte(ctx.canonicalString))), + hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))), }, "\n") } func (ctx *signingCtx) buildSignature() { - secret := ctx.credValues.SecretAccessKey - date := makeHmac([]byte("AWS4"+secret), []byte(ctx.formattedShortTime)) - region := makeHmac(date, []byte(ctx.Region)) - service := makeHmac(region, []byte(ctx.ServiceName)) - credentials := makeHmac(service, []byte("aws4_request")) - signature := makeHmac(credentials, []byte(ctx.stringToSign)) + creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time) + signature := hmacSHA256(creds, []byte(ctx.stringToSign)) ctx.signature = hex.EncodeToString(signature) } @@ -726,13 +713,13 @@ func (ctx *signingCtx) removePresign() { ctx.Query.Del("X-Amz-SignedHeaders") } -func makeHmac(key []byte, data []byte) []byte { +func hmacSHA256(key []byte, data []byte) []byte { hash := hmac.New(sha256.New, key) hash.Write(data) return hash.Sum(nil) } -func makeSha256(data []byte) []byte { +func hashSHA256(data []byte) []byte { hash := sha256.New() hash.Write(data) return hash.Sum(nil) @@ -804,3 +791,28 @@ func stripExcessSpaces(vals []string) { vals[i] = string(buf[:m]) } } + +func buildSigningScope(region, service string, dt time.Time) string { + return strings.Join([]string{ + formatShortTime(dt), + region, + service, + awsV4Request, + }, "/") +} + +func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte { + kDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt))) + kRegion := hmacSHA256(kDate, []byte(region)) + kService := hmacSHA256(kRegion, []byte(service)) + signingKey := hmacSHA256(kService, []byte(awsV4Request)) + return signingKey +} + +func formatShortTime(dt time.Time) string { + return dt.UTC().Format(shortTimeFormat) +} + +func formatTime(dt time.Time) string { + return dt.UTC().Format(timeFormat) +} diff --git a/private/protocol/eventstream/encode.go b/private/protocol/eventstream/encode.go index e2f8666c21d..e2096a79b22 100644 --- a/private/protocol/eventstream/encode.go +++ b/private/protocol/eventstream/encode.go @@ -44,7 +44,7 @@ func (e *Encoder) Encode(msg Message) (err error) { }() } - if err = encodeHeaders(e.headersBuf, msg.Headers); err != nil { + if err = EncodeHeaders(e.headersBuf, msg.Headers); err != nil { return err } @@ -124,9 +124,9 @@ func encodePrelude(w io.Writer, crc hash.Hash32, headersLen, payloadLen uint32) return nil } -// encodeHeaders writes the header values to the writer encoded in the event +// EncodeHeaders writes the header values to the writer encoded in the event // stream format. Returns an error if a header fails to encode. -func encodeHeaders(w io.Writer, headers Headers) error { +func EncodeHeaders(w io.Writer, headers Headers) error { for _, h := range headers { hn := headerName{ Len: uint8(len(h.Name)), diff --git a/private/protocol/eventstream/eventstreamapi/signer.go b/private/protocol/eventstream/eventstreamapi/signer.go new file mode 100644 index 00000000000..25ecabe9f81 --- /dev/null +++ b/private/protocol/eventstream/eventstreamapi/signer.go @@ -0,0 +1,43 @@ +package eventstreamapi + +import ( + "bytes" + "time" + + "github.com/aws/aws-sdk-go/private/protocol/eventstream" +) + +const ( + chunkSignatureHeader = ":chunk-signature" + chunkDateHeader = ":date" +) + +// StreamSigner defines an interface for the implementation of signing of event stream payloads +type StreamSigner interface { + GetSignature(headers, payload []byte, date time.Time) ([]byte, error) +} + +// MessageSigner encapsulates signing and attaching signatures to event stream messages +type MessageSigner struct { + Signer StreamSigner +} + +// SignMessage takes the given event stream message generates and adds signature information +// to the event stream message. +func (s MessageSigner) SignMessage(msg *eventstream.Message, date time.Time) error { + msg.Headers.Set(chunkDateHeader, eventstream.TimestampValue(date)) + + var headers bytes.Buffer + if err := eventstream.EncodeHeaders(&headers, msg.Headers); err != nil { + return err + } + + sig, err := s.Signer.GetSignature(headers.Bytes(), msg.Payload, date) + if err != nil { + return err + } + + msg.Headers.Set(chunkSignatureHeader, eventstream.BytesValue(sig)) + + return nil +} diff --git a/private/protocol/eventstream/eventstreamapi/signer_test.go b/private/protocol/eventstream/eventstreamapi/signer_test.go new file mode 100644 index 00000000000..e949991f76b --- /dev/null +++ b/private/protocol/eventstream/eventstreamapi/signer_test.go @@ -0,0 +1,106 @@ +// +build go1.7 + +package eventstreamapi + +import ( + "encoding/hex" + "fmt" + "reflect" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/private/protocol/eventstream" +) + +type mockChunkSigner struct { + signature string + err error +} + +func (m mockChunkSigner) GetSignature(_, _ []byte, _ time.Time) ([]byte, error) { + return mustDecodeHex(hex.DecodeString(m.signature)), m.err +} + +func TestMessageSigner(t *testing.T) { + currentTime := time.Date(2019, 1, 27, 22, 37, 54, 0, time.UTC) + + cases := map[string]struct { + signer StreamSigner + input eventstream.Message + expected eventstream.Message + expectedError string + }{ + "sign message": { + signer: mockChunkSigner{signature: "524f1d03d1d81e94a099042736d40bd9681b867321443ff58a4568e274dbd83bff"}, + input: eventstream.Message{ + Headers: []eventstream.Header{ + { + Name: "header_name", + Value: eventstream.StringValue("header value"), + }, + }, + Payload: []byte("payload"), + }, + expected: eventstream.Message{ + Headers: []eventstream.Header{ + { + Name: "header_name", + Value: eventstream.StringValue("header value"), + }, + { + Name: ":date", + Value: eventstream.TimestampValue(currentTime), + }, + { + Name: ":chunk-signature", + Value: eventstream.BytesValue(mustDecodeHex(hex.DecodeString("524f1d03d1d81e94a099042736d40bd9681b867321443ff58a4568e274dbd83bff"))), + }, + }, + Payload: []byte("payload"), + }, + }, + "signing error": { + signer: mockChunkSigner{err: fmt.Errorf("signing error")}, + input: eventstream.Message{ + Headers: []eventstream.Header{ + { + Name: "header_name", + Value: eventstream.StringValue("header value"), + }, + }, + Payload: []byte("payload"), + }, + expectedError: "signing error", + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + messageSigner := MessageSigner{Signer: tt.signer} + + err := messageSigner.SignMessage(&tt.input, currentTime) + if err == nil && len(tt.expectedError) > 0 { + t.Fatalf("expected error, but got nil") + } else if err != nil && len(tt.expectedError) == 0 { + t.Fatalf("expected no error, but got %v", err) + } else if err != nil && len(tt.expectedError) > 0 && !strings.Contains(err.Error(), tt.expectedError) { + t.Fatalf("expected %v, but got %v", tt.expectedError, err) + } else if len(tt.expectedError) > 0 { + return + } + + if e, a := tt.expected, tt.input; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, got %v", e, a) + } + }) + } +} + +func mustDecodeHex(b []byte, err error) []byte { + if err != nil { + panic(err) + } + + return b +} diff --git a/private/protocol/eventstream/message.go b/private/protocol/eventstream/message.go index 2dc012a66e2..25c9783cde6 100644 --- a/private/protocol/eventstream/message.go +++ b/private/protocol/eventstream/message.go @@ -27,7 +27,7 @@ func (m *Message) rawMessage() (rawMessage, error) { if len(m.Headers) > 0 { var headers bytes.Buffer - if err := encodeHeaders(&headers, m.Headers); err != nil { + if err := EncodeHeaders(&headers, m.Headers); err != nil { return rawMessage{}, err } raw.Headers = headers.Bytes()