diff --git a/overrides_test.go b/overrides_test.go index 8048204cf7..865475d3ad 100644 --- a/overrides_test.go +++ b/overrides_test.go @@ -46,10 +46,10 @@ func TestNestedOverrides(t *testing.T) { deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4) // Case 4: key:value overridden by a map - v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} - assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable - assert.Equal(10, v.Get("tom.age")) // new value should be there - deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there + v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10, "size": 4}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} + assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable + assert.Equal(10, v.Get("tom.age")) // new value should be there + deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) assert.Nil(v.Get("tom.size")) assert.Equal(10, v.Get("tom.age")) diff --git a/viper.go b/viper.go index 06610fc5a7..687713ddcb 100644 --- a/viper.go +++ b/viper.go @@ -887,6 +887,13 @@ func GetViper() *Viper { // Get returns an interface. For a specific value use one of the Get____ methods. func Get(key string) interface{} { return v.Get(key) } +func isStringMapInterface(val interface{}) bool { + vt := reflect.TypeOf(val) + return vt.Kind() == reflect.Map && + vt.Key().Kind() == reflect.String && + vt.Elem().Kind() == reflect.Interface +} + func (v *Viper) Get(key string) interface{} { lcaseKey := strings.ToLower(key) val := v.find(lcaseKey, true) @@ -894,6 +901,26 @@ func (v *Viper) Get(key string) interface{} { return nil } + // when section is partially overrided, + // make sure to return the complete map. + if isStringMapInterface(val) { + val := val.(map[string]interface{}) + prefix := lcaseKey + v.keyDelim + keys := v.AllKeys() + for _, key := range keys { + if !strings.HasPrefix(key, prefix) { + continue + } + mk := strings.TrimPrefix(key, prefix) + mk = strings.Split(mk, v.keyDelim)[0] + mv := v.Get(lcaseKey + v.keyDelim + mk) + if mv == nil { + continue + } + val[mk] = mv + } + } + if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. valType := val