From 221fc98d0abb603ac75d71b08f60a506751d8665 Mon Sep 17 00:00:00 2001 From: "Dustin L. Howett" Date: Mon, 20 Mar 2017 14:41:50 -0700 Subject: [PATCH] Add a custom marshaler/unmarshaler interface. Fixes #13, #17. Closes #11. Ref #3, #21. --- common_data_for_test.go | 100 +++++++++++++++++++++++++++++++ example_custom_marshaler_test.go | 52 ++++++++++++++++ marshal.go | 41 +++++++++---- plist.go | 18 ++++++ unmarshal.go | 40 +++++++++---- 5 files changed, 229 insertions(+), 22 deletions(-) create mode 100644 example_custom_marshaler_test.go diff --git a/common_data_for_test.go b/common_data_for_test.go index 020a79b..ea01737 100644 --- a/common_data_for_test.go +++ b/common_data_for_test.go @@ -1,6 +1,7 @@ package plist import ( + "errors" "math" "reflect" "time" @@ -77,6 +78,72 @@ func (b *TextMarshalingBoolViaPointer) UnmarshalText(text []byte) error { return nil } +type ArrayThatSerializesAsOneObject struct { + values []uint64 +} + +func (f ArrayThatSerializesAsOneObject) MarshalPlist() (interface{}, error) { + if len(f.values) == 1 { + return f.values[0], nil + } + return f.values, nil +} + +func (f *ArrayThatSerializesAsOneObject) UnmarshalPlist(unmarshal func(interface{}) error) error { + var ui uint64 + if err := unmarshal(&ui); err == nil { + f.values = []uint64{ui} + return nil + } + + return unmarshal(&f.values) +} + +type PlistMarshalingBoolByPointer struct { + b bool +} + +func (b *PlistMarshalingBoolByPointer) MarshalPlist() (interface{}, error) { + if b.b { + return int64(-1), nil + } + return int64(-2), nil +} + +func (b *PlistMarshalingBoolByPointer) UnmarshalPlist(unmarshal func(interface{}) error) error { + var val int64 + err := unmarshal(&val) + if err != nil { + return err + } + + b.b = val == -1 + return nil +} + +type BothMarshaler struct{} + +func (b *BothMarshaler) MarshalPlist() (interface{}, error) { + return map[string]string{"a": "b"}, nil +} + +func (b *BothMarshaler) MarshalText() ([]byte, error) { + return []byte("shouldn't see this"), nil +} + +type BothUnmarshaler struct { + Blah int64 `plist:"blah,omitempty"` +} + +func (b *BothUnmarshaler) UnmarshalPlist(unmarshal func(interface{}) error) error { + // no error + return nil +} + +func (b *BothUnmarshaler) UnmarshalText(text []byte) error { + return errors.New("shouldn't hit this") +} + var xmlPreamble string = ` ` @@ -588,6 +655,39 @@ var tests = []TestData{ U: 1024, }, }, + { + Name: "Custom Marshaller/Unmarshaller by Value", + Data: []ArrayThatSerializesAsOneObject{ + ArrayThatSerializesAsOneObject{[]uint64{100}}, + ArrayThatSerializesAsOneObject{[]uint64{2, 4, 6, 8}}, + }, + Expected: map[int][]byte{ + GNUStepFormat: []byte(`(<*I100>,(<*I2>,<*I4>,<*I6>,<*I8>,),)`), + }, + }, + { + Name: "Custom Marshaller/Unmarshaller by Pointer", + Data: &PlistMarshalingBoolByPointer{true}, + Expected: map[int][]byte{ + OpenStepFormat: []byte(`-1`), + GNUStepFormat: []byte(`<*I-1>`), + }, + }, + { + Name: "Type implementing both Text and Plist Marshaler", + Data: &BothMarshaler{}, + Expected: map[int][]byte{ + GNUStepFormat: []byte(`{a=b;}`), + }, + }, + { + Name: "Type implementing both Text and Plist Unmarshaler", + Data: &BothUnmarshaler{int64(1024)}, + Expected: map[int][]byte{ + GNUStepFormat: []byte(`{blah=<*I1024>;}`), + }, + DecodeData: &BothUnmarshaler{int64(0)}, + }, } type EverythingTestData struct { diff --git a/example_custom_marshaler_test.go b/example_custom_marshaler_test.go new file mode 100644 index 0000000..d8e8f69 --- /dev/null +++ b/example_custom_marshaler_test.go @@ -0,0 +1,52 @@ +package plist_test + +import ( + "encoding/base64" + "fmt" + + "howett.net/plist" +) + +type Base64String string + +func (e Base64String) MarshalPlist() (interface{}, error) { + return base64.StdEncoding.EncodeToString([]byte(e)), nil +} + +func (e *Base64String) UnmarshalPlist(unmarshal func(interface{}) error) error { + var b64 string + if err := unmarshal(&b64); err != nil { + return err + } + + bytes, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return err + } + + *e = Base64String(bytes) + return nil +} + +func Example() { + s := Base64String("Dustin") + + data, err := plist.Marshal(&s, plist.OpenStepFormat) + if err != nil { + panic(err) + } + + fmt.Println("Property List:", string(data)) + + var decoded Base64String + _, err = plist.Unmarshal(data, &decoded) + if err != nil { + panic(err) + } + + fmt.Println("Raw Data:", string(decoded)) + + // Output: + // Property List: RHVzdGlu + // Raw Data: Dustin +} diff --git a/marshal.go b/marshal.go index e32dea6..b861172 100644 --- a/marshal.go +++ b/marshal.go @@ -25,10 +25,33 @@ func isEmptyValue(v reflect.Value) bool { } var ( - textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + plistMarshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() ) +func implementsInterface(val reflect.Value, interfaceType reflect.Type) (interface{}, bool) { + if val.CanInterface() && val.Type().Implements(interfaceType) { + return val.Interface(), true + } + + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(interfaceType) { + return pv.Interface(), true + } + } + return nil, false +} + +func (p *Encoder) marshalPlistInterface(marshalable Marshaler) cfValue { + value, err := marshalable.MarshalPlist() + if err != nil { + panic(err) + } + return p.marshal(reflect.ValueOf(value)) +} + // marshalTextInterface marshals a TextMarshaler to a plist string. func (p *Encoder) marshalTextInterface(marshalable encoding.TextMarshaler) cfValue { s, err := marshalable.MarshalText() @@ -68,6 +91,10 @@ func (p *Encoder) marshal(val reflect.Value) cfValue { return nil } + if receiver, can := implementsInterface(val, plistMarshalerType); can { + return p.marshalPlistInterface(receiver.(Marshaler)) + } + // time.Time implements TextMarshaler, but we need to store it in RFC3339 if val.Type() == timeType { return p.marshalTime(val) @@ -80,14 +107,8 @@ func (p *Encoder) marshal(val reflect.Value) cfValue { } // Check for text marshaler. - if val.CanInterface() && val.Type().Implements(textMarshalerType) { - return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler)) - } - if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { - return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler)) - } + if receiver, can := implementsInterface(val, textMarshalerType); can { + return p.marshalTextInterface(receiver.(encoding.TextMarshaler)) } // Descend into pointers or interfaces diff --git a/plist.go b/plist.go index e1484de..a4078b2 100644 --- a/plist.go +++ b/plist.go @@ -65,3 +65,21 @@ func (e plistParseError) Error() string { // // UIDs cannot be serialized in OpenStepFormat or GNUStepFormat property lists. type UID uint64 + +// Marshaler is the interface implemented by types that can marshal themselves into valid +// property list objects. The returned value is marshaled in place of the original value +// implementing Marshaler +// +// If an error is returned by MarshalPlist, marshaling stops and the error is returned. +type Marshaler interface { + MarshalPlist() (interface{}, error) +} + +// Unmarshaler is the interface implemented by types that can unmarshal themselves from +// property list objects. The UnmarshalPlist method receives a function that may +// be called to unmarshal the original property list value into a field or variable. +// +// It is safe to call the unmarshal function more than once. +type Unmarshaler interface { + UnmarshalPlist(unmarshal func(interface{}) error) error +} diff --git a/unmarshal.go b/unmarshal.go index 3dcda34..c3cb90f 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -4,6 +4,7 @@ import ( "encoding" "fmt" "reflect" + "runtime" "time" ) @@ -17,14 +18,34 @@ func (u *incompatibleDecodeTypeError) Error() string { } var ( - textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() - uidType = reflect.TypeOf(UID(0)) + plistUnmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + uidType = reflect.TypeOf(UID(0)) ) func isEmptyInterface(v reflect.Value) bool { return v.Kind() == reflect.Interface && v.NumMethod() == 0 } +func (p *Decoder) unmarshalPlistInterface(pval cfValue, unmarshalable Unmarshaler) { + err := unmarshalable.UnmarshalPlist(func(i interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + p.unmarshal(pval, reflect.ValueOf(i)) + return + }) + + if err != nil { + panic(err) + } +} + func (p *Decoder) unmarshalTextInterface(pval cfString, unmarshalable encoding.TextUnmarshaler) { err := unmarshalable.UnmarshalText([]byte(pval)) if err != nil { @@ -98,20 +119,15 @@ func (p *Decoder) unmarshal(pval cfValue, val reflect.Value) { panic(incompatibleTypeError) } - if val.CanInterface() && val.Type().Implements(textUnmarshalerType) && val.Type() != timeType { - if str, ok := pval.(cfString); ok { - p.unmarshalTextInterface(str, val.Interface().(encoding.TextUnmarshaler)) - } else { - panic(incompatibleTypeError) - } + if receiver, can := implementsInterface(val, plistUnmarshalerType); can { + p.unmarshalPlistInterface(pval, receiver.(Unmarshaler)) return } - if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) && val.Type() != timeType { + if val.Type() != timeType { + if receiver, can := implementsInterface(val, textUnmarshalerType); can { if str, ok := pval.(cfString); ok { - p.unmarshalTextInterface(str, pv.Interface().(encoding.TextUnmarshaler)) + p.unmarshalTextInterface(str, receiver.(encoding.TextUnmarshaler)) } else { panic(incompatibleTypeError) }