diff --git a/internal/push/push.go b/internal/push/push.go index 335bf854..9a42f292 100644 --- a/internal/push/push.go +++ b/internal/push/push.go @@ -3,7 +3,6 @@ package push import ( "context" "crypto/tls" - "encoding/json" "fmt" "net/http" "sync" @@ -40,6 +39,8 @@ type subscription struct { currentTokenLock sync.RWMutex unregisterTokenNotifier func(string) registerTokenNotifier func(string, func(string)) + readEncoding elemental.EncodingType + writeEncoding elemental.EncodingType } // NewSubscriber creates a new Subscription. @@ -50,9 +51,15 @@ func NewSubscriber( registerTokenNotifier func(string, func(string)), unregisterTokenNotifier func(string), tlsConfig *tls.Config, + headers http.Header, recursive bool, ) manipulate.Subscriber { + readEncoding, writeEncoding, err := elemental.EncodingFromHeaders(headers) + if err != nil { + panic(err) + } + return &subscription{ id: uuid.Must(uuid.NewV4()).String(), url: url, @@ -67,12 +74,15 @@ func NewSubscriber( status: make(chan manipulate.SubscriberStatus, statusChSize), filters: make(chan *elemental.PushFilter, filterChSize), currentFilterLock: sync.RWMutex{}, + readEncoding: readEncoding, + writeEncoding: writeEncoding, config: wsc.Config{ PongWait: 10 * time.Second, WriteWait: 10 * time.Second, PingPeriod: 5 * time.Second, ReadChanSize: 2048, TLSConfig: tlsConfig, + Headers: headers, }, } } @@ -181,7 +191,7 @@ func (s *subscription) listen(ctx context.Context) { case filter := <-s.filters: - filterData, err = json.Marshal(filter) + filterData, err = elemental.Encode(s.writeEncoding, filter) if err != nil { s.publishError(err) continue @@ -192,7 +202,7 @@ func (s *subscription) listen(ctx context.Context) { case data := <-s.conn.Read(): event := &elemental.Event{} - if err = json.Unmarshal(data, event); err != nil { + if err = elemental.Decode(s.readEncoding, data, event); err != nil { s.publishError(err) continue } diff --git a/maniphttp/manipulator.go b/maniphttp/manipulator.go index ed887969..57aad121 100644 --- a/maniphttp/manipulator.go +++ b/maniphttp/manipulator.go @@ -9,7 +9,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" "net/http" "net/url" @@ -44,6 +43,7 @@ type httpManipulator struct { tokenManager manipulate.TokenManager globalHeaders http.Header transport *http.Transport + encoding elemental.EncodingType } // New returns a maniphttp.Manipulator configured according to the given suite of Option. @@ -60,6 +60,7 @@ func New(ctx context.Context, url string, options ...Option) (manipulate.Manipul renewNotifiers: map[string]func(string){}, ctx: ctx, url: url, + encoding: elemental.EncodingTypeJSON, } // Apply the options. @@ -145,7 +146,7 @@ func (s *httpManipulator) RetrieveMany(mctx manipulate.Context, dest elemental.I if response.StatusCode != http.StatusNoContent { defer response.Body.Close() // nolint: errcheck - if err := decodeData(response, dest); err != nil { + if err := decodeData(response, s.encoding, dest); err != nil { sp.SetTag("error", true) sp.LogFields(log.Error(err)) return err @@ -187,7 +188,7 @@ func (s *httpManipulator) Retrieve(mctx manipulate.Context, object elemental.Ide if response.StatusCode != http.StatusNoContent { defer response.Body.Close() // nolint: errcheck - if err := decodeData(response, &object); err != nil { + if err := decodeData(response, s.encoding, object); err != nil { sp.SetTag("error", true) sp.LogFields(log.Error(err)) return err @@ -224,7 +225,7 @@ func (s *httpManipulator) Create(mctx manipulate.Context, object elemental.Ident return manipulate.NewErrCannotBuildQuery(err.Error()) } - data, err := json.Marshal(object) + data, err := elemental.Encode(s.encoding, object) if err != nil { sp.SetTag("error", true) sp.LogFields(log.Error(err)) @@ -240,7 +241,7 @@ func (s *httpManipulator) Create(mctx manipulate.Context, object elemental.Ident if response.StatusCode != http.StatusNoContent { defer response.Body.Close() // nolint: errcheck - if err := decodeData(response, &object); err != nil { + if err := decodeData(response, s.encoding, object); err != nil { sp.SetTag("error", true) sp.LogFields(log.Error(err)) return err @@ -286,7 +287,7 @@ func (s *httpManipulator) Update(mctx manipulate.Context, object elemental.Ident return manipulate.NewErrCannotBuildQuery(err.Error()) } - data, err := json.Marshal(object) + data, err := elemental.Encode(s.encoding, object) if err != nil { sp.SetTag("error", true) sp.LogFields(log.Error(err)) @@ -302,7 +303,7 @@ func (s *httpManipulator) Update(mctx manipulate.Context, object elemental.Ident if response.StatusCode != http.StatusNoContent { defer response.Body.Close() // nolint: errcheck - if err := decodeData(response, &object); err != nil { + if err := decodeData(response, s.encoding, object); err != nil { sp.SetTag("error", true) sp.LogFields(log.Error(err)) return err @@ -347,7 +348,7 @@ func (s *httpManipulator) Delete(mctx manipulate.Context, object elemental.Ident if response.StatusCode != http.StatusNoContent { defer response.Body.Close() // nolint: errcheck - if err := decodeData(response, &object); err != nil { + if err := decodeData(response, s.encoding, object); err != nil { sp.SetTag("error", true) sp.LogFields(log.Error(err)) return err @@ -407,7 +408,8 @@ func (s *httpManipulator) prepareHeaders(request *http.Request, mctx manipulate. request.Header[k] = v } - request.Header.Set("Content-Type", "application/json; charset=UTF-8") + request.Header.Set("Content-Type", string(s.encoding)) + request.Header.Set("Accept", string(s.encoding)) request.Header.Set("Accept-Encoding", "gzip") if ns != "" { @@ -585,10 +587,10 @@ func (s *httpManipulator) send(mctx manipulate.Context, method string, requrl st if response.StatusCode < 200 || response.StatusCode >= 300 { - es := []elemental.Error{} + es := elemental.Errors{} defer response.Body.Close() // nolint: errcheck - if err := decodeData(response, &es); err != nil { + if err := decodeData(response, s.encoding, &es); err != nil { return nil, err } diff --git a/maniphttp/manipulator_test.go b/maniphttp/manipulator_test.go index 7fbbd9ac..c94ecc74 100644 --- a/maniphttp/manipulator_test.go +++ b/maniphttp/manipulator_test.go @@ -12,12 +12,11 @@ import ( "testing" "time" - "go.aporeto.io/manipulate/internal/idempotency" - . "github.com/smartystreets/goconvey/convey" "go.aporeto.io/elemental" testmodel "go.aporeto.io/elemental/test/model" "go.aporeto.io/manipulate" + "go.aporeto.io/manipulate/internal/idempotency" "go.aporeto.io/manipulate/internal/tracing" "go.aporeto.io/manipulate/maniptest" ) @@ -54,9 +53,6 @@ func TestHTTP_NewSHTTPm(t *testing.T) { }) } -/* - Privates -*/ func TestHTTP_makeAuthorizationHeaders(t *testing.T) { Convey("Given I create a new HTTP manipulator", t, func() { @@ -399,7 +395,7 @@ func TestHTTP_Retrieve(t *testing.T) { Convey("Then error should not be nil", func() { So(err, ShouldNotBeNil) So(err.(elemental.Errors).Code(), ShouldEqual, 422) - So(err.(elemental.Errors)[0].(elemental.Error).Description, ShouldEqual, "nope.") + So(err.(elemental.Errors)[0].Description, ShouldEqual, "nope.") }) }) }) @@ -548,7 +544,7 @@ func TestHTTP_Delete(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `[{"ID": "yyy"}]`) + w.Write([]byte(`{"ID":"yyy"}`)) // nolint })) defer ts.Close() @@ -559,10 +555,13 @@ func TestHTTP_Delete(t *testing.T) { list := testmodel.NewList() list.ID = "xxx" - _ = m.Delete(nil, list) + err := m.Delete(nil, list) + if err != nil { + panic(err) + } - Convey("Then ID should 'xxx'", func() { - So(list.Identifier(), ShouldEqual, "xxx") + Convey("Then ID should 'yyy'", func() { + So(list.Identifier(), ShouldEqual, "yyy") }) }) @@ -1212,7 +1211,7 @@ func TestHTTP_send(t *testing.T) { Convey("Then err should not be nil", func() { So(err, ShouldNotBeNil) So(err, ShouldHaveSameTypeAs, manipulate.ErrCannotUnmarshal{}) - So(err.Error(), ShouldEqual, `Unable to unmarshal data: invalid character '\n' in string literal. original data: + So(err.Error(), ShouldEqual, `Unable to unmarshal data: unable to decode application/json: EOF. original data: [{"code": 423, "] `) }) diff --git a/maniphttp/options.go b/maniphttp/options.go index 7c69952e..7e23bc89 100644 --- a/maniphttp/options.go +++ b/maniphttp/options.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net/http" + "go.aporeto.io/elemental" "go.aporeto.io/manipulate" ) @@ -88,3 +89,10 @@ func OptionDisableBuiltInRetry() Option { m.disableAutoRetry = true } } + +// OptionEncoding sets the encoding/decoding type to use. +func OptionEncoding(enc elemental.EncodingType) Option { + return func(m *httpManipulator) { + m.encoding = enc + } +} diff --git a/maniphttp/options_test.go b/maniphttp/options_test.go index d5fd3515..fd6d4aa9 100644 --- a/maniphttp/options_test.go +++ b/maniphttp/options_test.go @@ -7,6 +7,7 @@ import ( "testing" . "github.com/smartystreets/goconvey/convey" + "go.aporeto.io/elemental" ) type testTokenManager struct{} @@ -76,4 +77,10 @@ func TestManipHttp_Optionions(t *testing.T) { OptionDisableBuiltInRetry()(m) So(m.disableAutoRetry, ShouldBeTrue) }) + + Convey("Calling OptionEncoding should work", t, func() { + m := &httpManipulator{} + OptionEncoding(elemental.EncodingTypeMSGPACK)(m) + So(m.encoding, ShouldEqual, elemental.EncodingTypeMSGPACK) + }) } diff --git a/maniphttp/subscriber.go b/maniphttp/subscriber.go index f71e6871..2098792b 100644 --- a/maniphttp/subscriber.go +++ b/maniphttp/subscriber.go @@ -3,6 +3,7 @@ package maniphttp import ( "crypto/tls" "fmt" + "net/http" "strings" "go.aporeto.io/manipulate" @@ -75,6 +76,10 @@ func NewSubscriber(manipulator manipulate.Manipulator, options ...SubscriberOpti m.registerRenewNotifier, m.unregisterRenewNotifier, cfg.tlsConfig, + http.Header{ + "Content-Type": []string{string(m.encoding)}, + "Accept": []string{string(m.encoding)}, + }, cfg.recursive, ) } diff --git a/maniphttp/utils.go b/maniphttp/utils.go index a2b2ae1c..bf246dc0 100644 --- a/maniphttp/utils.go +++ b/maniphttp/utils.go @@ -5,7 +5,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" "io" "io/ioutil" @@ -16,6 +15,8 @@ import ( "sync" "time" + "go.aporeto.io/elemental" + "go.aporeto.io/manipulate" "go.aporeto.io/manipulate/maniphttp/internal/compiler" ) @@ -67,7 +68,7 @@ func addQueryParameters(req *http.Request, ctx manipulate.Context) error { return nil } -func decodeData(r *http.Response, dest interface{}) (err error) { +func decodeData(r *http.Response, encodingType elemental.EncodingType, dest interface{}) (err error) { if r.Body == nil { return manipulate.NewErrCannotUnmarshal("nil reader") @@ -96,7 +97,7 @@ func decodeData(r *http.Response, dest interface{}) (err error) { return manipulate.NewErrCannotUnmarshal(fmt.Sprintf("unable to read data: %s", err.Error())) } - if err = json.Unmarshal(data, dest); err != nil { + if err = elemental.Decode(encodingType, data, dest); err != nil { return manipulate.NewErrCannotUnmarshal(fmt.Sprintf("%s. original data:\n%s", err.Error(), string(data))) } diff --git a/maniphttp/utils_test.go b/maniphttp/utils_test.go index af779c6f..85846548 100644 --- a/maniphttp/utils_test.go +++ b/maniphttp/utils_test.go @@ -171,7 +171,7 @@ func Test_decodeData(t *testing.T) { Convey("When I call decodeData", func() { dest := map[string]interface{}{} - err := decodeData(r, &dest) + err := decodeData(r, "", &dest) Convey("Then err should be nil", func() { So(err, ShouldBeNil) @@ -180,7 +180,7 @@ func Test_decodeData(t *testing.T) { Convey("Then the dest should be correct", func() { So(len(dest), ShouldEqual, 2) So(dest["name"].(string), ShouldEqual, "thename") - So(dest["age"].(float64), ShouldEqual, 2) + So(dest["age"].(uint64), ShouldEqual, 2) }) }) }) @@ -205,7 +205,7 @@ func Test_decodeData(t *testing.T) { Convey("When I call decodeData", func() { dest := map[string]interface{}{} - err := decodeData(r, &dest) + err := decodeData(r, "", &dest) Convey("Then err should be nil", func() { So(err, ShouldBeNil) @@ -214,7 +214,7 @@ func Test_decodeData(t *testing.T) { Convey("Then the dest should be correct", func() { So(len(dest), ShouldEqual, 2) So(dest["name"].(string), ShouldEqual, "thename") - So(dest["age"].(float64), ShouldEqual, 2) + So(dest["age"].(uint64), ShouldEqual, 2) }) }) }) @@ -228,11 +228,11 @@ func Test_decodeData(t *testing.T) { Convey("When I call decodeData", func() { dest := map[string]interface{}{} - err := decodeData(r, &dest) + err := decodeData(r, "", &dest) Convey("Then err should not be nil", func() { So(err, ShouldNotBeNil) - So(err.Error(), ShouldEqual, "Unable to unmarshal data: invalid character '<' looking for beginning of value. original data:\nnot json") + So(err.Error(), ShouldEqual, "Unable to unmarshal data: unable to decode application/json: json decode error [pos 1]: read map - expect char '{' but got char '<'. original data:\nnot json") }) Convey("Then the dest should be empty", func() { @@ -250,7 +250,7 @@ func Test_decodeData(t *testing.T) { } dest := map[string]interface{}{} - err := decodeData(r, &dest) + err := decodeData(r, "", &dest) Convey("Then err should not be nil", func() { So(err, ShouldNotBeNil) @@ -272,7 +272,7 @@ func Test_decodeData(t *testing.T) { Convey("When I call decodeData", func() { dest := map[string]interface{}{} - err := decodeData(r, &dest) + err := decodeData(r, "", &dest) Convey("Then err should not be nil", func() { So(err, ShouldNotBeNil)