Skip to content

Commit

Permalink
private/protocol/eventstream/eventstreamapi: Add MessageSigner to Eve…
Browse files Browse the repository at this point in the history
…ntWriter (aws#3016)
  • Loading branch information
skmcgrail committed Dec 23, 2019
1 parent 84f860d commit 635b810
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 59 deletions.
65 changes: 65 additions & 0 deletions private/protocol/eventstream/eventstreamapi/shared_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package eventstreamapi

import (
"bytes"
"encoding/hex"
"time"

"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
)

type mockChunkSigner struct {
signature string
err error
}

func (m mockChunkSigner) GetSignature(_, _ []byte, _ time.Time) ([]byte, error) {
return mustDecodeHex(hex.DecodeString(m.signature)), m.err
}

type eventStructured struct {
_ struct{} `type:"structure"`

String *string `type:"string"`
Number *int64 `type:"long"`
Nested *eventStructured `type:"structure"`
}

func (e *eventStructured) MarshalEvent(pm protocol.PayloadMarshaler) (eventstream.Message, error) {
var msg eventstream.Message
msg.Headers.Set(MessageTypeHeader, eventstream.StringValue(EventMessageType))
msg.Headers.Set(EventTypeHeader, eventstream.StringValue("eventStructured"))

var buf bytes.Buffer
if err := pm.MarshalPayload(&buf, e); err != nil {
return eventstream.Message{}, err
}

msg.Payload = buf.Bytes()

return msg, nil
}

func (e *eventStructured) UnmarshalEvent(pm protocol.PayloadUnmarshaler, msg eventstream.Message) error {
return pm.UnmarshalPayload(bytes.NewReader(msg.Payload), e)
}

func mustDecodeHex(b []byte, err error) []byte {
if err != nil {
panic(err)
}

return b
}

func swapTimeNow(f func() time.Time) func() {
if f == nil {
return func() {}
}

timeNow = f
return func() {
timeNow = time.Now
}
}
9 changes: 3 additions & 6 deletions private/protocol/eventstream/eventstreamapi/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import (
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
)

const (
chunkSignatureHeader = ":chunk-signature"
chunkDateHeader = ":date"
)
var timeNow = time.Now

// StreamSigner defines an interface for the implementation of signing of event stream payloads
type StreamSigner interface {
Expand All @@ -25,7 +22,7 @@ type MessageSigner struct {
// SignMessage takes the given event stream message generates and adds signature information
// to the event stream message.
func (s MessageSigner) SignMessage(msg *eventstream.Message, date time.Time) error {
msg.Headers.Set(chunkDateHeader, eventstream.TimestampValue(date))
msg.Headers.Set(DateHeader, eventstream.TimestampValue(date))

var headers bytes.Buffer
if err := eventstream.EncodeHeaders(&headers, msg.Headers); err != nil {
Expand All @@ -37,7 +34,7 @@ func (s MessageSigner) SignMessage(msg *eventstream.Message, date time.Time) err
return err
}

msg.Headers.Set(chunkSignatureHeader, eventstream.BytesValue(sig))
msg.Headers.Set(ChunkSignatureHeader, eventstream.BytesValue(sig))

return nil
}
17 changes: 0 additions & 17 deletions private/protocol/eventstream/eventstreamapi/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@ import (
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
)

type mockChunkSigner struct {
signature string
err error
}

func (m mockChunkSigner) GetSignature(_, _ []byte, _ time.Time) ([]byte, error) {
return mustDecodeHex(hex.DecodeString(m.signature)), m.err
}

func TestMessageSigner(t *testing.T) {
currentTime := time.Date(2019, 1, 27, 22, 37, 54, 0, time.UTC)

Expand Down Expand Up @@ -96,11 +87,3 @@ func TestMessageSigner(t *testing.T) {
})
}
}

func mustDecodeHex(b []byte, err error) []byte {
if err != nil {
panic(err)
}

return b
}
15 changes: 15 additions & 0 deletions private/protocol/eventstream/eventstreamapi/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Marshaler interface {
type EventWriter struct {
writer io.Writer
encoder *eventstream.Encoder
signer *MessageSigner

payloadMarshaler protocol.PayloadMarshaler
}
Expand All @@ -27,11 +28,13 @@ type EventWriter struct {
// writer provided. Use the WriteStream method to write an event to the stream.
func NewEventWriter(writer io.Writer,
payloadMarshaler protocol.PayloadMarshaler,
signer *MessageSigner,
) *EventWriter {
return &EventWriter{
writer: writer,
encoder: eventstream.NewEncoder(writer),
payloadMarshaler: payloadMarshaler,
signer: signer,
}
}

Expand All @@ -43,6 +46,14 @@ func (w *EventWriter) UseLogger(logger aws.Logger, logLevel aws.LogLevelType) {
}
}

func (w *EventWriter) signMessage(msg *eventstream.Message) error {
if w.signer == nil {
return nil
}

return w.signer.SignMessage(msg, timeNow())
}

// WriteEvent writes an event to the stream. Returns an error if the event
// fails to marshal into a message, or writing to the underlying writer fails.
func (w *EventWriter) WriteEvent(event Marshaler) error {
Expand All @@ -51,5 +62,9 @@ func (w *EventWriter) WriteEvent(event Marshaler) error {
return err
}

if err = w.signMessage(&msg); err != nil {
return err
}

return w.encoder.Encode(msg)
}
90 changes: 54 additions & 36 deletions private/protocol/eventstream/eventstreamapi/writer_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
// +build go1.7

package eventstreamapi

import (
"bytes"
"encoding/hex"
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
Expand All @@ -14,8 +18,10 @@ import (

func TestEventWriter(t *testing.T) {
cases := map[string]struct {
Event Marshaler
Expect eventstream.Message
Event Marshaler
Signer *MessageSigner
TimeFunc func() time.Time
Expect eventstream.Message
}{
"structured event": {
Event: &eventStructured{
Expand All @@ -37,23 +43,61 @@ func TestEventWriter(t *testing.T) {
Payload: []byte(`{"String":"stringfield","Number":123,"Nested":{"String":"fieldstring","Number":321}}`),
},
},
"signed event": {
Event: &eventStructured{
String: aws.String("stringfield"),
Number: aws.Int64(123),
Nested: &eventStructured{
String: aws.String("fieldstring"),
Number: aws.Int64(321),
},
},
Signer: &MessageSigner{Signer: &mockChunkSigner{signature: "524f1d03d1d81e94a099042736d40bd9681b867321443ff58a4568e274dbd83bff"}},
TimeFunc: func() time.Time {
return time.Date(2019, 1, 27, 22, 37, 54, 0, time.UTC)
},
Expect: eventstream.Message{
Headers: eventstream.Headers{
eventMessageTypeHeader,
eventstream.Header{
Name: EventTypeHeader,
Value: eventstream.StringValue("eventStructured"),
},
{
Name: DateHeader,
Value: eventstream.TimestampValue(time.Date(2019, 1, 27, 22, 37, 54, 0, time.UTC)),
},
{
Name: ChunkSignatureHeader,
Value: eventstream.BytesValue(mustDecodeHex(hex.DecodeString("524f1d03d1d81e94a099042736d40bd9681b867321443ff58a4568e274dbd83bff"))),
},
},
Payload: []byte(`{"String":"stringfield","Number":123,"Nested":{"String":"fieldstring","Number":321}}`),
},
},
}

var marshalers request.HandlerList
marshalers.PushBackNamed(restjson.BuildHandler)

var stream bytes.Buffer
eventWriter := NewEventWriter(&stream,
protocol.HandlerPayloadMarshal{
Marshalers: marshalers,
},
)

decoder := eventstream.NewDecoder(&stream)

decodeBuf := make([]byte, 1024)
for name, c := range cases {
t.Run(name, func(t *testing.T) {
defer swapTimeNow(c.TimeFunc)()

stream.Reset()

eventWriter := NewEventWriter(&stream,
protocol.HandlerPayloadMarshal{
Marshalers: marshalers,
},
c.Signer,
)

decoder := eventstream.NewDecoder(&stream)

eventWriter.UseLogger(t, aws.LogDebugWithEventStreamBody)

if err := eventWriter.WriteEvent(c.Event); err != nil {
Expand Down Expand Up @@ -81,6 +125,7 @@ func BenchmarkEventWriter(b *testing.B) {
protocol.HandlerPayloadMarshal{
Marshalers: marshalers,
},
nil,
)

event := &eventStructured{
Expand All @@ -99,30 +144,3 @@ func BenchmarkEventWriter(b *testing.B) {
}
}
}

type eventStructured struct {
_ struct{} `type:"structure"`

String *string `type:"string"`
Number *int64 `type:"long"`
Nested *eventStructured `type:"structure"`
}

func (e *eventStructured) MarshalEvent(pm protocol.PayloadMarshaler) (eventstream.Message, error) {
var msg eventstream.Message
msg.Headers.Set(MessageTypeHeader, eventstream.StringValue(EventMessageType))
msg.Headers.Set(EventTypeHeader, eventstream.StringValue("eventStructured"))

var buf bytes.Buffer
if err := pm.MarshalPayload(&buf, e); err != nil {
return eventstream.Message{}, err
}

msg.Payload = buf.Bytes()

return msg, nil
}

func (e *eventStructured) UnmarshalEvent(pm protocol.PayloadUnmarshaler, msg eventstream.Message) error {
return pm.UnmarshalPayload(bytes.NewReader(msg.Payload), e)
}

0 comments on commit 635b810

Please sign in to comment.