diff --git a/internals/overlord/servstate/manager.go b/internals/overlord/servstate/manager.go index 25ab0926..1df9b820 100644 --- a/internals/overlord/servstate/manager.go +++ b/internals/overlord/servstate/manager.go @@ -174,6 +174,10 @@ func (m *ServiceManager) updatePlanLayers(layers []*plan.Layer) error { Checks: combined.Checks, LogTargets: combined.LogTargets, } + err = p.Validate() + if err != nil { + return err + } m.updatePlan(p) return nil } @@ -541,6 +545,11 @@ func (m *ServiceManager) SetServiceArgs(serviceArgs map[string][]string) error { } } + err = newLayer.Validate() + if err != nil { + return err + } + return m.appendLayer(newLayer) } diff --git a/internals/overlord/servstate/manager_test.go b/internals/overlord/servstate/manager_test.go index b01c4850..4c7763c7 100644 --- a/internals/overlord/servstate/manager_test.go +++ b/internals/overlord/servstate/manager_test.go @@ -752,6 +752,17 @@ services: command: /bin/b `[1:]) s.planLayersHasLen(c, s.manager, 3) + + // Make sure that layer validation is happening. + layer, err = plan.ParseLayer(0, "label4", []byte(` +checks: + bad-check: + override: replace + level: invalid + tcp: + port: 8080 +`)) + c.Check(err, ErrorMatches, `(?s).*plan check.*must be "alive" or "ready".*`) } func (s *S) TestSetServiceArgs(c *C) { diff --git a/internals/plan/plan.go b/internals/plan/plan.go index 6616c326..d3a58fec 100644 --- a/internals/plan/plan.go +++ b/internals/plan/plan.go @@ -552,6 +552,9 @@ func (e *FormatError) Error() string { // CombineLayers combines the given layers into a single layer, with the later // layers overriding earlier ones. +// Neither the individual layers nor the combined layer are validated here - the +// caller should have validated the individual layers prior to calling, and +// validate the combined output if required. func CombineLayers(layers ...*Layer) (*Layer, error) { combined := &Layer{ Services: make(map[string]*Service), @@ -641,87 +644,201 @@ func CombineLayers(layers ...*Layer) (*Layer, error) { } } - // Ensure fields in combined layers validate correctly (and set defaults). - for name, service := range combined.Services { - if service.Command == "" { - return nil, &FormatError{ - Message: fmt.Sprintf(`plan must define "command" for service %q`, name), + // Set defaults where required. + for _, service := range combined.Services { + if !service.BackoffDelay.IsSet { + service.BackoffDelay.Value = defaultBackoffDelay + } + if !service.BackoffFactor.IsSet { + service.BackoffFactor.Value = defaultBackoffFactor + } + if !service.BackoffLimit.IsSet { + service.BackoffLimit.Value = defaultBackoffLimit + } + } + + for _, check := range combined.Checks { + if !check.Period.IsSet { + check.Period.Value = defaultCheckPeriod + } + if !check.Timeout.IsSet { + check.Timeout.Value = defaultCheckTimeout + } + if check.Timeout.Value > check.Period.Value { + // The effective timeout will be the period, so make that clear. + // `.IsSet` remains false so that the capped value does not appear + // in the combined plan output - and it's not *user* set - the + // effective default timeout is the minimum of (check.Period.Value, + // default timeout). + check.Timeout.Value = check.Period.Value + } + if check.Threshold == 0 { + // Default number of failures in a row before check triggers + // action, default is >1 to avoid flapping due to glitches. For + // what it's worth, Kubernetes probes uses a default of 3 too. + check.Threshold = defaultCheckThreshold + } + } + + return combined, nil +} + +// Validate checks that the layer is valid. It returns nil if all the checks pass, or +// an error if there are validation errors. +// See also Plan.Validate, which does additional checks based on the combined +// layers. +func (layer *Layer) Validate() error { + for name, service := range layer.Services { + if name == "" { + return &FormatError{ + Message: fmt.Sprintf("cannot use empty string as service name"), + } + } + if name == "pebble" { + // Disallow service name "pebble" to avoid ambiguity (for example, + // in log output). + return &FormatError{ + Message: fmt.Sprintf("cannot use reserved service name %q", name), + } + } + // Deprecated service names + if name == "all" || name == "default" || name == "none" { + logger.Noticef("Using keyword %q as a service name is deprecated", name) + } + if strings.HasPrefix(name, "-") { + return &FormatError{ + Message: fmt.Sprintf(`cannot use service name %q: starting with "-" not allowed`, name), + } + } + if service == nil { + return &FormatError{ + Message: fmt.Sprintf("service object cannot be null for service %q", name), } } _, _, err := service.ParseCommand() if err != nil { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf("plan service %q command invalid: %v", name, err), } } if !validServiceAction(service.OnSuccess, ActionFailureShutdown) { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf("plan service %q on-success action %q invalid", name, service.OnSuccess), } } if !validServiceAction(service.OnFailure, ActionSuccessShutdown) { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf("plan service %q on-failure action %q invalid", name, service.OnFailure), } } for _, action := range service.OnCheckFailure { if !validServiceAction(action, ActionSuccessShutdown) { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf("plan service %q on-check-failure action %q invalid", name, action), } } } - if !service.BackoffDelay.IsSet { - service.BackoffDelay.Value = defaultBackoffDelay - } - if !service.BackoffFactor.IsSet { - service.BackoffFactor.Value = defaultBackoffFactor - } else if service.BackoffFactor.Value < 1 { - return nil, &FormatError{ + if service.BackoffFactor.IsSet && service.BackoffFactor.Value < 1 { + return &FormatError{ Message: fmt.Sprintf("plan service %q backoff-factor must be 1.0 or greater, not %g", name, service.BackoffFactor.Value), } } - if !service.BackoffLimit.IsSet { - service.BackoffLimit.Value = defaultBackoffLimit - } - } - for name, check := range combined.Checks { + for name, check := range layer.Checks { + if name == "" { + return &FormatError{ + Message: fmt.Sprintf("cannot use empty string as check name"), + } + } + if check == nil { + return &FormatError{ + Message: fmt.Sprintf("check object cannot be null for check %q", name), + } + } + if name == "" { + return &FormatError{ + Message: fmt.Sprintf("cannot use empty string as log target name"), + } + } if check.Level != UnsetLevel && check.Level != AliveLevel && check.Level != ReadyLevel { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`plan check %q level must be "alive" or "ready"`, name), } } - if !check.Period.IsSet { - check.Period.Value = defaultCheckPeriod - } else if check.Period.Value == 0 { - return nil, &FormatError{ + if check.Period.IsSet && check.Period.Value == 0 { + return &FormatError{ Message: fmt.Sprintf("plan check %q period must not be zero", name), } } - if !check.Timeout.IsSet { - check.Timeout.Value = defaultCheckTimeout - } else if check.Timeout.Value == 0 { - return nil, &FormatError{ + if check.Timeout.IsSet && check.Timeout.Value == 0 { + return &FormatError{ Message: fmt.Sprintf("plan check %q timeout must not be zero", name), } - } else if check.Timeout.Value >= check.Period.Value { - return nil, &FormatError{ - Message: fmt.Sprintf("plan check %q timeout must be less than period", name), + } + + if check.Exec != nil { + _, err := shlex.Split(check.Exec.Command) + if err != nil { + return &FormatError{ + Message: fmt.Sprintf("plan check %q command invalid: %v", name, err), + } + } + _, _, err = osutil.NormalizeUidGid(check.Exec.UserID, check.Exec.GroupID, check.Exec.User, check.Exec.Group) + if err != nil { + return &FormatError{ + Message: fmt.Sprintf("plan check %q has invalid user/group: %v", name, err), + } } } - if check.Threshold == 0 { - // Default number of failures in a row before check triggers - // action, default is >1 to avoid flapping due to glitches. For - // what it's worth, Kubernetes probes uses a default of 3 too. - check.Threshold = defaultCheckThreshold + } + + for name, target := range layer.LogTargets { + if target == nil { + return &FormatError{ + Message: fmt.Sprintf("log target object cannot be null for log target %q", name), + } } + for labelName := range target.Labels { + // 'pebble_*' labels are reserved + if strings.HasPrefix(labelName, "pebble_") { + return &FormatError{ + Message: fmt.Sprintf(`log target %q: label %q uses reserved prefix "pebble_"`, name, labelName), + } + } + } + switch target.Type { + case LokiTarget, SyslogTarget: + // valid, continue + case UnsetLogTarget: + // will be checked when the layers are combined + default: + return &FormatError{ + Message: fmt.Sprintf(`log target %q has unsupported type %q, must be %q or %q`, + name, target.Type, LokiTarget, SyslogTarget), + } + } + } + + return nil +} + +// Validate checks that the combined layers form a valid plan. +// See also Layer.Validate, which checks that the individual layers are valid. +func (p *Plan) Validate() error { + for name, service := range p.Services { + if service.Command == "" { + return &FormatError{ + Message: fmt.Sprintf(`plan must define "command" for service %q`, name), + } + } + } + for name, check := range p.Checks { numTypes := 0 if check.HTTP != nil { if check.HTTP.URL == "" { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`plan must set "url" for http check %q`, name), } } @@ -729,7 +846,7 @@ func CombineLayers(layers ...*Layer) (*Layer, error) { } if check.TCP != nil { if check.TCP.Port == 0 { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`plan must set "port" for tcp check %q`, name), } } @@ -737,83 +854,65 @@ func CombineLayers(layers ...*Layer) (*Layer, error) { } if check.Exec != nil { if check.Exec.Command == "" { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`plan must set "command" for exec check %q`, name), } } - _, err := shlex.Split(check.Exec.Command) - if err != nil { - return nil, &FormatError{ - Message: fmt.Sprintf("plan check %q command invalid: %v", name, err), - } - } - _, contextExists := combined.Services[check.Exec.ServiceContext] + _, contextExists := p.Services[check.Exec.ServiceContext] if check.Exec.ServiceContext != "" && !contextExists { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf("plan check %q service context specifies non-existent service %q", name, check.Exec.ServiceContext), } } - _, _, err = osutil.NormalizeUidGid(check.Exec.UserID, check.Exec.GroupID, check.Exec.User, check.Exec.Group) - if err != nil { - return nil, &FormatError{ - Message: fmt.Sprintf("plan check %q has invalid user/group: %v", name, err), - } - } numTypes++ } if numTypes != 1 { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`plan must specify one of "http", "tcp", or "exec" for check %q`, name), } } } - for name, target := range combined.LogTargets { + for name, target := range p.LogTargets { switch target.Type { case LokiTarget, SyslogTarget: // valid, continue case UnsetLogTarget: - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`plan must define "type" (%q or %q) for log target %q`, LokiTarget, SyslogTarget, name), } - default: - return nil, &FormatError{ - Message: fmt.Sprintf(`log target %q has unsupported type %q, must be %q or %q`, - name, target.Type, LokiTarget, SyslogTarget), - } } - // Validate service names specified in log target + // Validate service names specified in log target. for _, serviceName := range target.Services { serviceName = strings.TrimPrefix(serviceName, "-") if serviceName == "all" { continue } - if _, ok := combined.Services[serviceName]; ok { + if _, ok := p.Services[serviceName]; ok { continue } - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`log target %q specifies unknown service %q`, target.Name, serviceName), } } if target.Location == "" { - return nil, &FormatError{ + return &FormatError{ Message: fmt.Sprintf(`plan must define "location" for log target %q`, name), } } } // Ensure combined layers don't have cycles. - err := combined.checkCycles() + err := p.checkCycles() if err != nil { - return nil, err + return err } - - return combined, nil + return nil } // StartOrder returns the required services that must be started for the named @@ -906,12 +1005,12 @@ func order(services map[string]*Service, names []string, stop bool) ([]string, e return order, nil } -func (l *Layer) checkCycles() error { +func (p *Plan) checkCycles() error { var names []string - for name := range l.Services { + for name := range p.Services { names = append(names, name) } - _, err := order(l.Services, names, false) + _, err := order(p.Services, names, false) return err } @@ -933,75 +1032,29 @@ func ParseLayer(order int, label string, data []byte) (*Layer, error) { layer.Label = label for name, service := range layer.Services { - if name == "" { - return nil, &FormatError{ - Message: fmt.Sprintf("cannot use empty string as service name"), - } - } - if name == "pebble" { - // Disallow service name "pebble" to avoid ambiguity (for example, - // in log output). - return nil, &FormatError{ - Message: fmt.Sprintf("cannot use reserved service name %q", name), - } - } - // Deprecated service names - if name == "all" || name == "default" || name == "none" { - logger.Noticef("Using keyword %q as a service name is deprecated", name) - } - if strings.HasPrefix(name, "-") { - return nil, &FormatError{ - Message: fmt.Sprintf(`cannot use service name %q: starting with "-" not allowed`, name), - } - } - if service == nil { - return nil, &FormatError{ - Message: fmt.Sprintf("service object cannot be null for service %q", name), - } + // If service is nil, then the validation below will reject the layer, + // but we want the name set so that we can use easily use it in error + // messages during validation. + if service != nil { + service.Name = name } - service.Name = name } - for name, check := range layer.Checks { - if name == "" { - return nil, &FormatError{ - Message: fmt.Sprintf("cannot use empty string as check name"), - } + if check != nil { + check.Name = name } - if check == nil { - return nil, &FormatError{ - Message: fmt.Sprintf("check object cannot be null for check %q", name), - } - } - check.Name = name } - for name, target := range layer.LogTargets { - if name == "" { - return nil, &FormatError{ - Message: fmt.Sprintf("cannot use empty string as log target name"), - } + if target != nil { + target.Name = name } - if target == nil { - return nil, &FormatError{ - Message: fmt.Sprintf("log target object cannot be null for log target %q", name), - } - } - for labelName := range target.Labels { - // 'pebble_*' labels are reserved - if strings.HasPrefix(labelName, "pebble_") { - return nil, &FormatError{ - Message: fmt.Sprintf(`log target %q: label %q uses reserved prefix "pebble_"`, name, labelName), - } - } - } - target.Name = name } - err = layer.checkCycles() + err = layer.Validate() if err != nil { return nil, err } + return &layer, err } @@ -1107,6 +1160,10 @@ func ReadDir(dir string) (*Plan, error) { Checks: combined.Checks, LogTargets: combined.LogTargets, } + err = plan.Validate() + if err != nil { + return nil, err + } return plan, err } diff --git a/internals/plan/plan_test.go b/internals/plan/plan_test.go index bb22dd6a..5d47201a 100644 --- a/internals/plan/plan_test.go +++ b/internals/plan/plan_test.go @@ -779,7 +779,7 @@ var planTests = []planTest{{ Name: "chk-http", Override: plan.MergeOverride, Period: plan.OptionalDuration{Value: time.Second, IsSet: true}, - Timeout: plan.OptionalDuration{Value: defaultCheckTimeout}, + Timeout: plan.OptionalDuration{Value: time.Second}, Threshold: defaultCheckThreshold, HTTP: &plan.HTTPCheck{ URL: "https://example.com/bar", @@ -814,6 +814,63 @@ var planTests = []planTest{{ }, LogTargets: map[string]*plan.LogTarget{}, }, +}, { + summary: "Timeout is capped at period", + input: []string{` + checks: + chk1: + override: replace + period: 100ms + timeout: 2s + tcp: + host: foobar + port: 80 +`}, + result: &plan.Layer{ + Services: map[string]*plan.Service{}, + Checks: map[string]*plan.Check{ + "chk1": { + Name: "chk1", + Override: plan.ReplaceOverride, + Period: plan.OptionalDuration{Value: 100 * time.Millisecond, IsSet: true}, + Timeout: plan.OptionalDuration{Value: 100 * time.Millisecond, IsSet: true}, + Threshold: defaultCheckThreshold, + TCP: &plan.TCPCheck{ + Port: 80, + Host: "foobar", + }, + }, + }, + LogTargets: map[string]*plan.LogTarget{}, + }, +}, { + summary: "Unset timeout is capped at period", + input: []string{` + checks: + chk1: + override: replace + period: 100ms + tcp: + host: foobar + port: 80 +`}, + result: &plan.Layer{ + Services: map[string]*plan.Service{}, + Checks: map[string]*plan.Check{ + "chk1": { + Name: "chk1", + Override: plan.ReplaceOverride, + Period: plan.OptionalDuration{Value: 100 * time.Millisecond, IsSet: true}, + Timeout: plan.OptionalDuration{Value: 100 * time.Millisecond, IsSet: false}, + Threshold: defaultCheckThreshold, + TCP: &plan.TCPCheck{ + Port: 80, + Host: "foobar", + }, + }, + }, + LogTargets: map[string]*plan.LogTarget{}, + }, }, { summary: "One of http, tcp, or exec must be present for check", error: `plan must specify one of "http", "tcp", or "exec" for check "chk1"`, @@ -1286,6 +1343,64 @@ var planTests = []planTest{{ pebble_service: illegal `}, error: `log target "tgt1": label "pebble_service" uses reserved prefix "pebble_"`, +}, { + summary: "Required field two layers deep", + input: []string{` + services: + srv1: + override: replace + command: sleep 1000 + `, ` + services: + srv1: + override: merge + environment: + VAR1: foo + `, ` + services: + srv1: + override: merge + environment: + VAR2: bar + `}, + result: &plan.Layer{ + Services: map[string]*plan.Service{ + "srv1": { + Name: "srv1", + Command: "sleep 1000", + Override: plan.ReplaceOverride, + BackoffDelay: plan.OptionalDuration{Value: defaultBackoffDelay}, + BackoffFactor: plan.OptionalFloat{Value: defaultBackoffFactor}, + BackoffLimit: plan.OptionalDuration{Value: defaultBackoffLimit}, + Environment: map[string]string{ + "VAR1": "foo", + "VAR2": "bar", + }, + }, + }, + Checks: map[string]*plan.Check{}, + LogTargets: map[string]*plan.LogTarget{}, + }, +}, { + summary: "Three layers missing command", + input: []string{` + services: + srv1: + override: replace +`, ` + services: + srv1: + override: merge + environment: + VAR1: foo +`, ` + services: + srv1: + override: merge + environment: + VAR2: bar +`}, + error: `plan must define "command" for service "srv1"`, }} func (s *S) TestParseLayer(c *C) { @@ -1324,6 +1439,15 @@ func (s *S) TestParseLayer(c *C) { c.Assert(names, DeepEquals, order) } } + if err == nil { + p := &plan.Plan{ + Layers: sup.Layers, + Services: result.Services, + Checks: result.Checks, + LogTargets: result.LogTargets, + } + err = p.Validate() + } } if err != nil || test.error != "" { if test.error != "" { @@ -1355,7 +1479,16 @@ services: - srv1 `)) c.Assert(err, IsNil) - _, err = plan.CombineLayers(layer1, layer2) + combined, err := plan.CombineLayers(layer1, layer2) + c.Assert(err, IsNil) + layers := []*plan.Layer{layer1, layer2} + p := &plan.Plan{ + Layers: layers, + Services: combined.Services, + Checks: combined.Checks, + LogTargets: combined.LogTargets, + } + err = p.Validate() c.Assert(err, ErrorMatches, `services in before/after loop: .*`) _, ok := err.(*plan.FormatError) c.Assert(ok, Equals, true, Commentf("error must be *plan.FormatError, not %T", err)) @@ -1386,7 +1519,16 @@ services: override: merge `)) c.Assert(err, IsNil) - _, err = plan.CombineLayers(layer1, layer2) + combined, err := plan.CombineLayers(layer1, layer2) + c.Assert(err, IsNil) + layers := []*plan.Layer{layer1, layer2} + p := &plan.Plan{ + Layers: layers, + Services: combined.Services, + Checks: combined.Checks, + LogTargets: combined.LogTargets, + } + err = p.Validate() c.Check(err, ErrorMatches, `plan must define "command" for service "srv1"`) _, ok := err.(*plan.FormatError) c.Check(ok, Equals, true, Commentf("error must be *plan.FormatError, not %T", err)) @@ -1405,7 +1547,7 @@ services: override: merge `)) c.Assert(err, IsNil) - combined, err := plan.CombineLayers(layer1, layer2) + combined, err = plan.CombineLayers(layer1, layer2) c.Assert(err, IsNil) c.Assert(combined.Services["srv1"].Command, Equals, "foo --bar") }