From dfc9c9b38554b28bf172c11b4545b8f1cb0d5f31 Mon Sep 17 00:00:00 2001 From: Bogdan Drutu Date: Mon, 21 Nov 2022 18:00:14 -0800 Subject: [PATCH] Add recursive validation check for configs Signed-off-by: Bogdan Drutu --- component/config.go | 60 +++++++++++++++- component/config_test.go | 143 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 component/config_test.go diff --git a/component/config.go b/component/config.go index 619e9b45857..d7b44c58104 100644 --- a/component/config.go +++ b/component/config.go @@ -15,6 +15,10 @@ package component // import "go.opentelemetry.io/collector/component" import ( + "reflect" + + "go.uber.org/multierr" + "go.opentelemetry.io/collector/confmap" ) @@ -33,7 +37,61 @@ type ConfigValidator interface { // ValidateConfig validates a config, by doing this: // - Call Validate on the config itself if the config implements ConfigValidator. func ValidateConfig(cfg Config) error { - validator, ok := cfg.(ConfigValidator) + return validate(reflect.ValueOf(cfg)) +} + +func validate(v reflect.Value) error { + // Validate the value itself. + switch v.Kind() { + case reflect.Ptr: + // Consider valid any nil value. + if v.IsNil() { + return nil + } + return validate(v.Elem()) + case reflect.Struct: + var errs error + // If not addressable, then create a new *V pointer and set the value to current v. + if !v.CanAddr() { + pv := reflect.New(reflect.PtrTo(v.Type()).Elem()) + pv.Elem().Set(v) + v = pv.Elem() + } + errs = multierr.Append(errs, callValidate(v.Addr())) + // Reflect on the pointed data and check each of its fields. + for i := 0; i < v.NumField(); i++ { + if !v.Type().Field(i).IsExported() { + continue + } + errs = multierr.Append(errs, validate(v.Field(i))) + } + return errs + case reflect.Slice, reflect.Array: + var errs error + // Reflect on the pointed data and check each of its fields. + for i := 0; i < v.Len(); i++ { + errs = multierr.Append(errs, validate(v.Index(i))) + } + return errs + case reflect.Map: + var errs error + iter := v.MapRange() + for iter.Next() { + errs = multierr.Append(errs, validate(iter.Key())) + errs = multierr.Append(errs, validate(iter.Value())) + } + return errs + } + return nil +} + +func callValidate(v reflect.Value) error { + if !v.CanInterface() { + // Cannot retrieve the "Interface" just return, otherwise Interface() will panic + return nil + } + // If implements ConfigValidator then call Validate. + validator, ok := v.Interface().(ConfigValidator) if !ok { return nil } diff --git a/component/config_test.go b/component/config_test.go new file mode 100644 index 00000000000..3531be539d1 --- /dev/null +++ b/component/config_test.go @@ -0,0 +1,143 @@ +package component + +import ( + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type configChildStruct struct { + Child errConfig + ChildPtr *errConfig +} + +type configChildSlice struct { + Child []errConfig + ChildPtr []*errConfig +} + +type configChildMapValue struct { + Child map[string]errConfig + ChildPtr map[string]*errConfig +} + +type configChildMapKey struct { + Child map[errType]string + ChildPtr map[*errType]string +} + +type configChildTypeDef struct { + Child errType + ChildPtr *errType +} + +type errConfig struct { + err error +} + +func (e *errConfig) Validate() error { + return e.err +} + +type errType string + +func (e errType) Validate() error { + return errors.New(string(e)) +} + +func newErrType(etStr string) *errType { + et := errType(etStr) + return &et +} + +func TestValidateConfig(t *testing.T) { + var tests = []struct { + name string + cfg any + expected error + }{ + { + name: "child struct", + cfg: configChildStruct{Child: errConfig{err: errors.New("child struct")}}, + expected: errors.New("child struct"), + }, + { + name: "pointer child struct", + cfg: &configChildStruct{Child: errConfig{err: errors.New("pointer child struct")}}, + expected: errors.New("pointer child struct"), + }, + { + name: "child struct pointer", + cfg: &configChildStruct{ChildPtr: &errConfig{err: errors.New("child struct pointer")}}, + expected: errors.New("child struct pointer"), + }, + { + name: "child slice", + cfg: configChildSlice{Child: []errConfig{{}, {err: errors.New("child slice")}}}, + expected: errors.New("child slice"), + }, + { + name: "pointer child slice", + cfg: &configChildSlice{Child: []errConfig{{}, {err: errors.New("pointer child slice")}}}, + expected: errors.New("pointer child slice"), + }, + { + name: "child slice pointer", + cfg: &configChildSlice{ChildPtr: []*errConfig{{}, {err: errors.New("child slice pointer")}}}, + expected: errors.New("child slice pointer"), + }, + { + name: "child map value", + cfg: configChildMapValue{Child: map[string]errConfig{"test": {err: errors.New("child map")}}}, + expected: errors.New("child map"), + }, + { + name: "pointer child map value", + cfg: &configChildMapValue{Child: map[string]errConfig{"test": {err: errors.New("pointer child map")}}}, + expected: errors.New("pointer child map"), + }, + { + name: "child map value pointer", + cfg: &configChildMapValue{ChildPtr: map[string]*errConfig{"test": {err: errors.New("child map pointer")}}}, + expected: errors.New("child map pointer"), + }, + { + name: "child map key", + cfg: configChildMapKey{Child: map[errType]string{"child map key": ""}}, + expected: errors.New("child map key"), + }, + { + name: "pointer child map key", + cfg: &configChildMapKey{Child: map[errType]string{"pointer child map key": ""}}, + expected: errors.New("pointer child map key"), + }, + { + name: "child map key pointer", + cfg: &configChildMapKey{ChildPtr: map[*errType]string{newErrType("child map key pointer"): ""}}, + expected: errors.New("child map key pointer"), + }, + { + name: "child type", + cfg: configChildTypeDef{Child: "child type"}, + expected: errors.New("child type"), + }, + { + name: "pointer child type", + cfg: &configChildTypeDef{Child: "pointer child type"}, + expected: errors.New("pointer child type"), + }, + { + name: "child type pointer", + cfg: &configChildTypeDef{ChildPtr: newErrType("child type pointer")}, + expected: errors.New("child type pointer"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, validate(reflect.ValueOf(tt.cfg))) + }) + } +}