diff --git a/marshal.go b/marshal.go index 5bdbdd9..cc62164 100644 --- a/marshal.go +++ b/marshal.go @@ -218,6 +218,17 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) { return } + // If the pointer is to a struct + if deref(val.Type()).Kind() == reflect.Struct { + b, ok := tryEncodeFunc(val) + if ok { + putTag(m, num, protowire.BytesType) + putBytes(m, b) + + return + } + } + m.encodeValue(num, val.Elem()) case reflect.Interface: diff --git a/marshal_test.go b/marshal_test.go index 0ebde24..23a5dc2 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -489,8 +489,8 @@ func TestCustomEcnoders(t *testing.T) { }, "should use custom encoder on pointer": { testCustomEncodersDecoders( - encodeCustomEncoderStruct, - decodeCustomEncoderStruct, + encodeCustomEncoderStructPtr, + decodeCustomEncoderStructPtr, OneFieldStruct[*CustomEncoderStruct]{ Field: &CustomEncoderStruct{ Value: 150, @@ -498,7 +498,7 @@ func TestCustomEcnoders(t *testing.T) { }, OneFieldStruct[*CustomEncoderStruct]{ Field: &CustomEncoderStruct{ - Value: 152, + Value: 156, }, }, ), @@ -547,6 +547,21 @@ func decodeCustomEncoderStruct(slc []byte) (CustomEncoderStruct, error) { }, err } +func encodeCustomEncoderStructPtr(v *CustomEncoderStruct) ([]byte, error) { + return []byte(strconv.Itoa(v.Value + 3)), nil +} + +func decodeCustomEncoderStructPtr(slc []byte) (*CustomEncoderStruct, error) { + res, err := strconv.Atoi(string(slc)) + if err != nil { + return &CustomEncoderStruct{}, err + } + + return &CustomEncoderStruct{ + Value: res + 3, + }, err +} + func testCustomEncodersDecoders[V any, T any]( enc func(T) ([]byte, error), dec func([]byte) (T, error), diff --git a/unmarshal.go b/unmarshal.go index a305f29..00cf5c3 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -349,7 +349,7 @@ func instantiate(dst reflect.Value) error { return nil } -//nolint:cyclop,gocyclo +//nolint:cyclop,gocognit,gocyclo func unmarshalBytes(dst reflect.Value, value complexValue) (err error) { defer func() { if err != nil { @@ -413,6 +413,18 @@ func unmarshalBytes(dst reflect.Value, value complexValue) (err error) { } } + // If the pointer is to a struct + if deref(dst.Type()).Kind() == reflect.Struct { + ok, err := tryDecodeFunc(bytes, dst) + if err != nil { + return err + } + + if ok { + return nil + } + } + return unmarshalBytes(dst.Elem(), value) case reflect.String: