diff --git a/defaults.go b/defaults.go index 9a1acfc..1882ac9 100644 --- a/defaults.go +++ b/defaults.go @@ -33,20 +33,22 @@ func Set(ptr interface{}) error { for i := 0; i < t.NumField(); i++ { if defaultVal := t.Field(i).Tag.Get(fieldName); defaultVal != "-" { - setField(v.Field(i), defaultVal) + if err := setField(v.Field(i), defaultVal); err != nil { + return err + } } } return nil } -func setField(field reflect.Value, defaultVal string) { +func setField(field reflect.Value, defaultVal string) error { if !field.CanSet() { - return + return nil } if !shouldInitializeField(field.Kind(), defaultVal) { - return + return nil } if isInitialValue(field) { @@ -116,20 +118,26 @@ func setField(field reflect.Value, defaultVal string) { ref := reflect.New(field.Type()) ref.Elem().Set(reflect.MakeSlice(field.Type(), 0, 0)) if defaultVal != "" && defaultVal != "[]" { - json.Unmarshal([]byte(defaultVal), ref.Interface()) + if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil { + return err + } } field.Set(ref.Elem().Convert(field.Type())) case reflect.Map: ref := reflect.New(field.Type()) ref.Elem().Set(reflect.MakeMap(field.Type())) if defaultVal != "" && defaultVal != "{}" { - json.Unmarshal([]byte(defaultVal), ref.Interface()) + if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil { + return err + } } field.Set(ref.Elem().Convert(field.Type())) case reflect.Struct: ref := reflect.New(field.Type()) if defaultVal != "" && defaultVal != "{}" { - json.Unmarshal([]byte(defaultVal), ref.Interface()) + if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil { + return err + } } field.Set(ref.Elem()) case reflect.Ptr: @@ -141,13 +149,17 @@ func setField(field reflect.Value, defaultVal string) { case reflect.Ptr: setField(field.Elem(), defaultVal) callSetter(field.Interface()) - default: + case reflect.Struct: ref := reflect.New(field.Type()) ref.Elem().Set(field) - Set(ref.Interface()) + if err := Set(ref.Interface()); err != nil { + return err + } callSetter(ref.Interface()) field.Set(ref.Elem()) } + + return nil } func isInitialValue(field reflect.Value) bool { diff --git a/defaults_test.go b/defaults_test.go index 2d238fc..8c7b897 100644 --- a/defaults_test.go +++ b/defaults_test.go @@ -117,10 +117,15 @@ func TestInit(t *testing.T) { } if err := Set(sample); err != nil { - t.Fatalf("it should return an error: %v", err) + t.Fatalf("it should not return an error: %v", err) } - if err := Set(1); err == nil { + nonPtrVal := 1 + + if err := Set(nonPtrVal); err == nil { + t.Fatalf("it should return an error when used for a non-pointer type") + } + if err := Set(&nonPtrVal); err == nil { t.Fatalf("it should return an error when used for a non-pointer type") } @@ -272,6 +277,36 @@ func TestInit(t *testing.T) { if len(sample.SliceWithJSON) == 0 || sample.SliceWithJSON[0] != "foo" { t.Errorf("it should initialize slice with json") } + + t.Run("invalid json", func(t *testing.T) { + if err := Set(&struct { + I []int `default:"[!]"` + }{}); err == nil { + t.Errorf("it should return error") + } + + if err := Set(&struct { + I map[string]int `default:"{1}"` + }{}); err == nil { + t.Errorf("it should return error") + } + + if err := Set(&struct { + S struct { + I []int + } `default:"{!}"` + }{}); err == nil { + t.Errorf("it should return error") + } + + if err := Set(&struct { + S struct { + I []int `default:"[!]"` + } + }{}); err == nil { + t.Errorf("it should return error") + } + }) }) t.Run("Setter interface", func(t *testing.T) {