diff --git a/config/configmap.go b/config/configmap.go index 94400d23684..61b2c363bdd 100644 --- a/config/configmap.go +++ b/config/configmap.go @@ -62,23 +62,13 @@ func (l *Map) AllKeys() []string { // Unmarshal unmarshalls the config into a struct. // Tags on the fields of the structure must be properly set. -func (l *Map) Unmarshal(rawVal interface{}) error { - decoder, err := mapstructure.NewDecoder(decoderConfig(rawVal)) - if err != nil { - return err - } - return decoder.Decode(l.ToStringMap()) +func (l *Map) Unmarshal(result interface{}) error { + return decoderConfig(l, result, false) } // UnmarshalExact unmarshalls the config into a struct, erroring if a field is nonexistent. -func (l *Map) UnmarshalExact(rawVal interface{}) error { - dc := decoderConfig(rawVal) - dc.ErrorUnused = true - decoder, err := mapstructure.NewDecoder(dc) - if err != nil { - return err - } - return decoder.Decode(l.ToStringMap()) +func (l *Map) UnmarshalExact(result interface{}) error { + return decoderConfig(l, result, true) } // Get can retrieve any value given the key to use. @@ -133,10 +123,10 @@ func (l *Map) ToStringMap() map[string]interface{} { // whose values are nil pointer structs resolved to the zero value of the target struct (see // expandNilStructPointers). A decoder created from this mapstructure.DecoderConfig will decode // its contents to the result argument. -func decoderConfig(result interface{}) *mapstructure.DecoderConfig { - return &mapstructure.DecoderConfig{ +func decoderConfig(m *Map, result interface{}, errorUnused bool) error { + dc := &mapstructure.DecoderConfig{ + ErrorUnused: errorUnused, Result: result, - Metadata: nil, TagName: "mapstructure", WeaklyTypedInput: true, DecodeHook: mapstructure.ComposeDecodeHookFunc( @@ -145,8 +135,15 @@ func decoderConfig(result interface{}) *mapstructure.DecoderConfig { mapStringToMapComponentIDHookFunc, stringToTimeDurationHookFunc, textUnmarshallerHookFunc, + unmarshallableHookFunc(result), ), } + + decoder, err := mapstructure.NewDecoder(dc) + if err != nil { + return err + } + return decoder.Decode(m.ToStringMap()) } var ( @@ -219,3 +216,27 @@ var mapStringToMapComponentIDHookFunc = func(f reflect.Type, t reflect.Type, dat } return m, nil } + +func unmarshallableHookFunc(result interface{}) mapstructure.DecodeHookFuncValue { + return func(from reflect.Value, to reflect.Value) (interface{}, error) { + if _, ok := from.Interface().(map[string]interface{}); !ok { + return from.Interface(), nil + } + + toPtr := to.Addr().Interface() + if _, ok := toPtr.(Unmarshallable); !ok { + return from.Interface(), nil + } + + // Need to ignore the top structure to avoid circular dependency. + if toPtr == result { + return from.Interface(), nil + } + + unmarshaller := reflect.New(to.Type()).Interface().(Unmarshallable) + if err := unmarshaller.Unmarshal(NewMapFromStringMap(from.Interface().(map[string]interface{}))); err != nil { + return nil, err + } + return unmarshaller, nil + } +} diff --git a/config/configmap_test.go b/config/configmap_test.go index 1d34f3df45f..ff3fe63e40a 100644 --- a/config/configmap_test.go +++ b/config/configmap_test.go @@ -109,6 +109,44 @@ func TestToStringMap(t *testing.T) { } } +type testConfig struct { + Next nextConfig `mapstructure:"next"` + Another string `mapstructure:"another"` +} + +func (tc *testConfig) Unmarshal(component *Map) error { + if err := component.UnmarshalExact(tc); err != nil { + return err + } + tc.Another = tc.Another + " is not called" + return nil +} + +type nextConfig struct { + String string `mapstructure:"string"` +} + +func (nc *nextConfig) Unmarshal(component *Map) error { + if err := component.UnmarshalExact(nc); err != nil { + return err + } + nc.String = nc.String + " is called" + return nil +} + +func TestUnmarshallable(t *testing.T) { + cfgMap := NewMapFromStringMap(map[string]interface{}{ + "next": map[string]interface{}{ + "string": "make sure this", + }, + "another": "make sure this", + }) + tc := &testConfig{} + assert.NoError(t, cfgMap.UnmarshalExact(tc)) + assert.Equal(t, "make sure this", tc.Another) + assert.Equal(t, "make sure this is called", tc.Next.String) +} + // newMapFromFile creates a new config.Map by reading the given file. func newMapFromFile(fileName string) (*Map, error) { content, err := ioutil.ReadFile(filepath.Clean(fileName))