diff --git a/CHANGELOG.md b/CHANGELOG.md index d30d835f..b4183d4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ ## [Unreleased] +## 0.16.0 (September 11, 2020) + +IMPROVEMENTS: + - Add support for `Un/MarshalAminoJSON` override: if a type implements + `Un/MarshalAminoJSON`, then amino will use these methods for JSON un/marshalling + ([#323]). + +[#323]: https://github.com/tendermint/go-amino/pull/323 + ## 0.15.1 (October 10, 2019) ### IMPROVEMENTS: diff --git a/binary-encode.go b/binary-encode.go index 6cd8e6f7..5b3ab2ac 100644 --- a/binary-encode.go +++ b/binary-encode.go @@ -37,7 +37,7 @@ func (cdc *Codec) encodeReflectBinary(w io.Writer, info *TypeInfo, rv reflect.Va }() } - // Handle override if rv implements json.Marshaler. + // Handle override if rv implements MarshalAmino. if info.IsAminoMarshaler { // First, encode rv into repr instance. var rrv, rinfo = reflect.Value{}, (*TypeInfo)(nil) diff --git a/codec.go b/codec.go index 146ddaea..3c81ff6a 100644 --- a/codec.go +++ b/codec.go @@ -84,10 +84,14 @@ type ConcreteInfo struct { // These fields get set for all concrete types, // even those not manually registered (e.g. are never interface values). - IsAminoMarshaler bool // Implements MarshalAmino() (, error). - AminoMarshalReprType reflect.Type // - IsAminoUnmarshaler bool // Implements UnmarshalAmino() (error). - AminoUnmarshalReprType reflect.Type // + IsAminoMarshaler bool // Implements MarshalAmino() (, error). + AminoMarshalReprType reflect.Type // + IsAminoUnmarshaler bool // Implements UnmarshalAmino() (error). + AminoUnmarshalReprType reflect.Type // + IsAminoJSONMarshaler bool // Implements MarshalAminoJSON() (, error). + AminoJSONMarshalReprType reflect.Type // + IsAminoJSONUnmarshaler bool // Implements UnmarshalAminoJSON() (error). + AminoJSONUnmarshalReprType reflect.Type // } type StructInfo struct { @@ -566,6 +570,14 @@ func (cdc *Codec) newTypeInfoUnregistered(rt reflect.Type) *TypeInfo { info.ConcreteInfo.IsAminoUnmarshaler = true info.ConcreteInfo.AminoUnmarshalReprType = unmarshalAminoReprType(rm) } + if rm, ok := rt.MethodByName("MarshalAminoJSON"); ok { + info.ConcreteInfo.IsAminoJSONMarshaler = true + info.ConcreteInfo.AminoJSONMarshalReprType = marshalAminoJSONReprType(rm) + } + if rm, ok := reflect.PtrTo(rt).MethodByName("UnmarshalAminoJSON"); ok { + info.ConcreteInfo.IsAminoJSONUnmarshaler = true + info.ConcreteInfo.AminoJSONUnmarshalReprType = unmarshalAminoJSONReprType(rm) + } return info } @@ -804,3 +816,42 @@ func unmarshalAminoReprType(rm reflect.Method) (rrt reflect.Type) { } return } + +func marshalAminoJSONReprType(rm reflect.Method) (rrt reflect.Type) { + // Verify form of this method. + if rm.Type.NumIn() != 1 { + panic(fmt.Sprintf("MarshalAminoJSON should have 1 input parameters (including receiver); got %v", rm.Type)) + } + if rm.Type.NumOut() != 2 { + panic(fmt.Sprintf("MarshalAminoJSON should have 2 output parameters; got %v", rm.Type)) + } + if out := rm.Type.Out(1); out != errorType { + panic(fmt.Sprintf("MarshalAminoJSON should have second output parameter of error type, got %v", out)) + } + rrt = rm.Type.Out(0) + if rrt.Kind() == reflect.Ptr { + panic(fmt.Sprintf("Representative objects cannot be pointers; got %v", rrt)) + } + return +} + +func unmarshalAminoJSONReprType(rm reflect.Method) (rrt reflect.Type) { + // Verify form of this method. + if rm.Type.NumIn() != 2 { + panic(fmt.Sprintf("UnmarshalAminoJSON should have 2 input parameters (including receiver); got %v", rm.Type)) + } + if in1 := rm.Type.In(0); in1.Kind() != reflect.Ptr { + panic(fmt.Sprintf("UnmarshalAminoJSON first input parameter should be pointer type but got %v", in1)) + } + if rm.Type.NumOut() != 1 { + panic(fmt.Sprintf("UnmarshalAminoJSON should have 1 output parameters; got %v", rm.Type)) + } + if out := rm.Type.Out(0); out != errorType { + panic(fmt.Sprintf("UnmarshalAminoJSON should have first output parameter of error type, got %v", out)) + } + rrt = rm.Type.In(1) + if rrt.Kind() == reflect.Ptr { + panic(fmt.Sprintf("Representative objects cannot be pointers; got %v", rrt)) + } + return +} diff --git a/json-decode.go b/json-decode.go index f9ff580c..2bb3bca6 100644 --- a/json-decode.go +++ b/json-decode.go @@ -60,6 +60,29 @@ func (cdc *Codec) decodeReflectJSON(bz []byte, info *TypeInfo, rv reflect.Value, } } + // Handle override if a pointer to rv implements UnmarshalAminoJSON. + if info.IsAminoJSONUnmarshaler { + // First, decode repr instance from JSON. + rrv := reflect.New(info.AminoJSONUnmarshalReprType).Elem() + var rinfo *TypeInfo + rinfo, err = cdc.getTypeInfo_wlock(info.AminoJSONUnmarshalReprType) + if err != nil { + return + } + err = cdc.decodeReflectJSON(bz, rinfo, rrv, fopts) + if err != nil { + return + } + // Then, decode from repr instance. + uwrm := rv.Addr().MethodByName("UnmarshalAminoJSON") + uwouts := uwrm.Call([]reflect.Value{rrv}) + erri := uwouts[0].Interface() + if erri != nil { + err = erri.(error) + } + return + } + // Handle override if a pointer to rv implements json.Unmarshaler. if rv.Addr().Type().Implements(jsonUnmarshalerType) { err = rv.Addr().Interface().(json.Unmarshaler).UnmarshalJSON(bz) @@ -401,7 +424,7 @@ func (cdc *Codec) decodeReflectJSONStruct(bz []byte, info *TypeInfo, rv reflect. // Set nil/zero on frv. frv.Set(reflect.Zero(frv.Type())) } - + continue } diff --git a/json-encode.go b/json-encode.go index 15b10f21..1fa07d1f 100644 --- a/json-encode.go +++ b/json-encode.go @@ -48,6 +48,27 @@ func (cdc *Codec) encodeReflectJSON(w io.Writer, info *TypeInfo, rv reflect.Valu ct := rv.Interface().(time.Time).Round(0).UTC() rv = reflect.ValueOf(ct) } + + // Handle override if rv implements MarshalAminoJSON. + if info.IsAminoJSONMarshaler { + // First, encode rv into repr instance. + var ( + rrv reflect.Value + rinfo *TypeInfo + ) + rrv, err = toReprJSONObject(rv) + if err != nil { + return + } + rinfo, err = cdc.getTypeInfo_wlock(info.AminoJSONMarshalReprType) + if err != nil { + return + } + // Then, encode the repr instance. + err = cdc.encodeReflectJSON(w, rinfo, rrv, fopts) + return + } + // Handle override if rv implements json.Marshaler. if rv.CanAddr() { // Try pointer first. if rv.Addr().Type().Implements(jsonMarshalerType) { @@ -59,7 +80,7 @@ func (cdc *Codec) encodeReflectJSON(w io.Writer, info *TypeInfo, rv reflect.Valu return } - // Handle override if rv implements json.Marshaler. + // Handle override if rv implements MarshalAmino. if info.IsAminoMarshaler { // First, encode rv into repr instance. var rrv, rinfo = reflect.Value{}, (*TypeInfo)(nil) diff --git a/reflect.go b/reflect.go index 76a86d66..3fd14671 100644 --- a/reflect.go +++ b/reflect.go @@ -243,3 +243,22 @@ func toReprObject(rv reflect.Value) (rrv reflect.Value, err error) { rrv = mwouts[0] return } + +func toReprJSONObject(rv reflect.Value) (rrv reflect.Value, err error) { + var mwrm reflect.Value + if rv.CanAddr() { + mwrm = rv.Addr().MethodByName("MarshalAminoJSON") + } else { + mwrm = rv.MethodByName("MarshalAminoJSON") + } + mwouts := mwrm.Call(nil) + if !mwouts[1].IsNil() { + erri := mwouts[1].Interface() + if erri != nil { + err = erri.(error) + return rrv, err + } + } + rrv = mwouts[0] + return +} diff --git a/repr_test.go b/repr_test.go index 497cdb0b..c214c164 100644 --- a/repr_test.go +++ b/repr_test.go @@ -96,3 +96,54 @@ func TestMarshalAminoJSON(t *testing.T) { assert.Equal(t, f, f2) assert.Equal(t, f.a, f2.a) // In case the above doesn't check private fields? } + +type Bar struct { + a string + b int + c []*Bar + D string // exposed +} + +func (b Bar) MarshalAminoJSON() ([]pair, error) { // nolint: golint + return []pair{ + {"a", b.a}, + {"b", b.b}, + {"c", b.c}, + {"D", b.D}, + }, nil +} + +func (b *Bar) UnmarshalAminoJSON(repr []pair) error { + b.a = repr[0].get("a").(string) + b.b = repr[1].get("b").(int) + b.c = repr[2].get("c").([]*Bar) + b.D = repr[3].get("D").(string) + return nil +} + +func TestMarshalAminoJSON_Override(t *testing.T) { + + cdc := NewCodec() + cdc.RegisterInterface((*interface{})(nil), nil) + cdc.RegisterConcrete(string(""), "string", nil) + cdc.RegisterConcrete(int(0), "int", nil) + cdc.RegisterConcrete(([]*Bar)(nil), "[]*Bar", nil) + + var f = Bar{ + a: "K", + b: 2, + c: []*Bar{nil, nil, nil}, + D: "J", + } + bz, err := cdc.MarshalJSON(f) + assert.Nil(t, err) + + t.Logf("bz %X", bz) + + var f2 Bar + err = cdc.UnmarshalJSON(bz, &f2) + assert.Nil(t, err) + + assert.Equal(t, f, f2) + assert.Equal(t, f.a, f2.a) // In case the above doesn't check private fields? +}