From 28527bdc9eff1325ffd44ba14beba3da2e18a540 Mon Sep 17 00:00:00 2001 From: cramja Date: Mon, 21 Jun 2021 17:16:11 -0700 Subject: [PATCH] extract parser function --- cmd/namespace/validate.go | 39 ++++++++------------- cmd/namespace/validate_test.go | 33 ++++++++++++++++- internal/driver/config/namespace_watcher.go | 28 +++++++++------ 3 files changed, 63 insertions(+), 37 deletions(-) diff --git a/cmd/namespace/validate.go b/cmd/namespace/validate.go index 84a697b7e..c33be048b 100644 --- a/cmd/namespace/validate.go +++ b/cmd/namespace/validate.go @@ -3,17 +3,14 @@ package namespace import ( "bytes" "context" - "encoding/json" "fmt" "io/ioutil" - "strings" "github.com/ory/jsonschema/v3" "github.com/ory/x/cmdx" "github.com/ory/x/configx" "github.com/ory/x/jsonschemax" "github.com/ory/x/logrusx" - "github.com/pelletier/go-toml" "github.com/segmentio/objconv/yaml" "github.com/spf13/cobra" @@ -90,10 +87,16 @@ func validateNamespaceFile(cmd *cobra.Command, fn string) (*namespace.Namespace, return nil, cmdx.FailSilently(cmd) } - return validateNamespaceBytes(cmd, fn, fc, yaml.Unmarshal) + parse, err := config.GetParser(fn) + if err != nil { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Unable to infer file type from \"%s\": %+v\n", fn, err) + return nil, cmdx.FailSilently(cmd) + } + + return validateNamespaceBytes(cmd, fn, fc, parse) } -func validateNamespaceBytes(cmd *cobra.Command, name string, b []byte, parser func([]byte, interface{}) error) (*namespace.Namespace, error) { +func validateNamespaceBytes(cmd *cobra.Command, name string, b []byte, parser config.Parser) (*namespace.Namespace, error) { schema, err := getSchema(cmd) if err != nil { return nil, err @@ -112,7 +115,7 @@ func validateNamespaceBytes(cmd *cobra.Command, name string, b []byte, parser fu } var n namespace.Namespace - if err := yaml.Unmarshal(b, &n); err != nil { + if err := parser(b, &n); err != nil { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Encountered unmarshal error for \"%s\": %+v\n", name, err) return nil, cmdx.FailSilently(cmd) } @@ -127,29 +130,15 @@ func validateConfigFile(cmd *cobra.Command, fn string) error { return cmdx.FailSilently(cmd) } - var val map[string]interface{} - dot := strings.LastIndex(fn, ".") - if dot == -1 { + parse, err := config.GetParser(fn) + if err != nil { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Unable to infer file type from \"%s\": %+v\n", fn, err) return cmdx.FailSilently(cmd) } - var unmarshal func(b []byte, v interface{}) error - ext := fn[dot+1:] - switch ext { - case "yaml", "yml": - unmarshal = yaml.Unmarshal - case "json": - unmarshal = json.Unmarshal - case "toml": - unmarshal = toml.Unmarshal - default: - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Unhandled config file extension \"%s\"\n", ext) - return cmdx.FailSilently(cmd) - } - - if err := unmarshal(fc, &val); err != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Encountered unmarshal error for \"%s\": %+v\n", fn, err) + var val map[string]interface{} + if err := parse(fc, &val); err != nil { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Encountered parse error for \"%s\": %+v\n", fn, err) return cmdx.FailSilently(cmd) } diff --git a/cmd/namespace/validate_test.go b/cmd/namespace/validate_test.go index dc5c6c386..d2a706605 100644 --- a/cmd/namespace/validate_test.go +++ b/cmd/namespace/validate_test.go @@ -63,13 +63,36 @@ func TestValidateConfigNamespaces(t *testing.T) { assert.Equal(t, "Congrats, all files are valid!\n", stdOut) }) + t.Run("case=supports 3 namespace file formats", func(t *testing.T) { + dir := t.TempDir() + files := map[string]string { + dir + "/ns.yaml": "name: testns0\nid: 0", + dir + "/ns.json": "{\"name\": \"testns0\",\"id\": 0}", + dir + "/ns.toml": "name = \"testns0\"\nid = 0", + } + for fn, contents := range files { + require.NoError(t, ioutil.WriteFile(fn, []byte(contents), fileMode)) + } + + params := append([]string{"validate"}, keys(files)...) + cmd.ExecNoErr(t, params...) + }) + + t.Run("case=unknown namespace format gives error", func(t *testing.T) { + fn := t.TempDir() + "/ns.txt" + require.NoError(t, ioutil.WriteFile(fn, []byte("name: ns\nid: 0"), fileMode)) + + stdOut := cmd.ExecExpectedErr(t, "validate", fn) + assert.Contains(t, stdOut, "Unable to infer file type") + }) + t.Run("case=config passed as varg fails", func(t *testing.T) { fn := t.TempDir() + "/keto.yaml" require.NoError(t, ioutil.WriteFile(fn, []byte(configEmbeddedYaml), fileMode)) // interprets config file as namespace file when `-c` flag is not passed stdOut := cmd.ExecExpectedErr(t, "validate", fn) - assert.Regexp(t, "additionalProperties ((\"namespaces\", \"dsn\")|(\"dsn\", \"namespaces\", \"dsn\")) not allowed", stdOut) + assert.Regexp(t, "additionalProperties ((\"namespaces\", \"dsn\")|(\"dsn\", \"namespaces\")) not allowed", stdOut) }) t.Run("case=read config with invalid embedded namespace", func(t *testing.T) { @@ -117,3 +140,11 @@ func validateCommand() *cobra.Command { cmd.AddCommand(NewValidateCmd()) return cmd } + +func keys(m map[string]string) []string { + rv := make([]string, 0, len(m)) + for k, _ := range m { + rv = append(rv, k) + } + return rv +} \ No newline at end of file diff --git a/internal/driver/config/namespace_watcher.go b/internal/driver/config/namespace_watcher.go index 04c4c11e6..e716fcb15 100644 --- a/internal/driver/config/namespace_watcher.go +++ b/internal/driver/config/namespace_watcher.go @@ -138,17 +138,9 @@ func eventHandler(ctx context.Context, nw *NamespaceWatcher, done <-chan int, in func readNamespaceFile(l *logrusx.Logger, r io.Reader, source string) *NamespaceFile { var parse Parser - knownFormats := stringsx.RegisteredCases{} - switch ext := filepath.Ext(source); ext { - case knownFormats.AddCase(".yaml"), knownFormats.AddCase(".yml"): - parse = yaml.Unmarshal - case knownFormats.AddCase(".json"): - parse = json.Unmarshal - case knownFormats.AddCase(".toml"): - parse = toml.Unmarshal - default: - l.WithError(knownFormats.ToUnknownCaseErr(ext)).WithField("file_name", source).Warn("could not infer format from file extension") - return nil + parse, err := GetParser(source) + if err != nil { + l.WithError(err).WithField("file_name", source).Warn("could not infer format from file extension") } raw, err := ioutil.ReadAll(r) @@ -202,3 +194,17 @@ func (n *NamespaceWatcher) NamespaceFiles() []*NamespaceFile { } return nsfs } + +func GetParser(fn string) (Parser, error) { + knownFormats := stringsx.RegisteredCases{} + switch ext := filepath.Ext(fn); ext { + case knownFormats.AddCase(".yaml"), knownFormats.AddCase(".yml"): + return yaml.Unmarshal, nil + case knownFormats.AddCase(".json"): + return json.Unmarshal, nil + case knownFormats.AddCase(".toml"): + return toml.Unmarshal, nil + default: + return nil, knownFormats.ToUnknownCaseErr(ext) + } +}