Skip to content

Commit

Permalink
Allow defaults library to use UnmarshalText() and UnmarshalJSON() int…
Browse files Browse the repository at this point in the history
…erface to initialize values
  • Loading branch information
HRogge committed Feb 17, 2023
1 parent bf90154 commit d1d4e26
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
23 changes: 23 additions & 0 deletions defaults.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package defaults

import (
"encoding"
"encoding/json"
"errors"
"reflect"
Expand Down Expand Up @@ -61,6 +62,10 @@ func setField(field reflect.Value, defaultVal string) error {

isInitial := isInitialValue(field)
if isInitial {
if unmarshalByInterface(field, defaultVal) {
return nil
}

switch field.Kind() {
case reflect.Bool:
if val, err := strconv.ParseBool(defaultVal); err == nil {
Expand Down Expand Up @@ -194,6 +199,24 @@ func setField(field reflect.Value, defaultVal string) error {
return nil
}

func unmarshalByInterface(field reflect.Value, defaultVal string) bool {
asText, ok := field.Addr().Interface().(encoding.TextUnmarshaler)
if ok && defaultVal != "" {
// if field implements encode.TextUnmarshaler, try to use it before decode by kind
if err := asText.UnmarshalText([]byte(defaultVal)); err == nil {
return true
}
}
asJSON, ok := field.Addr().Interface().(json.Unmarshaler)
if ok && defaultVal != "" && defaultVal != "{}" && defaultVal != "[]" {
// if field implements json.Unmarshaler, try to use it before decode by kind
if err := asJSON.UnmarshalJSON([]byte(defaultVal)); err == nil {
return true
}
}
return false
}

func isInitialValue(field reflect.Value) bool {
return reflect.DeepEqual(reflect.Zero(field.Type()).Interface(), field.Interface())
}
Expand Down
43 changes: 40 additions & 3 deletions defaults_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package defaults

import (
"encoding/json"
"errors"
"net"
"reflect"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -112,9 +116,12 @@ type Sample struct {
MyMap MyMap `default:"{}"`
MySlice MySlice `default:"[]"`

StructWithJSON Struct `default:"{\"Foo\": 123}"`
StructPtrWithJSON *Struct `default:"{\"Foo\": 123}"`
MapWithJSON map[string]int `default:"{\"foo\": 123}"`
StructWithText net.IP `default:"10.0.0.1"`
StructPtrWithText *net.IP `default:"10.0.0.1"`
StructWithJSON Struct `default:"{\"Foo\": 123}"`
StructPtrWithJSON *Struct `default:"{\"Foo\": 123}"`
MapWithJSON map[string]int `default:"{\"foo\": 123}"`
TypeWithUnmarshalJSON JSONOnlyType `default:"\"one\""`

MapOfPtrStruct map[string]*Struct
MapOfStruct map[string]Struct
Expand Down Expand Up @@ -155,6 +162,24 @@ type Embedded struct {
Int int `default:"1"`
}

type JSONOnlyType int

func (j *JSONOnlyType) UnmarshalJSON(b []byte) error {
var tmp string
if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
if i, err := strconv.Atoi(tmp); err == nil {
*j = JSONOnlyType(i)
return nil
}
if tmp == "one" {
*j = 1
return nil
}
return errors.New("cannot unmarshal")
}

func TestMustSet(t *testing.T) {

t.Run("right way", func(t *testing.T) {
Expand Down Expand Up @@ -485,6 +510,14 @@ func TestInit(t *testing.T) {
}
})

t.Run("complex types with text unmarshal", func(t *testing.T) {
if !sample.StructWithText.Equal(net.ParseIP("10.0.0.1")) {
t.Errorf("it should initialize struct with text")
}
if !sample.StructPtrWithText.Equal(net.ParseIP("10.0.0.1")) {
t.Errorf("it should initialize struct with text")
}
})
t.Run("complex types with json", func(t *testing.T) {
if sample.StructWithJSON.Foo != 123 {
t.Errorf("it should initialize struct with json")
Expand All @@ -499,6 +532,10 @@ func TestInit(t *testing.T) {
t.Errorf("it should initialize slice with json")
}

if int(sample.TypeWithUnmarshalJSON) != 1 {
t.Errorf("it should initialize json unmarshaled value")
}

t.Run("invalid json", func(t *testing.T) {
if err := Set(&struct {
I []int `default:"[!]"`
Expand Down

0 comments on commit d1d4e26

Please sign in to comment.