Skip to content

Commit

Permalink
Add strict mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyprinus12138 committed Apr 7, 2024
1 parent 44e6468 commit 55b100d
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ func (v *Viper) WatchConfig() {
(event.Has(fsnotify.Write) || event.Has(fsnotify.Create))) ||
(currentConfigFile != "" && currentConfigFile != realConfigFile) {
realConfigFile = currentConfigFile
err := v.ReadInConfig()
err := v.ReadInConfig(false)
if err != nil {
log.Printf("error reading config file: %v\n", err)
}
Expand Down Expand Up @@ -512,8 +512,10 @@ func (v *Viper) WatchConfig() {
initWG.Wait() // make sure that the go routine above fully ended before returning
}

// updateRegisteredConfig validate the registered config items in the new config, notify user with the hook functions.
func (v *Viper) updateRegisteredConfig(newConfig map[string]interface{}) (result map[string]interface{}) {
// updateRegisteredConfig validate the registered config items in the new config,
// notify user with the hook functions.
// Only when strict is true, err can be returned, only the last error will be returned.
func (v *Viper) updateRegisteredConfig(newConfig map[string]interface{}, strict bool) (result map[string]interface{}, err error) {
result = make(map[string]interface{})

for key, config := range v.registered {
Expand All @@ -522,6 +524,7 @@ func (v *Viper) updateRegisteredConfig(newConfig map[string]interface{}) (result
// Check exist
if newValue == nil && !config.CanBeNil {
newConfig[key] = oldValue
err = errors.New(fmt.Sprintf("%s is nil or not found", key))
if config.OnUpdateFailed != nil {
config.OnUpdateFailed(&Event{
name: v.name,
Expand All @@ -540,6 +543,7 @@ func (v *Viper) updateRegisteredConfig(newConfig map[string]interface{}) (result
if err != nil {
newConfig[key] = oldValue
if config.OnUpdateFailed != nil {
err = errors.New(fmt.Sprintf("%s is with a invalid type", key))
config.OnUpdateFailed(&Event{
name: v.name,
old: oldValue,
Expand All @@ -557,6 +561,7 @@ func (v *Viper) updateRegisteredConfig(newConfig map[string]interface{}) (result
if config.Validator != nil && !config.Validator(config.Schema) {
newConfig[key] = oldValue
if config.OnUpdateFailed != nil {
err = errors.New(fmt.Sprintf("%s validation failed", key))
config.OnUpdateFailed(&Event{
name: v.name,
old: oldValue,
Expand All @@ -577,7 +582,11 @@ func (v *Viper) updateRegisteredConfig(newConfig map[string]interface{}) (result
})
}
}
return result

if strict {
return result, err
}
return result, nil
}

// SetConfigFile explicitly defines the path, name and extension of the config file.
Expand Down Expand Up @@ -1633,9 +1642,13 @@ func (v *Viper) Set(key string, value interface{}) {

// ReadInConfig will discover and load the configuration file from disk
// and key/value stores, searching in one of the defined paths.
func ReadInConfig() error { return v.ReadInConfig() }
func ReadInConfig() error { return v.ReadInConfig(false) }

// ReadInConfigStrict register the configs with strict manner: any validation
// error will result in failure.
func ReadInConfigStrict() error { return v.ReadInConfig(true) }

func (v *Viper) ReadInConfig() error {
func (v *Viper) ReadInConfig(strict bool) error {
v.logger.Info("attempting to read in config file")
filename, err := v.getConfigFile()
if err != nil {
Expand All @@ -1657,7 +1670,10 @@ func (v *Viper) ReadInConfig() error {
if err != nil {
return err
}
config = v.updateRegisteredConfig(config)
config, err = v.updateRegisteredConfig(config, strict)
if strict && err != nil {
return err
}

v.config = config
return nil
Expand Down Expand Up @@ -2014,7 +2030,7 @@ func (v *Viper) getRemoteConfig(provider RemoteProvider) (map[string]interface{}
}
kvStore := make(map[string]interface{})
err = v.unmarshalReader(reader, kvStore)
v.kvstore = v.updateRegisteredConfig(kvStore)
v.kvstore, _ = v.updateRegisteredConfig(kvStore, false)
return v.kvstore, err
}

Expand All @@ -2033,7 +2049,7 @@ func (v *Viper) watchKeyValueConfigOnChannel() error {
reader := bytes.NewReader(b.Value)
kvStore := make(map[string]interface{})
v.unmarshalReader(reader, kvStore)
v.kvstore = v.updateRegisteredConfig(kvStore)
v.kvstore, _ = v.updateRegisteredConfig(kvStore, false)
}
}(respc)
return nil
Expand Down

0 comments on commit 55b100d

Please sign in to comment.