Skip to content

Commit

Permalink
chore: structure support encoding.TextUnmarshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Sep 19, 2024
1 parent 3676d1b commit f020b20
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 27 deletions.
83 changes: 56 additions & 27 deletions common/structure/structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package structure
// references: https://github.com/mitchellh/mapstructure

import (
"encoding"
"encoding/base64"
"fmt"
"reflect"
Expand Down Expand Up @@ -86,35 +87,41 @@ func (d *Decoder) Decode(src map[string]any, dst any) error {
}

func (d *Decoder) decode(name string, data any, val reflect.Value) error {
kind := val.Kind()
switch {
case isInt(kind):
return d.decodeInt(name, data, val)
case isUint(kind):
return d.decodeUint(name, data, val)
case isFloat(kind):
return d.decodeFloat(name, data, val)
}
switch kind {
case reflect.Pointer:
if val.IsNil() {
for {
kind := val.Kind()
if kind == reflect.Pointer && val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
return d.decode(name, data, val.Elem())
case reflect.String:
return d.decodeString(name, data, val)
case reflect.Bool:
return d.decodeBool(name, data, val)
case reflect.Slice:
return d.decodeSlice(name, data, val)
case reflect.Map:
return d.decodeMap(name, data, val)
case reflect.Interface:
return d.setInterface(name, data, val)
case reflect.Struct:
return d.decodeStruct(name, data, val)
default:
return fmt.Errorf("type %s not support", val.Kind().String())
if ok, err := d.decodeTextUnmarshaller(name, data, val); ok {
return err
}
switch {
case isInt(kind):
return d.decodeInt(name, data, val)
case isUint(kind):
return d.decodeUint(name, data, val)
case isFloat(kind):
return d.decodeFloat(name, data, val)
}
switch kind {
case reflect.Pointer:
val = val.Elem()
continue
case reflect.String:
return d.decodeString(name, data, val)
case reflect.Bool:
return d.decodeBool(name, data, val)
case reflect.Slice:
return d.decodeSlice(name, data, val)
case reflect.Map:
return d.decodeMap(name, data, val)
case reflect.Interface:
return d.setInterface(name, data, val)
case reflect.Struct:
return d.decodeStruct(name, data, val)
default:
return fmt.Errorf("type %s not support", val.Kind().String())
}
}
}

Expand Down Expand Up @@ -553,3 +560,25 @@ func (d *Decoder) setInterface(name string, data any, val reflect.Value) (err er
val.Set(dataVal)
return nil
}

func (d *Decoder) decodeTextUnmarshaller(name string, data any, val reflect.Value) (bool, error) {
if !val.CanAddr() {
return false, nil
}
valAddr := val.Addr()
if !valAddr.CanInterface() {
return false, nil
}
unmarshaller, ok := valAddr.Interface().(encoding.TextUnmarshaler)
if !ok {
return false, nil
}
var str string
if err := d.decodeString(name, data, reflect.Indirect(reflect.ValueOf(&str))); err != nil {
return false, err
}
if err := unmarshaller.UnmarshalText([]byte(str)); err != nil {
return true, fmt.Errorf("cannot parse '%s' as %s: %s", name, val.Type(), err)
}
return true, nil
}
88 changes: 88 additions & 0 deletions common/structure/structure_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package structure

import (
"strconv"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -179,3 +180,90 @@ func TestStructure_SliceNilValueComplex(t *testing.T) {
err = decoder.Decode(rawMap, ss)
assert.NotNil(t, err)
}

func TestStructure_SliceCap(t *testing.T) {
rawMap := map[string]any{
"foo": []string{},
}

s := &struct {
Foo []string `test:"foo,omitempty"`
Bar []string `test:"bar,omitempty"`
}{}

err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.NotNil(t, s.Foo) // structure's Decode will ensure value not nil when input has value even it was set an empty array
assert.Nil(t, s.Bar)
}

func TestStructure_Base64(t *testing.T) {
rawMap := map[string]any{
"foo": "AQID",
}

s := &struct {
Foo []byte `test:"foo"`
}{}

err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, []byte{1, 2, 3}, s.Foo)
}

func TestStructure_Pointer(t *testing.T) {
rawMap := map[string]any{
"foo": "foo",
}

s := &struct {
Foo *string `test:"foo,omitempty"`
Bar *string `test:"bar,omitempty"`
}{}

err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.NotNil(t, s.Foo)
assert.Equal(t, "foo", *s.Foo)
assert.Nil(t, s.Bar)
}

type num struct {
a int
}

func (n *num) UnmarshalText(text []byte) (err error) {
n.a, err = strconv.Atoi(string(text))
return
}

func TestStructure_TextUnmarshaller(t *testing.T) {
rawMap := map[string]any{
"num": "255",
"num_p": "127",
}

s := &struct {
Num num `test:"num"`
NumP *num `test:"num_p"`
}{}

err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, 255, s.Num.a)
assert.NotNil(t, s.NumP)
assert.Equal(t, s.NumP.a, 127)

// test WeaklyTypedInput
rawMap["num"] = 256
err = decoder.Decode(rawMap, s)
assert.NotNilf(t, err, "should throw error: %#v", s)
err = weakTypeDecoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, 256, s.Num.a)

// test invalid input
rawMap["num_p"] = "abc"
err = decoder.Decode(rawMap, s)
assert.NotNilf(t, err, "should throw error: %#v", s)
}

0 comments on commit f020b20

Please sign in to comment.