Skip to content

Commit

Permalink
InitialzePFlags doesn't add flags for sub-sections (#17)
Browse files Browse the repository at this point in the history
* Allow InitializePFlags to recursively discover sub-section flags

* Add unit test
  • Loading branch information
EngHabu authored May 8, 2019
1 parent a2fb1f1 commit 7fc3cee
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
30 changes: 30 additions & 0 deletions config/tests/accessor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,36 @@ func TestAccessor_InitializePflags(t *testing.T) {
assert.Equal(t, 4, otherC.IntValue)
assert.Equal(t, []string{"default value"}, otherC.StringArrayWithDefaults)
})

t.Run(fmt.Sprintf("[%v] Sub-sections", provider(config.Options{}).ID()), func(t *testing.T) {
reg := config.NewRootSection()
sec, err := reg.RegisterSection(MyComponentSectionKey, &MyComponentConfig{})
assert.NoError(t, err)

_, err = sec.RegisterSection("nested", &OtherComponentConfig{})
assert.NoError(t, err)

v := provider(config.Options{
SearchPaths: []string{filepath.Join("testdata", "nested_config.yaml")},
RootSection: reg,
})

set := pflag.NewFlagSet("test", pflag.ExitOnError)
v.InitializePflags(set)
assert.NoError(t, set.Parse([]string{"--my-component.nested.int-val=3"}))
assert.True(t, set.Parsed())

flagValue, err := set.GetInt("my-component.nested.int-val")
assert.NoError(t, err)
assert.Equal(t, 3, flagValue)

assert.NoError(t, v.UpdateConfig(context.TODO()))
r := reg.GetSection(MyComponentSectionKey).GetConfig().(*MyComponentConfig)
assert.Equal(t, "Hello World", r.StringValue)

nested := sec.GetSection("nested").GetConfig().(*OtherComponentConfig)
assert.Equal(t, 3, nested.IntValue)
})
}
}

Expand Down
13 changes: 11 additions & 2 deletions config/viper/viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,18 @@ func (v viperAccessor) InitializePflags(cmdFlags *pflag.FlagSet) {
}

func (v viperAccessor) addSectionsPFlags(flags *pflag.FlagSet) (err error) {
for key, section := range v.rootConfig.GetSections() {
return v.addSubsectionsPFlags(flags, "", v.rootConfig)
}

func (v viperAccessor) addSubsectionsPFlags(flags *pflag.FlagSet, rootKey string, root config.Section) error {
for key, section := range root.GetSections() {
prefix := rootKey + key + keyDelim
if asPFlagProvider, ok := section.GetConfig().(config.PFlagProvider); ok {
flags.AddFlagSet(asPFlagProvider.GetPFlagSet(key + keyDelim))
flags.AddFlagSet(asPFlagProvider.GetPFlagSet(prefix))
}

if err := v.addSubsectionsPFlags(flags, prefix, section); err != nil {
return err
}
}

Expand Down

0 comments on commit 7fc3cee

Please sign in to comment.