diff --git a/loader/loader.go b/loader/loader.go index 6b80ab23..19cffc34 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -17,8 +17,10 @@ package loader import ( + "bytes" "context" "fmt" + "io" "os" paths "path" "path/filepath" @@ -166,7 +168,9 @@ func WithProfiles(profiles []string) func(*Options) { // ParseYAML reads the bytes from a file, parses the bytes into a mapping // structure, and returns it. func ParseYAML(source []byte) (map[string]interface{}, error) { - m, _, err := parseYAML(source) + r := bytes.NewReader(source) + decoder := yaml.NewDecoder(r) + m, _, err := parseYAML(decoder) return m, err } @@ -179,11 +183,11 @@ type PostProcessor interface { Apply(config *types.Config) error } -func parseYAML(source []byte) (map[string]interface{}, PostProcessor, error) { +func parseYAML(decoder *yaml.Decoder) (map[string]interface{}, PostProcessor, error) { var cfg interface{} processor := ResetProcessor{target: &cfg} - if err := yaml.Unmarshal(source, &processor); err != nil { + if err := decoder.Decode(&processor); err != nil { return nil, nil, err } stringMap, ok := cfg.(map[string]interface{}) @@ -251,67 +255,86 @@ func load(ctx context.Context, configDetails types.ConfigDetails, opts *Options, loaded = append(loaded, mainFile) includeRefs := make(map[string][]types.IncludeConfig) - for i, file := range configDetails.ConfigFiles { + first := true + for _, file := range configDetails.ConfigFiles { var postProcessor PostProcessor configDict := file.Config - if configDict == nil { - if len(file.Content) == 0 { - content, err := os.ReadFile(file.Filename) - if err != nil { - return nil, err + + processYaml := func() error { + if !opts.SkipValidation { + if err := schema.Validate(configDict); err != nil { + return fmt.Errorf("validating %s: %w", file.Filename, err) } - file.Content = content } - dict, p, err := parseConfig(file.Content, opts) + + configDict = groupXFieldsIntoExtensions(configDict) + + cfg, err := loadSections(ctx, file.Filename, configDict, configDetails, opts) if err != nil { - return nil, fmt.Errorf("parsing %s: %w", file.Filename, err) + return err } - configDict = dict - file.Config = dict - configDetails.ConfigFiles[i] = file - postProcessor = p - } - if !opts.SkipValidation { - if err := schema.Validate(configDict); err != nil { - return nil, fmt.Errorf("validating %s: %w", file.Filename, err) + if !opts.SkipInclude { + var included map[string][]types.IncludeConfig + cfg, included, err = loadInclude(ctx, file.Filename, configDetails, cfg, opts, loaded) + if err != nil { + return err + } + for k, v := range included { + includeRefs[k] = append(includeRefs[k], v...) + } } - } - - configDict = groupXFieldsIntoExtensions(configDict) - - cfg, err := loadSections(ctx, file.Filename, configDict, configDetails, opts) - if err != nil { - return nil, err - } - if !opts.SkipInclude { - var included map[string][]types.IncludeConfig - cfg, included, err = loadInclude(ctx, file.Filename, configDetails, cfg, opts, loaded) + if first { + first = false + model = cfg + return nil + } + merged, err := merge([]*types.Config{model, cfg}) if err != nil { - return nil, err + return err } - for k, v := range included { - includeRefs[k] = append(includeRefs[k], v...) + if postProcessor != nil { + err = postProcessor.Apply(merged) + if err != nil { + return err + } } + model = merged + return nil } - if i == 0 { - model = cfg - continue - } + if configDict == nil { + if len(file.Content) == 0 { + content, err := os.ReadFile(file.Filename) + if err != nil { + return nil, err + } + file.Content = content + } - merged, err := merge([]*types.Config{model, cfg}) - if err != nil { - return nil, err - } - if postProcessor != nil { - err = postProcessor.Apply(merged) - if err != nil { + r := bytes.NewReader(file.Content) + decoder := yaml.NewDecoder(r) + for { + dict, p, err := parseConfig(decoder, opts) + if err != nil { + if err != io.EOF { + return nil, fmt.Errorf("parsing %s: %w", file.Filename, err) + } + break + } + configDict = dict + postProcessor = p + + if err := processYaml(); err != nil { + return nil, err + } + } + } else { + if err := processYaml(); err != nil { return nil, err } } - model = merged } project := &types.Project{ @@ -449,8 +472,8 @@ func NormalizeProjectName(s string) string { return strings.TrimLeft(s, "_-") } -func parseConfig(b []byte, opts *Options) (map[string]interface{}, PostProcessor, error) { - yml, postProcessor, err := parseYAML(b) +func parseConfig(decoder *yaml.Decoder, opts *Options) (map[string]interface{}, PostProcessor, error) { + yml, postProcessor, err := parseYAML(decoder) if err != nil { return nil, nil, err } @@ -757,7 +780,10 @@ func loadServiceWithExtends(ctx context.Context, filename, name string, services return nil, err } - baseFile, _, err := parseConfig(b, opts) + r := bytes.NewReader(b) + decoder := yaml.NewDecoder(r) + + baseFile, _, err := parseConfig(decoder, opts) if err != nil { return nil, err } diff --git a/loader/loader_test.go b/loader/loader_test.go index 6dad1e78..07e6baf8 100644 --- a/loader/loader_test.go +++ b/loader/loader_test.go @@ -2737,3 +2737,19 @@ services: }) assert.ErrorContains(t, err, "Circular reference") } + +func TestLoadMulmtiDocumentYaml(t *testing.T) { + project, err := loadYAML(` +name: load-multi-docs +services: + test: + image: nginx:latest +--- +services: + test: + image: nginx:override + +`) + assert.NilError(t, err) + assert.Equal(t, project.Services[0].Image, "nginx:override") +}