diff --git a/resp/resp3/resp.go b/resp/resp3/resp.go index 0f96761..10b211d 100644 --- a/resp/resp3/resp.go +++ b/resp/resp3/resp.go @@ -2005,10 +2005,6 @@ func keyableReceiver(prefix Prefix, kv reflect.Value) (reflect.Value, error) { } func unmarshalAgg(prefix Prefix, br resp.BufferedReader, l int64, rcv interface{}, o *resp.Opts) error { - if prefix == MapHeaderPrefix { - l *= 2 - } - if !o.DisableErrorBubbling { if o == nil { o = resp.NewOpts() @@ -2021,6 +2017,9 @@ func unmarshalAgg(prefix Prefix, br resp.BufferedReader, l int64, rcv interface{ } size := int(l) + if prefix == MapHeaderPrefix { + size *= 2 + } stream := size < 0 if rcv == nil { return discardMulti(br, size, o) @@ -2041,6 +2040,9 @@ func unmarshalAgg(prefix Prefix, br resp.BufferedReader, l int64, rcv interface{ switch v.Kind() { case reflect.Slice: slice := v + if slice.Type().Elem().Kind() == reflect.Struct { + size = int(l) + } if size > slice.Cap() || slice.IsNil() { sliceSize := size if stream { diff --git a/resp/resp3/resp_test.go b/resp/resp3/resp_test.go index 4c65869..ae221ca 100644 --- a/resp/resp3/resp_test.go +++ b/resp/resp3/resp_test.go @@ -198,6 +198,43 @@ func TestRawMessage(t *testing.T) { } } +func TestMapIntoSliceOfStructs(t *testing.T) { + buf := new(bytes.Buffer) + opts := resp.NewOpts() + + _ = (MapHeader{NumPairs: 1}).MarshalRESP(buf, opts) + _ = (SimpleString{S: "key"}).MarshalRESP(buf, opts) + _ = (SimpleString{S: "value"}).MarshalRESP(buf, opts) + + var rcv []mapPair + reader := bufio.NewReader(buf) + err := Unmarshal(reader, &rcv, opts) + require.NoError(t, err) + require.Len(t, rcv, 1) + + expectedStruct := mapPair{Key: "key", Value: "value"} + assert.Equal(t, expectedStruct, rcv[0]) +} + +type mapPair struct { + Key string + Value string +} + +func (s *mapPair) UnmarshalRESP(b resp.BufferedReader, opts *resp.Opts) error { + var key SimpleString + if err := key.UnmarshalRESP(b, opts); err != nil { + return err + } + var val SimpleString + if err := val.UnmarshalRESP(b, opts); err != nil { + return err + } + s.Key = key.S + s.Value = val.S + return nil +} + func Example_streamedAggregatedType() { buf := new(bytes.Buffer) opts := resp.NewOpts()